Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query S2A Address from MDS #1400

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
58 changes: 58 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/MtlsConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.google.auth.oauth2;

import com.google.errorprone.annotations.CanIgnoreReturnValue;

/** Holds an mTLS configuration (consists of address of S2A) retrieved from the Metadata Server. */
public final class MtlsConfig {
// plaintextS2AAddress is the plaintext address to reach the S2A.
private final String plaintextS2AAddress;

// mtlsS2AAddress is the mTLS address to reach the S2A.
private final String mtlsS2AAddress;

public static Builder createBuilder() {
return new Builder();
}

public String getPlaintextS2AAddress() {
return plaintextS2AAddress;
}

public String getMtlsS2AAddress() {
return mtlsS2AAddress;
}

public static final class Builder {
// plaintextS2AAddress is the plaintext address to reach the S2A.
private String plaintextS2AAddress;

// mtlsS2AAddress is the mTLS address to reach the S2A.
private String mtlsS2AAddress;

Builder() {
plaintextS2AAddress = "";
mtlsS2AAddress = "";
}

@CanIgnoreReturnValue
public Builder setPlaintextS2AAddress(String plaintextS2AAddress) {
this.plaintextS2AAddress = plaintextS2AAddress;
return this;
}

@CanIgnoreReturnValue
public Builder setMtlsS2AAddress(String mtlsS2AAddress) {
this.mtlsS2AAddress = mtlsS2AAddress;
return this;
}

public MtlsConfig build() {
return new MtlsConfig(plaintextS2AAddress, mtlsS2AAddress);
}
}

private MtlsConfig(String plaintextS2AAddress, String mtlsS2AAddress) {
this.plaintextS2AAddress = plaintextS2AAddress;
this.mtlsS2AAddress = mtlsS2AAddress;
}
}
112 changes: 112 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/S2A.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.google.auth.oauth2;

import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.json.JsonObjectParser;
import com.google.api.client.util.GenericData;
import com.google.auth.http.HttpTransportFactory;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.io.InputStream;
import java.util.ServiceLoader;
import javax.annotation.concurrent.ThreadSafe;

