Skip to content

Commit

Permalink
xds: fix the validation code to accept new-style CertificateProviderP…
Browse files Browse the repository at this point in the history
…luginInstance wherever used (#8892) (#8901)
  • Loading branch information
sanjaypujare committed Feb 9, 2022
1 parent 2564020 commit 81ac88a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 3 deletions.
Expand Up @@ -26,9 +26,11 @@ public final class CommonTlsContextUtil {
private CommonTlsContextUtil() {}

static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext != null
&& (commonTlsContext.hasTlsCertificateCertificateProviderInstance()
|| hasCertProviderValidationContext(commonTlsContext));
if (commonTlsContext == null) {
return false;
}
return hasIdentityCertificateProviderInstance(commonTlsContext)
|| hasCertProviderValidationContext(commonTlsContext);
}

private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) {
Expand All @@ -37,6 +39,19 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT
commonTlsContext.getCombinedValidationContext();
return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance();
}
return hasValidationProviderInstance(commonTlsContext);
}

private static boolean hasIdentityCertificateProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext.hasTlsCertificateProviderInstance()
|| commonTlsContext.hasTlsCertificateCertificateProviderInstance();
}

private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasValidationContext() && commonTlsContext.getValidationContext()
.hasCaCertificateProviderInstance()) {
return true;
}
return commonTlsContext.hasValidationContextCertificateProviderInstance();
}

Expand Down
Expand Up @@ -208,6 +208,72 @@ public void createCertProviderClientSslContextProvider_2providers()
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderClientSslContextProvider_withSans() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);

CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderClientSslContextProvider_onlyRootCert() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void createNullCommonTlsContext_exception() throws IOException {
clientSslContextProviderFactory =
Expand Down
Expand Up @@ -206,4 +206,41 @@ public void createCertProviderServerSslContextProvider_2providers()
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderServerSslContextProvider_withSans()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();

DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext,
/* requireClientCert= */ true);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
bootstrapInfo, certProviderServerSslContextProviderFactory);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
}

0 comments on commit 81ac88a

Please sign in to comment.