Skip to content

Commit

Permalink
feat: Add context object to pass to supplier functions (#1363)
Browse files Browse the repository at this point in the history
* feat: adding context to supplier methods

* adds docs

* Add builder

* linting

* responding to docs

* Adding enum support

* linting

* builder methods package private

* Add examples on javadocs

* Add test class

* added docs

* Add expected values to context
  • Loading branch information
aeitzman committed Feb 2, 2024
1 parent 5a2d943 commit 1d9efc7
Show file tree
Hide file tree
Showing 12 changed files with 301 additions and 41 deletions.
14 changes: 11 additions & 3 deletions oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java
Expand Up @@ -61,7 +61,8 @@ public class AwsCredentials extends ExternalAccountCredentials {

private static final long serialVersionUID = -3670131891574618105L;

@Nullable private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier;
private final AwsSecurityCredentialsSupplier awsSecurityCredentialsSupplier;
private final ExternalAccountSupplierContext supplierContext;
// Regional credential verification url override. This needs to be its own value so we can
// correctly pass it to a builder.
@Nullable private final String regionalCredentialVerificationUrlOverride;
Expand All @@ -71,6 +72,12 @@ public class AwsCredentials extends ExternalAccountCredentials {
/** Internal constructor. See {@link AwsCredentials.Builder}. */
AwsCredentials(Builder builder) {
super(builder);
this.supplierContext =
ExternalAccountSupplierContext.newBuilder()
.setAudience(this.getAudience())
.setSubjectTokenType(this.getSubjectTokenType())
.build();

// Check that one and only one of supplier or credential source are provided.
if (builder.awsSecurityCredentialsSupplier != null && builder.credentialSource != null) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -128,9 +135,10 @@ public String retrieveSubjectToken() throws IOException {

// The targeted region is required to generate the signed request. The regional
// endpoint must also be used.
String region = awsSecurityCredentialsSupplier.getRegion();
String region = awsSecurityCredentialsSupplier.getRegion(supplierContext);

AwsSecurityCredentials credentials = awsSecurityCredentialsSupplier.getCredentials();
AwsSecurityCredentials credentials =
awsSecurityCredentialsSupplier.getCredentials(supplierContext);

// Generate the signed request to the AWS STS GetCallerIdentity API.
Map<String, String> headers = new HashMap<>();
Expand Down
Expand Up @@ -43,16 +43,18 @@ public interface AwsSecurityCredentialsSupplier extends Serializable {
/**
* Gets the AWS region to use.
*
* @param context relevant context from the calling credential.
* @return the AWS region that should be used for the credential.
* @throws IOException
*/
String getRegion() throws IOException;
String getRegion(ExternalAccountSupplierContext context) throws IOException;

/**
* Gets AWS security credentials.
*
* @param context relevant context from the calling credential.
* @return valid AWS security credentials that can be exchanged for a GCP access token.
* @throws IOException
*/
AwsSecurityCredentials getCredentials() throws IOException;
AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context) throws IOException;
}
@@ -0,0 +1,100 @@
package com.google.auth.oauth2;

import com.google.auth.oauth2.ExternalAccountCredentials.SubjectTokenTypes;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.Serializable;

/**
* Context object to pass relevant variables from external account credentials to suppliers. This
* will be passed on any call made to {@link IdentityPoolSubjectTokenSupplier} or {@link
* AwsSecurityCredentialsSupplier}.
*/
public class ExternalAccountSupplierContext implements Serializable {

private static final long serialVersionUID = -7852130853542313494L;

private final String audience;
private final String subjectTokenType;

/** Internal constructor. See {@link ExternalAccountSupplierContext.Builder}. */
private ExternalAccountSupplierContext(Builder builder) {
this.audience = builder.audience;
this.subjectTokenType = builder.subjectTokenType;
}

/**
* Returns the credentials' expected audience.
*
* @return the requested audience. For example:
* "//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID".
*/
public String getAudience() {
return audience;
}

/**
* Returns the credentials' expected Security Token Service subject token type based on the OAuth
* 2.0 token exchange spec.
*
* <p>Expected values:
*
* <p>"urn:ietf:params:oauth:token-type:jwt" "urn:ietf:params:aws:token-type:aws4_request"
* "urn:ietf:params:oauth:token-type:saml2" "urn:ietf:params:oauth:token-type:id_token"
*
* @return the requested subject token type. For example: "urn:ietf:params:oauth:token-type:jwt".
*/
public String getSubjectTokenType() {
return subjectTokenType;
}

static Builder newBuilder() {
return new Builder();
}

/** Builder for external account supplier context. */
static class Builder {

protected String audience;
protected String subjectTokenType;

/**
* Sets the Audience.
*
* @param audience the audience to set
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
Builder setAudience(String audience) {
this.audience = audience;
return this;
}

/**
* Sets the subject token type.
*
* @param subjectTokenType the subjectTokenType to set.
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
Builder setSubjectTokenType(String subjectTokenType) {
this.subjectTokenType = subjectTokenType;
return this;
}

/**
* Sets the subject token type.
*
* @param subjectTokenType the subjectTokenType to set.
* @return this {@code Builder} object
*/
@CanIgnoreReturnValue
Builder setSubjectTokenType(SubjectTokenTypes subjectTokenType) {
this.subjectTokenType = subjectTokenType.value;
return this;
}

ExternalAccountSupplierContext build() {
return new ExternalAccountSupplierContext(this);
}
}
}
Expand Up @@ -66,7 +66,7 @@ class FileIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSu
}

@Override
public String getSubjectToken() throws IOException {
public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException {
String credentialFilePath = this.credentialSource.credentialLocation;
if (!Files.exists(Paths.get(credentialFilePath), LinkOption.NOFOLLOW_LINKS)) {
throw new IOException(
Expand Down
Expand Up @@ -49,16 +49,20 @@ public class IdentityPoolCredentials extends ExternalAccountCredentials {
static final String FILE_METRICS_HEADER_VALUE = "file";
static final String URL_METRICS_HEADER_VALUE = "url";
private static final long serialVersionUID = 2471046175477275881L;

private final IdentityPoolSubjectTokenSupplier subjectTokenSupplier;
private final ExternalAccountSupplierContext supplierContext;
private final String metricsHeaderValue;

/** Internal constructor. See {@link Builder}. */
IdentityPoolCredentials(Builder builder) {
super(builder);
IdentityPoolCredentialSource credentialSource =
(IdentityPoolCredentialSource) builder.credentialSource;

this.supplierContext =
ExternalAccountSupplierContext.newBuilder()
.setAudience(this.getAudience())
.setSubjectTokenType(this.getSubjectTokenType())
.build();
// Check that one and only one of supplier or credential source are provided.
if (builder.subjectTokenSupplier != null && credentialSource != null) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -99,7 +103,7 @@ public AccessToken refreshAccessToken() throws IOException {

@Override
public String retrieveSubjectToken() throws IOException {
return this.subjectTokenSupplier.getSubjectToken();
return this.subjectTokenSupplier.getSubjectToken(supplierContext);
}

@Override
Expand Down
Expand Up @@ -44,8 +44,9 @@ public interface IdentityPoolSubjectTokenSupplier extends Serializable {
/**
* Gets a subject token that can be exchanged for a GCP access token.
*
* @param context relevant context from the calling credential.
* @return a valid subject token.
* @throws IOException
*/
String getSubjectToken() throws IOException;
String getSubjectToken(ExternalAccountSupplierContext context) throws IOException;
}
Expand Up @@ -89,7 +89,8 @@ class InternalAwsSecurityCredentialsSupplier implements AwsSecurityCredentialsSu
}

@Override
public AwsSecurityCredentials getCredentials() throws IOException {
public AwsSecurityCredentials getCredentials(ExternalAccountSupplierContext context)
throws IOException {
// Check environment variables for credentials first.
if (canRetrieveSecurityCredentialsFromEnvironment()) {
String accessKeyId = environmentProvider.getEnv(AWS_ACCESS_KEY_ID);
Expand Down Expand Up @@ -129,7 +130,7 @@ public AwsSecurityCredentials getCredentials() throws IOException {
}

@Override
public String getRegion() throws IOException {
public String getRegion(ExternalAccountSupplierContext context) throws IOException {
String region;
if (canRetrieveRegionFromEnvironment()) {
// For AWS Lambda, the region is retrieved through the AWS_REGION environment variable.
Expand Down
Expand Up @@ -65,7 +65,7 @@ class UrlIdentityPoolSubjectTokenSupplier implements IdentityPoolSubjectTokenSup
}

@Override
public String getSubjectToken() throws IOException {
public String getSubjectToken(ExternalAccountSupplierContext context) throws IOException {
HttpRequest request =
transportFactory
.create()
Expand Down

0 comments on commit 1d9efc7

Please sign in to comment.