/**
* Utilities to fetch the S2A (Secure Session Agent) address from the mTLS configuration.
*
* <p>mTLS configuration is queried from the MDS MTLS Autoconfiguration endpoint.
*/
@ThreadSafe
public final class S2A {
public static final String MTLS_CONFIG_ENDPOINT =
"/computeMetadata/v1/instance/platform-security/auto-mtls-configuration";

public static final String METADATA_FLAVOR = "Metadata-Flavor";
public static final String GOOGLE = "Google";
private static final int MAX_MDS_PING_TRIES = 3;
private static final String PARSE_ERROR_S2A = "Error parsing Mtls Auto Config response.";

private MtlsConfig config;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: MTLS?


private transient HttpTransportFactory transportFactory;

public S2A() {}

public void setHttpTransportFactory(HttpTransportFactory tf) {
this.transportFactory = tf;
}

/** @return the mTLS S2A Address from the mTLS config. */
public synchronized String getMtlsS2AAddress() {
if (config == null) {
config = getMdsMtlsConfig();
}
return config.getMtlsS2AAddress();
}

/** @return the plaintext S2A Address from the mTLS config. */
public synchronized String getPlaintextS2AAddress() {
if (config == null) {
config = getMdsMtlsConfig();
}
return config.getPlaintextS2AAddress();
}

/**
* Queries the MDS mTLS Autoconfiguration endpoint and returns the {@link MtlsConfig}.
*
* <p>Returns {@link MtlsConfig} with empty addresses on error.
*
* @return the {@link MtlsConfig}.
*/
private MtlsConfig getMdsMtlsConfig() {
String plaintextS2AAddress = "";
String mtlsS2AAddress = "";

String url = getMdsMtlsEndpoint();
GenericUrl genericUrl = new GenericUrl(url);

for (int i = 0; i < MAX_MDS_PING_TRIES; i++) {
try {
if (transportFactory == null) {
transportFactory =
Iterables.getFirst(
ServiceLoader.load(HttpTransportFactory.class), OAuth2Utils.HTTP_TRANSPORT_FACTORY);
}
HttpRequest request =
transportFactory.create().createRequestFactory().buildGetRequest(genericUrl);
JsonObjectParser parser = new JsonObjectParser(OAuth2Utils.JSON_FACTORY);
request.setParser(parser);
request.getHeaders().set(METADATA_FLAVOR, GOOGLE);
request.setThrowExceptionOnExecuteError(false);
HttpResponse response = request.execute();

if (!response.isSuccessStatusCode()) {
continue;
}

InputStream content = response.getContent();
if (content == null) {
continue;
}
GenericData responseData = response.parseAs(GenericData.class);
plaintextS2AAddress =
OAuth2Utils.validateString(responseData, "plaintext_address", PARSE_ERROR_S2A);
mtlsS2AAddress = OAuth2Utils.validateString(responseData, "mtls_address", PARSE_ERROR_S2A);
} catch (IOException e) {
continue;
}
return MtlsConfig.createBuilder()
.setPlaintextS2AAddress(plaintextS2AAddress)
.setMtlsS2AAddress(mtlsS2AAddress)
.build();
}
return MtlsConfig.createBuilder().build();
}

/** @return MDS mTLS autoconfig endpoint. */
private String getMdsMtlsEndpoint() {
return ComputeEngineCredentials.getMetadataServerUrl() + MTLS_CONFIG_ENDPOINT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ public class MockMetadataServerTransport extends MockHttpTransport {

private byte[] signature;

private String plaintextS2AAddress;

private String mtlsS2AAddress;

private boolean emptyContent;

public MockMetadataServerTransport() {}

public void setAccessToken(String accessToken) {
Expand All @@ -82,6 +88,18 @@ public void setIdToken(String idToken) {
this.idToken = idToken;
}

public void setPlaintextS2AAddress(String address) {
this.plaintextS2AAddress = address;
}

public void setMtlsS2AAddress(String address) {
this.mtlsS2AAddress = address;
}

public void setEmptyContent(boolean emptyContent) {
this.emptyContent = emptyContent;
}

@Override
public LowLevelHttpRequest buildRequest(String method, String url) throws IOException {
if (url.equals(ComputeEngineCredentials.getTokenServerEncodedUrl())) {
Expand All @@ -92,6 +110,8 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce
return getMockRequestForSign(url);
} else if (isIdentityDocumentUrl(url)) {
return getMockRequestForIdentityDocument(url);
} else if (isMtlsConfigRequestUrl(url)) {
return getMockRequestForMtlsConfig(url);
}
return new MockLowLevelHttpRequest(url) {
@Override
Expand Down Expand Up @@ -233,6 +253,37 @@ public LowLevelHttpResponse execute() throws IOException {
};
}

private MockLowLevelHttpRequest getMockRequestForMtlsConfig(String url) {
return new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() throws IOException {

String metadataRequestHeader = getFirstHeaderValue(S2A.METADATA_FLAVOR);
if (!S2A.GOOGLE.equals(metadataRequestHeader)) {
throw new IOException("Metadata request header not found");
}

// Create the JSON response
GenericJson content = new GenericJson();
content.setFactory(OAuth2Utils.JSON_FACTORY);
content.put("plaintext_address", plaintextS2AAddress);
content.put("mtls_address", mtlsS2AAddress);
String contentText = content.toPrettyString();

MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();

if (requestStatusCode != null) {
response.setStatusCode(requestStatusCode);
}
if (emptyContent == true) {
return response.setZeroContent();
}
response.setContentType(Json.MEDIA_TYPE).setContent(contentText);
return response;
}
};
}

protected boolean isGetServiceAccountsUrl(String url) {
return url.equals(ComputeEngineCredentials.getServiceAccountsUrl());
}
Expand All @@ -246,4 +297,10 @@ protected boolean isSignRequestUrl(String url) {
protected boolean isIdentityDocumentUrl(String url) {
return url.startsWith(String.format(ComputeEngineCredentials.getIdentityDocumentUrl()));
}

protected boolean isMtlsConfigRequestUrl(String url) {
return plaintextS2AAddress != null
&& mtlsS2AAddress != null
&& url.equals(String.format(ComputeEngineCredentials.getMetadataServerUrl() + S2A.MTLS_CONFIG_ENDPOINT));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.google.auth.oauth2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test cases for {@link MtlsConfig}. */
@RunWith(JUnit4.class)
public class MtlsConfigTest {
private static final String S2A_PLAINTEXT_ADDRESS = "plaintext";
private static final String S2A_MTLS_ADDRESS = "mtls";

@Test
public void createMtlsConfig_success() {
MtlsConfig config =
MtlsConfig.createBuilder()
.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS)
.setMtlsS2AAddress(S2A_MTLS_ADDRESS)
.build();
assertEquals(S2A_PLAINTEXT_ADDRESS, config.getPlaintextS2AAddress());
assertEquals(S2A_MTLS_ADDRESS, config.getMtlsS2AAddress());
}

@Test
public void createEmptyMtlsConfig_success() {
MtlsConfig config = MtlsConfig.createBuilder().build();
assertTrue(config.getPlaintextS2AAddress().isEmpty());
assertTrue(config.getMtlsS2AAddress().isEmpty());
}
}
65 changes: 65 additions & 0 deletions oauth2_http/javatests/com/google/auth/oauth2/S2ATest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.google.auth.oauth2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.google.api.client.http.HttpStatusCodes;
import com.google.auth.oauth2.ComputeEngineCredentialsTest.MockMetadataServerTransportFactory;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test cases for {@link S2A}. */
@RunWith(JUnit4.class)
public class S2ATest {

private static final String S2A_PLAINTEXT_ADDRESS = "plaintext";
private static final String S2A_MTLS_ADDRESS = "mtls";

@Test
public void getS2AAddress_validAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertEquals(S2A_PLAINTEXT_ADDRESS, plaintextS2AAddress);
assertEquals(S2A_MTLS_ADDRESS, mtlsS2AAddress);
}

@Test
public void getS2AAddress_queryEndpointResponseErrorCode_emptyAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertTrue(plaintextS2AAddress.isEmpty());
assertTrue(mtlsS2AAddress.isEmpty());
}

@Test
public void getS2AAddress_queryEndpointResponseEmpty_emptyAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setEmptyContent(true);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertTrue(plaintextS2AAddress.isEmpty());
assertTrue(mtlsS2AAddress.isEmpty());
}
}
Loading