Skip to content

Commit

Permalink
certprovider: API update to include certificate name. (#3797)
Browse files Browse the repository at this point in the history
  • Loading branch information
easwars committed Aug 21, 2020
1 parent 6c0171f commit e14f1c2
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 321 deletions.
8 changes: 6 additions & 2 deletions credentials/tls/certprovider/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ type Distributor struct {
km *KeyMaterial
pErr error

ready *grpcsync.Event
// ready channel to unblock KeyMaterial() invocations blocked on
// availability of key material.
ready *grpcsync.Event
// done channel to notify provider implementations and unblock any
// KeyMaterial() calls, once the Distributor is closed.
closed *grpcsync.Event
}

Expand Down Expand Up @@ -75,7 +79,7 @@ func (d *Distributor) Set(km *KeyMaterial, err error) {
d.mu.Unlock()
}

// KeyMaterial returns the most recent key material provided to the distributor.
// KeyMaterial returns the most recent key material provided to the Distributor.
// If no key material was provided at the time of this call, it will block until
// the deadline on the context expires or fresh key material arrives.
func (d *Distributor) KeyMaterial(ctx context.Context) (*KeyMaterial, error) {
Expand Down
177 changes: 65 additions & 112 deletions credentials/tls/certprovider/distributor_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// +build go1.13

/*
*
* Copyright 2020 gRPC authors.
Expand All @@ -21,139 +23,90 @@ package certprovider
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"testing"
"time"
)

var errProviderTestInternal = errors.New("provider internal error")

// TestDistributorEmpty tries to read key material from an empty distributor and
// expects the call to timeout.
func (s) TestDistributorEmpty(t *testing.T) {
dist := NewDistributor()

// This call to KeyMaterial() should timeout because no key material has
// been set on the distributor as yet.
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, dist, nil); !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}
}

// TestDistributor invokes the different methods on the Distributor type and
// verifies the results.
func (s) TestDistributor(t *testing.T) {
dist := NewDistributor()

// Read cert/key files from testdata.
km, err := loadKeyMaterials()
if err != nil {
km1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")

// Push key material into the distributor and make sure that a call to
// KeyMaterial() returns the expected key material, with both the local
// certs and root certs.
dist.Set(km1, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, dist, km1); err != nil {
t.Fatal(err)
}
// wantKM1 has both local and root certs.
wantKM1 := *km
// wantKM2 has only local certs. Roots are nil-ed out.
wantKM2 := *km
wantKM2.Roots = nil

// Create a goroutines which work in lockstep with the rest of the test.
// This goroutine reads the key material from the distributor while the rest
// of the test sets it.
var wg sync.WaitGroup
wg.Add(1)
errCh := make(chan error)
proceedCh := make(chan struct{})
go func() {
defer wg.Done()

// The first call to KeyMaterial() should timeout because no key
// material has been set on the distributor as yet.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout/2)
defer cancel()
if _, err := dist.KeyMaterial(ctx); err != context.DeadlineExceeded {
errCh <- err
return
}
proceedCh <- struct{}{}

// This call to KeyMaterial() should return the key material with both
// the local certs and the root certs.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotKM, err := dist.KeyMaterial(ctx)
if err != nil {
errCh <- err
return
}
if !reflect.DeepEqual(gotKM, &wantKM1) {
errCh <- fmt.Errorf("provider.KeyMaterial() = %+v, want %+v", gotKM, wantKM1)
}
proceedCh <- struct{}{}

// This call to KeyMaterial() should eventually return key material with
// only the local certs.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for {
gotKM, err := dist.KeyMaterial(ctx)
if err != nil {
errCh <- err
return
}
if reflect.DeepEqual(gotKM, &wantKM2) {
break
}
}
proceedCh <- struct{}{}

// This call to KeyMaterial() should return nil key material and a
// non-nil error.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for {
gotKM, err := dist.KeyMaterial(ctx)
if gotKM == nil && err == errProviderTestInternal {
break
}
if err != nil {
// If we have gotten any error other than
// errProviderTestInternal, we should bail out.
errCh <- err
return
}
}
proceedCh <- struct{}{}

// This call to KeyMaterial() should eventually return errProviderClosed
// error.
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for {
if _, err := dist.KeyMaterial(ctx); err == errProviderClosed {
break
}
time.Sleep(100 * time.Millisecond)
}
}()

waitAndDo(t, proceedCh, errCh, func() {
dist.Set(&wantKM1, nil)
})
// Push new key material into the distributor and make sure that a call to
// KeyMaterial() returns the expected key material, with only root certs.
dist.Set(km2, nil)
if err := readAndVerifyKeyMaterial(ctx, dist, km2); err != nil {
t.Fatal(err)
}

waitAndDo(t, proceedCh, errCh, func() {
dist.Set(&wantKM2, nil)
})
// Push an error into the distributor and make sure that a call to
// KeyMaterial() returns that error and nil keyMaterial.
dist.Set(km2, errProviderTestInternal)
if gotKM, err := dist.KeyMaterial(ctx); gotKM != nil || !errors.Is(err, errProviderTestInternal) {
t.Fatalf("KeyMaterial() = {%v, %v}, want {nil, %v}", gotKM, err, errProviderTestInternal)
}

waitAndDo(t, proceedCh, errCh, func() {
dist.Set(&wantKM2, errProviderTestInternal)
})
// Stop the distributor and KeyMaterial() should return errProviderClosed.
dist.Stop()
if km, err := dist.KeyMaterial(ctx); !errors.Is(err, errProviderClosed) {
t.Fatalf("KeyMaterial() = {%v, %v}, want {nil, %v}", km, err, errProviderClosed)
}
}

waitAndDo(t, proceedCh, errCh, func() {
dist.Stop()
})
// TestDistributorConcurrency invokes methods on the distributor in parallel. It
// exercises that the scenario where a distributor's KeyMaterial() method is
// blocked waiting for keyMaterial, while the Set() method is called from
// another goroutine. It verifies that the KeyMaterial() method eventually
// returns with expected keyMaterial.
func (s) TestDistributorConcurrency(t *testing.T) {
dist := NewDistributor()

}
// Read cert/key files from testdata.
km := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")

func waitAndDo(t *testing.T, proceedCh chan struct{}, errCh chan error, do func()) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

timer := time.NewTimer(defaultTestTimeout)
select {
case <-timer.C:
t.Fatalf("test timed out when waiting for event from distributor")
case <-proceedCh:
do()
case err := <-errCh:
// Push key material into the distributor from a goroutine and read from
// here to verify that the distributor returns the expected keyMaterial.
go func() {
// Add a small sleep here to make sure that the call to KeyMaterial()
// happens before the call to Set(), thereby the former is blocked till
// the latter happens.
time.Sleep(100 * time.Microsecond)
dist.Set(km, nil)
}()
if err := readAndVerifyKeyMaterial(ctx, dist, km); err != nil {
t.Fatal(err)
}
}
35 changes: 24 additions & 11 deletions credentials/tls/certprovider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*
*/

// Package certprovider defines APIs for certificate providers in gRPC.
// Package certprovider defines APIs for Certificate Providers in gRPC.
//
// Experimental
//
Expand All @@ -36,18 +36,18 @@ var (
// closed.
errProviderClosed = errors.New("provider instance is closed")

// m is a map from name to provider builder.
// m is a map from name to Provider builder.
m = make(map[string]Builder)
)

// Register registers the provider builder, whose name as returned by its Name()
// Register registers the Provider builder, whose name as returned by its Name()
// method will be used as the name registered with this builder. Registered
// Builders are used by the Store to create Providers.
func Register(b Builder) {
m[b.Name()] = b
}

// getBuilder returns the provider builder registered with the given name.
// getBuilder returns the Provider builder registered with the given name.
// If no builder is registered with the provided name, nil will be returned.
func getBuilder(name string) Builder {
if b, ok := m[name]; ok {
Expand All @@ -58,8 +58,9 @@ func getBuilder(name string) Builder {

// Builder creates a Provider.
type Builder interface {
// Build creates a new provider with the provided config.
Build(StableConfig) Provider
// Build creates a new Provider and initializes it with the given config and
// options combination.
Build(StableConfig, Options) Provider

// ParseConfig converts config input in a format specific to individual
// implementations and returns an implementation of the StableConfig
Expand All @@ -72,9 +73,9 @@ type Builder interface {
Name() string
}

// StableConfig wraps the method to return a stable provider configuration.
// StableConfig wraps the method to return a stable Provider configuration.
type StableConfig interface {
// Canonical returns provider config as an arbitrary byte slice.
// Canonical returns Provider config as an arbitrary byte slice.
// Equivalent configurations must return the same output.
Canonical() []byte
}
Expand All @@ -87,18 +88,30 @@ type StableConfig interface {
// the latest secrets, and free to share any state between different
// instantiations as they deem fit.
type Provider interface {
// KeyMaterial returns the key material sourced by the provider.
// KeyMaterial returns the key material sourced by the Provider.
// Callers are expected to use the returned value as read-only.
KeyMaterial(ctx context.Context) (*KeyMaterial, error)

// Close cleans up resources allocated by the provider.
// Close cleans up resources allocated by the Provider.
Close()
}

// KeyMaterial wraps the certificates and keys returned by a provider instance.
// KeyMaterial wraps the certificates and keys returned by a Provider instance.
type KeyMaterial struct {
// Certs contains a slice of cert/key pairs used to prove local identity.
Certs []tls.Certificate
// Roots contains the set of trusted roots to validate the peer's identity.
Roots *x509.CertPool
}

// Options contains configuration knobs passed to a Provider at creation time.
type Options struct {
// CertName holds the certificate name, whose key material is of interest to
// the caller.
CertName string
// WantRoot indicates if the caller is interested in the root certificate.
WantRoot bool
// WantIdentity indicates if the caller is interested in the identity
// certificate.
WantIdentity bool
}
28 changes: 17 additions & 11 deletions credentials/tls/certprovider/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type storeKey struct {
name string
// configuration of the certificate provider in string form.
config string
// opts contains the certificate name and other keyMaterial options.
opts Options
}

// wrappedProvider wraps a provider instance with a reference count.
Expand All @@ -57,17 +59,20 @@ type store struct {
providers map[storeKey]*wrappedProvider
}

// GetProvider returns a provider instance corresponding to name and config.
// name is the registered name of the provider and config is the
// provider-specific configuration. Implementations of the Builder interface
// should clearly document the type of configuration accepted by them.
// GetProvider returns a provider instance from which keyMaterial can be read.
//
// If a provider exists for the (name+config) combination, its reference count
// is incremented before returning. If no provider exists for the (name+config)
// combination, a new one is created using the registered builder. If no
// registered builder is found, or the provider configuration is rejected by it,
// a non-nil error is returned.
func GetProvider(name string, config interface{}) (Provider, error) {
// name is the registered name of the provider, config is the provider-specific
// configuration, opts contains extra information that controls the keyMaterial
// returned by the provider.
//
// Implementations of the Builder interface should clearly document the type of
// configuration accepted by them.
//
// If a provider exists for passed arguments, its reference count is incremented
// before returning. If no provider exists for the passed arguments, a new one
// is created using the registered builder. If no registered builder is found,
// or the provider configuration is rejected by it, a non-nil error is returned.
func GetProvider(name string, config interface{}, opts Options) (Provider, error) {
provStore.mu.Lock()
defer provStore.mu.Unlock()

Expand All @@ -83,13 +88,14 @@ func GetProvider(name string, config interface{}) (Provider, error) {
sk := storeKey{
name: name,
config: string(stableConfig.Canonical()),
opts: opts,
}
if wp, ok := provStore.providers[sk]; ok {
wp.refCount++
return wp, nil
}

provider := builder.Build(stableConfig)
provider := builder.Build(stableConfig, opts)
if provider == nil {
return nil, fmt.Errorf("certprovider.Build(%v) failed", sk)
}
Expand Down
Loading

0 comments on commit e14f1c2

Please sign in to comment.