-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Spire delegate API as CertificateProvider
This adds an implementation of the Delegate API of a SPIRE server as a source for certificates to be used in an mTLS handhake. It will connect to the admin socket of a SPIRE agent where it will be able to get the certificates and keys in name of all Cilium workloads which are receiving an SVID from the controller. This is then cached in memory for the auth handler to request. Signed-off-by: Maartje Eyskens <maartje.eyskens@isovalent.com>
- Loading branch information
Showing
2 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// Copyright Authors of Cilium | ||
|
||
package spire | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"errors" | ||
"fmt" | ||
) | ||
|
||
// This file implements the CertificateProvider interface | ||
|
||
func (s *SpireDelegateClient) GetTrustBundle() (*x509.CertPool, error) { | ||
if s.trustBundle == nil { | ||
return nil, errors.New("trust bundle not yet available") | ||
} | ||
return s.trustBundle, nil | ||
} | ||
|
||
func (s *SpireDelegateClient) GetCertificateForIdentity(identity string) (*tls.Certificate, error) { | ||
spiffeID := s.sniToSPIFFEID(identity) | ||
svid, ok := s.svidStore[spiffeID] | ||
if !ok { | ||
return nil, fmt.Errorf("no SPIFFE ID for %s", spiffeID) | ||
} | ||
|
||
if len(svid.X509Svid.CertChain) == 0 { | ||
return nil, fmt.Errorf("no certificate chain inside %s", spiffeID) | ||
} | ||
|
||
var leafCert *x509.Certificate | ||
for _, cert := range svid.X509Svid.CertChain { | ||
cert, err := x509.ParseCertificate(cert) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to parse certificate: %w", err) | ||
} | ||
|
||
if !cert.IsCA { | ||
leafCert = cert | ||
break | ||
} | ||
} | ||
if leafCert == nil { | ||
return nil, fmt.Errorf("no leaf certificate inside %s", spiffeID) | ||
} | ||
|
||
privKey, err := x509.ParsePKCS8PrivateKey(svid.X509SvidKey) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to parse private keyof %s: %w", spiffeID, err) | ||
} | ||
|
||
return &tls.Certificate{ | ||
Certificate: svid.X509Svid.CertChain, | ||
PrivateKey: privKey, | ||
Leaf: leafCert, | ||
}, nil | ||
} | ||
|
||
func (s *SpireDelegateClient) sniToSPIFFEID(sni string) string { | ||
return fmt.Sprintf("spiffe://%s/cilium-id/%s", s.cfg.SpiffeTrustDomain, sni) | ||
} | ||
|
||
func (s *SpireDelegateClient) ValidateIdentity(identity string, cert *x509.Certificate) (bool, error) { | ||
spiffeID := s.sniToSPIFFEID(identity) | ||
|
||
// Spec: SVIDs containing more than one URI SAN MUST be rejected | ||
if len(cert.URIs) != 1 { | ||
return false, errors.New("SPIFFE IDs must have exactly one URI SAN") | ||
} | ||
|
||
return cert.URIs[0].String() == spiffeID, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// Copyright Authors of Cilium | ||
|
||
package spire | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/x509" | ||
"errors" | ||
"fmt" | ||
"os" | ||
"time" | ||
|
||
delegatedidentityv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/agent/delegatedidentity/v1" | ||
spiffeTypes "github.com/spiffe/spire-api-sdk/proto/spire/api/types" | ||
|
||
"github.com/sirupsen/logrus" | ||
"github.com/spf13/pflag" | ||
|
||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials/insecure" | ||
|
||
"github.com/cilium/cilium/pkg/auth/certs" | ||
"github.com/cilium/cilium/pkg/backoff" | ||
"github.com/cilium/cilium/pkg/hive" | ||
"github.com/cilium/cilium/pkg/hive/cell" | ||
"github.com/cilium/cilium/pkg/logging/logfields" | ||
) | ||
|
||
type SpireDelegateClient struct { | ||
cfg SpireDelegateConfig | ||
log logrus.FieldLogger | ||
|
||
connectionAttempts int | ||
|
||
stream delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient | ||
trustStream delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient | ||
|
||
svidStore map[string]*delegatedidentityv1.X509SVIDWithKey | ||
trustBundle *x509.CertPool | ||
|
||
cancelListenForUpdates context.CancelFunc | ||
} | ||
|
||
type SpireDelegateConfig struct { | ||
SpireAdminSocketPath string `mapstructure:"mesh-auth-spire-admin-socket"` | ||
SpiffeTrustDomain string `mapstructure:"mesh-auth-spiffe-trust-domain"` | ||
} | ||
|
||
type certificateProviderResult struct { | ||
cell.Out | ||
|
||
CertificateProvider certs.CertificateProvider `group:"certificateProviders"` | ||
} | ||
|
||
var Cell = cell.Module( | ||
"spire-delegate", | ||
"Spire Delegate API Client", | ||
cell.Provide(newSpireDelegateClient), | ||
cell.Config(SpireDelegateConfig{}), | ||
) | ||
|
||
func newSpireDelegateClient(lc hive.Lifecycle, cfg SpireDelegateConfig, log logrus.FieldLogger) certificateProviderResult { | ||
client := &SpireDelegateClient{ | ||
cfg: cfg, | ||
log: log.WithField(logfields.LogSubsys, "spire-delegate"), | ||
svidStore: map[string]*delegatedidentityv1.X509SVIDWithKey{}, | ||
} | ||
|
||
lc.Append(hive.Hook{OnStart: client.onStart, OnStop: client.onStop}) | ||
|
||
return certificateProviderResult{ | ||
CertificateProvider: client, | ||
} | ||
} | ||
|
||
func (cfg SpireDelegateConfig) Flags(flags *pflag.FlagSet) { | ||
flags.StringVar(&cfg.SpireAdminSocketPath, "mesh-auth-spire-admin-socket", "/run/spire/sockets/admin.sock", "The path for the SPIRE admin agent Unix socket.") | ||
flags.StringVar(&cfg.SpiffeTrustDomain, "mesh-auth-spiffe-trust-domain", "spiffe.cilium.io", "The trust domain for the SPIFFE identity.") | ||
} | ||
|
||
func (s *SpireDelegateClient) onStart(ctx hive.HookContext) error { | ||
s.log.Info("Spire Delegate API Client is running") | ||
|
||
listenCtx, cancel := context.WithCancel(context.Background()) | ||
go s.listenForUpdates(listenCtx) | ||
|
||
s.cancelListenForUpdates = cancel | ||
|
||
return nil | ||
} | ||
|
||
func (s *SpireDelegateClient) onStop(ctx hive.HookContext) error { | ||
s.log.Info("SPIFFE Delegate API Client is stopping") | ||
|
||
s.cancelListenForUpdates() | ||
|
||
if s.stream != nil { | ||
s.stream.CloseSend() | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (s *SpireDelegateClient) listenForUpdates(ctx context.Context) { | ||
s.openStream(ctx) | ||
|
||
listenCtx, cancel := context.WithCancel(ctx) | ||
err := make(chan error) | ||
|
||
go s.listenForSVIDUpdates(listenCtx, err) | ||
go s.listenForBundleUpdates(listenCtx, err) | ||
|
||
backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second} | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
cancel() | ||
return | ||
case e := <-err: | ||
s.log.WithError(e).Error("error in delegate stream, restarting") | ||
time.Sleep(backoffTime.Duration(s.connectionAttempts)) | ||
cancel() | ||
s.connectionAttempts++ | ||
s.listenForUpdates(ctx) | ||
return | ||
} | ||
} | ||
} | ||
|
||
func (s *SpireDelegateClient) listenForSVIDUpdates(ctx context.Context, errorChan chan error) { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
default: | ||
resp, err := s.stream.Recv() | ||
if err != nil { | ||
errorChan <- err | ||
return | ||
} | ||
|
||
s.log.Debugf("received %d X509-SVIDs in update", len(resp.X509Svids)) | ||
s.handleX509SVIDUpdate(resp.X509Svids) | ||
} | ||
} | ||
} | ||
|
||
func (s *SpireDelegateClient) listenForBundleUpdates(ctx context.Context, errorChan chan error) { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
default: | ||
resp, err := s.trustStream.Recv() | ||
if err != nil { | ||
errorChan <- err | ||
return | ||
} | ||
|
||
s.log.Debugf("received %d X509-Bundles in update", len(resp.CaCertificates)) | ||
s.handleX509BundleUpdate(resp.CaCertificates) | ||
} | ||
} | ||
} | ||
|
||
func (s *SpireDelegateClient) handleX509SVIDUpdate(svids []*delegatedidentityv1.X509SVIDWithKey) { | ||
newSvidStore := map[string]*delegatedidentityv1.X509SVIDWithKey{} | ||
|
||
for _, svid := range svids { | ||
|
||
s.log.Debugf("processing spiffe://%s%s, Expires at %s", svid.X509Svid.Id.TrustDomain, | ||
svid.X509Svid.Id.Path, | ||
time.Unix(svid.X509Svid.ExpiresAt, 0)) | ||
|
||
if svid.X509Svid.Id.TrustDomain != s.cfg.SpiffeTrustDomain { | ||
s.log.Debugf("skipping X509-SVID update for trust domain %s as it does not match ours", svid.X509Svid.Id.TrustDomain) | ||
return | ||
} | ||
|
||
key := fmt.Sprintf("spiffe://%s%s", svid.X509Svid.Id.TrustDomain, svid.X509Svid.Id.Path) | ||
|
||
if _, exists := s.svidStore[key]; exists { | ||
old := s.svidStore[key] | ||
if old.X509Svid.ExpiresAt != svid.X509Svid.ExpiresAt || !equalCertChains(old.X509Svid.CertChain, svid.X509Svid.CertChain) { | ||
s.log.Debugf("X509-SVID for %s has changed, updating", key) | ||
// this is a good point to in the future send a trigger for a new handshake | ||
} | ||
} else { | ||
s.log.Debugf("X509-SVID for %s is new, adding", key) | ||
} | ||
|
||
newSvidStore[key] = svid | ||
|
||
} | ||
|
||
s.svidStore = newSvidStore | ||
} | ||
|
||
func (s *SpireDelegateClient) handleX509BundleUpdate(bundles map[string][]byte) { | ||
pool := x509.NewCertPool() | ||
|
||
for trustDomain, bundle := range bundles { | ||
s.log.Debugf("processing trust domain %s cert bundle", trustDomain) | ||
|
||
certs, err := x509.ParseCertificates(bundle) | ||
if err != nil { | ||
s.log.WithError(err).Errorf("failed to parse X.509 DER bundle for trust domain %s", trustDomain) | ||
continue | ||
} | ||
|
||
for _, cert := range certs { | ||
pool.AddCert(cert) | ||
} | ||
} | ||
|
||
s.trustBundle = pool | ||
} | ||
|
||
func (s *SpireDelegateClient) openStream(ctx context.Context) { | ||
// try to init the watcher with a backoff | ||
backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second} | ||
for { | ||
s.log.Info("Connecting to SPIRE Delegate API Client") | ||
|
||
var err error | ||
s.stream, s.trustStream, err = s.initWatcher(ctx) | ||
if err != nil { | ||
s.log.WithError(err).Warn("SPIRE Delegate API Client failed to init watcher, retrying") | ||
time.Sleep(backoffTime.Duration(s.connectionAttempts)) | ||
s.connectionAttempts++ | ||
continue | ||
} | ||
break | ||
} | ||
} | ||
|
||
func (s *SpireDelegateClient) initWatcher(ctx context.Context) (delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient, delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient, error) { | ||
if _, err := os.Stat(s.cfg.SpireAdminSocketPath); errors.Is(err, os.ErrNotExist) { | ||
return nil, nil, fmt.Errorf("SPIRE admin socket (%s) does not exist: %w", s.cfg.SpireAdminSocketPath, err) | ||
} | ||
|
||
unixPath := fmt.Sprintf("unix://%s", s.cfg.SpireAdminSocketPath) | ||
|
||
conn, err := grpc.Dial(unixPath, grpc.WithTransportCredentials(insecure.NewCredentials())) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("grpc.Dial() failed on %s: %w", unixPath, err) | ||
} | ||
|
||
client := delegatedidentityv1.NewDelegatedIdentityClient(conn) | ||
|
||
stream, err := client.SubscribeToX509SVIDs(ctx, &delegatedidentityv1.SubscribeToX509SVIDsRequest{ | ||
Selectors: []*spiffeTypes.Selector{ | ||
{ | ||
Type: "cilium", | ||
Value: "mtls", | ||
}, | ||
}, | ||
}) | ||
|
||
if err != nil { | ||
conn.Close() | ||
return nil, nil, fmt.Errorf("stream failed on %s: %w", unixPath, err) | ||
} | ||
|
||
trustStream, err := client.SubscribeToX509Bundles(ctx, &delegatedidentityv1.SubscribeToX509BundlesRequest{}) | ||
if err != nil { | ||
conn.Close() | ||
return nil, nil, fmt.Errorf("stream for x509 bundle failed on %s: %w", unixPath, err) | ||
} | ||
|
||
return stream, trustStream, nil | ||
} | ||
|
||
func equalCertChains(a, b [][]byte) bool { | ||
if len(a) != len(b) { | ||
return false | ||
} | ||
for i := range a { | ||
if !bytes.Equal(a[i], b[i]) { | ||
return false | ||
} | ||
} | ||
return true | ||
} |