Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query S2A Address from MDS #1400

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
58 changes: 58 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/MtlsConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.google.auth.oauth2;

import com.google.errorprone.annotations.CanIgnoreReturnValue;

/** Holds an mTLS configuration (consists of address of S2A) retrieved from the Metadata Server. */
public final class MtlsConfig {
// plaintextS2AAddress is the plaintext address to reach the S2A.
private final String plaintextS2AAddress;

// mtlsS2AAddress is the mTLS address to reach the S2A.
private final String mtlsS2AAddress;

public static Builder createBuilder() {
return new Builder();
}

public String getPlaintextS2AAddress() {
return plaintextS2AAddress;
}

public String getMtlsS2AAddress() {
return mtlsS2AAddress;
}

public static final class Builder {
// plaintextS2AAddress is the plaintext address to reach the S2A.
private String plaintextS2AAddress;

// mtlsS2AAddress is the mTLS address to reach the S2A.
private String mtlsS2AAddress;

Builder() {
plaintextS2AAddress = "";
mtlsS2AAddress = "";
}

@CanIgnoreReturnValue
public Builder setPlaintextS2AAddress(String plaintextS2AAddress) {
this.plaintextS2AAddress = plaintextS2AAddress;
return this;
}

@CanIgnoreReturnValue
public Builder setMtlsS2AAddress(String mtlsS2AAddress) {
this.mtlsS2AAddress = mtlsS2AAddress;
return this;
}

public MtlsConfig build() {
return new MtlsConfig(plaintextS2AAddress, mtlsS2AAddress);
}
}

private MtlsConfig(String plaintextS2AAddress, String mtlsS2AAddress) {
this.plaintextS2AAddress = plaintextS2AAddress;
this.mtlsS2AAddress = mtlsS2AAddress;
}
}
107 changes: 107 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/S2A.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.google.auth.oauth2;

import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpResponse;
import com.google.api.client.json.JsonObjectParser;
import com.google.api.client.util.GenericData;
import com.google.auth.http.HttpTransportFactory;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.io.InputStream;
import java.util.ServiceLoader;
import javax.annotation.concurrent.ThreadSafe;

/**
* Utilities to fetch the S2A (Secure Session Agent) address from the mTLS configuration.
*
* <p>mTLS configuration is queried from the MDS MTLS Autoconfiguration endpoint.
*/
@ThreadSafe
public final class S2A {
public static final String DEFAULT_METADATA_SERVER_URL = "http://169.254.169.254";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like getting metadata server address is already implemented here:

public static String getMetadataServerUrl(DefaultCredentialsProvider provider) {

should we reuse that definition?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Used ComputeEngineCredentials.getMeadataServerUrl()

public static final String MTLS_CONFIG_ENDPOINT =
"/computeMetadata/v1/instance/platform-security/auto-mtls-configuration";

public static final String METADATA_FLAVOR = "Metadata-Flavor";
public static final String GOOGLE = "Google";
private static final String PARSE_ERROR_S2A = "Error parsing Mtls Auto Config response.";

private MtlsConfig config;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: MTLS?


private transient HttpTransportFactory transportFactory;

public S2A() {}

public void setHttpTransportFactory(HttpTransportFactory tf) {
this.transportFactory = tf;
}

/** @return the mTLS S2A Address from the mTLS config. */
public synchronized String getMtlsS2AAddress() {
if (config == null) {
config = getMdsMtlsConfig();
}
return config.getMtlsS2AAddress();
}

/** @return the plaintext S2A Address from the mTLS config. */
public synchronized String getPlaintextS2AAddress() {
if (config == null) {
config = getMdsMtlsConfig();
}
return config.getPlaintextS2AAddress();
}

/**
* Queries the MDS mTLS Autoconfiguration endpoint and returns the {@link MtlsConfig}.
*
* <p>Returns {@link MtlsConfig} with empty addresses on error.
*
* @return the {@link MtlsConfig}.
*/
private MtlsConfig getMdsMtlsConfig() {
String plaintextS2AAddress = "";
String mtlsS2AAddress = "";
try {
if (transportFactory == null) {
transportFactory =
Iterables.getFirst(
ServiceLoader.load(HttpTransportFactory.class), OAuth2Utils.HTTP_TRANSPORT_FACTORY);
}
String url = getMdsMtlsEndpoint();
GenericUrl genericUrl = new GenericUrl(url);
HttpRequest request =
transportFactory.create().createRequestFactory().buildGetRequest(genericUrl);
JsonObjectParser parser = new JsonObjectParser(OAuth2Utils.JSON_FACTORY);
request.setParser(parser);
request.getHeaders().set(METADATA_FLAVOR, GOOGLE);
request.setThrowExceptionOnExecuteError(false);
HttpResponse response = request.execute();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be useful to have retry logic here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added retry logic, similar to


if (!response.isSuccessStatusCode()) {
return MtlsConfig.createBuilder().build();
}

InputStream content = response.getContent();
if (content == null) {
return MtlsConfig.createBuilder().build();
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to reuse the code below for querying mds endpoint?

private HttpResponse getMetadataResponse(String url) throws IOException {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use this function to get the MDS HttpResponse, but it would only replace a few lines (~7) in this function:

creating the HttpRequest and executing it to get the response, replaced with

response = getMetadataResponse(url).

However, in order to do this, we would also have to:

  • ignore the ComputeEngineCredentials specific errors thrown by the function, because we only care if we are able to successfully create a response (aligning with Go implementation, return empty S2A Address if any error). This technically works, although I am not sure it is best practice.
  • getMetadataResponse(String url) is not static, so we would have to create an instance of ComputeEngineCredentials to use it.

WDYT?

GenericData responseData = response.parseAs(GenericData.class);
plaintextS2AAddress =
OAuth2Utils.validateString(responseData, "plaintext_address", PARSE_ERROR_S2A);
mtlsS2AAddress = OAuth2Utils.validateString(responseData, "mtls_address", PARSE_ERROR_S2A);
} catch (IOException e) {
return MtlsConfig.createBuilder().build();
}
return MtlsConfig.createBuilder()
.setPlaintextS2AAddress(plaintextS2AAddress)
.setMtlsS2AAddress(mtlsS2AAddress)
.build();
}

/** @return MDS mTLS autoconfig endpoint. */
private String getMdsMtlsEndpoint() {
return DEFAULT_METADATA_SERVER_URL + MTLS_CONFIG_ENDPOINT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ public class MockMetadataServerTransport extends MockHttpTransport {

private byte[] signature;

private String plaintextS2AAddress;

private String mtlsS2AAddress;

private boolean emptyContent;

public MockMetadataServerTransport() {}

public void setAccessToken(String accessToken) {
Expand All @@ -82,6 +88,18 @@ public void setIdToken(String idToken) {
this.idToken = idToken;
}

public void setPlaintextS2AAddress(String address) {
this.plaintextS2AAddress = address;
}

public void setMtlsS2AAddress(String address) {
this.mtlsS2AAddress = address;
}

public void setEmptyContent(boolean emptyContent) {
this.emptyContent = emptyContent;
}

@Override
public LowLevelHttpRequest buildRequest(String method, String url) throws IOException {
if (url.equals(ComputeEngineCredentials.getTokenServerEncodedUrl())) {
Expand All @@ -92,6 +110,8 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce
return getMockRequestForSign(url);
} else if (isIdentityDocumentUrl(url)) {
return getMockRequestForIdentityDocument(url);
} else if (isMtlsConfigRequestUrl(url)) {
return getMockRequestForMtlsConfig(url);
}
return new MockLowLevelHttpRequest(url) {
@Override
Expand Down Expand Up @@ -233,6 +253,37 @@ public LowLevelHttpResponse execute() throws IOException {
};
}

private MockLowLevelHttpRequest getMockRequestForMtlsConfig(String url) {
return new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() throws IOException {

String metadataRequestHeader = getFirstHeaderValue(S2A.METADATA_FLAVOR);
if (!S2A.GOOGLE.equals(metadataRequestHeader)) {
throw new IOException("Metadata request header not found");
}

// Create the JSON response
GenericJson content = new GenericJson();
content.setFactory(OAuth2Utils.JSON_FACTORY);
content.put("plaintext_address", plaintextS2AAddress);
content.put("mtls_address", mtlsS2AAddress);
String contentText = content.toPrettyString();

MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();

if (requestStatusCode != null) {
response.setStatusCode(requestStatusCode);
}
if (emptyContent == true) {
return response.setZeroContent();
}
response.setContentType(Json.MEDIA_TYPE).setContent(contentText);
return response;
}
};
}

protected boolean isGetServiceAccountsUrl(String url) {
return url.equals(ComputeEngineCredentials.getServiceAccountsUrl());
}
Expand All @@ -246,4 +297,10 @@ protected boolean isSignRequestUrl(String url) {
protected boolean isIdentityDocumentUrl(String url) {
return url.startsWith(String.format(ComputeEngineCredentials.getIdentityDocumentUrl()));
}

protected boolean isMtlsConfigRequestUrl(String url) {
return plaintextS2AAddress != null
&& mtlsS2AAddress != null
&& url.equals(String.format(S2A.DEFAULT_METADATA_SERVER_URL + S2A.MTLS_CONFIG_ENDPOINT));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.google.auth.oauth2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test cases for {@link MtlsConfig}. */
@RunWith(JUnit4.class)
public class MtlsConfigTest {
private static final String S2A_PLAINTEXT_ADDRESS = "plaintext";
private static final String S2A_MTLS_ADDRESS = "mtls";

@Test
public void createMtlsConfig_success() {
MtlsConfig config =
MtlsConfig.createBuilder()
.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS)
.setMtlsS2AAddress(S2A_MTLS_ADDRESS)
.build();
assertEquals(S2A_PLAINTEXT_ADDRESS, config.getPlaintextS2AAddress());
assertEquals(S2A_MTLS_ADDRESS, config.getMtlsS2AAddress());
}

@Test
public void createEmptyMtlsConfig_success() {
MtlsConfig config = MtlsConfig.createBuilder().build();
assertTrue(config.getPlaintextS2AAddress().isEmpty());
assertTrue(config.getMtlsS2AAddress().isEmpty());
}
}
65 changes: 65 additions & 0 deletions oauth2_http/javatests/com/google/auth/oauth2/S2ATest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.google.auth.oauth2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.google.api.client.http.HttpStatusCodes;
import com.google.auth.oauth2.ComputeEngineCredentialsTest.MockMetadataServerTransportFactory;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Test cases for {@link S2A}. */
@RunWith(JUnit4.class)
public class S2ATest {

private static final String S2A_PLAINTEXT_ADDRESS = "plaintext";
private static final String S2A_MTLS_ADDRESS = "mtls";

@Test
public void getS2AAddress_validAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertEquals(S2A_PLAINTEXT_ADDRESS, plaintextS2AAddress);
assertEquals(S2A_MTLS_ADDRESS, mtlsS2AAddress);
}

@Test
public void getS2AAddress_queryEndpointResponseErrorCode_emptyAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertTrue(plaintextS2AAddress.isEmpty());
assertTrue(mtlsS2AAddress.isEmpty());
}

@Test
public void getS2AAddress_queryEndpointResponseEmpty_emptyAddress() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setPlaintextS2AAddress(S2A_PLAINTEXT_ADDRESS);
transportFactory.transport.setMtlsS2AAddress(S2A_MTLS_ADDRESS);
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setEmptyContent(true);

S2A s2aUtils = new S2A();
s2aUtils.setHttpTransportFactory(transportFactory);
String plaintextS2AAddress = s2aUtils.getPlaintextS2AAddress();
String mtlsS2AAddress = s2aUtils.getMtlsS2AAddress();
assertTrue(plaintextS2AAddress.isEmpty());
assertTrue(mtlsS2AAddress.isEmpty());
}
}