Skip to content

Commit

Permalink
Add Spire delegate API as CertificateProvider
Browse files Browse the repository at this point in the history
This adds an implementation of the Delegate API of a SPIRE server
as a source for certificates to be used in an mTLS handhake.

It will connect to the admin socket of a SPIRE agent where it will
be able to get the certificates and keys in name of all Cilium
workloads which are receiving an SVID from the controller.
This is then cached in memory for the auth handler to request.

Signed-off-by: Maartje Eyskens <maartje.eyskens@isovalent.com>
  • Loading branch information
meyskens authored and sayboras committed Mar 8, 2023
1 parent 7275f04 commit ab5ea83
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 0 deletions.
74 changes: 74 additions & 0 deletions pkg/auth/spire/certificate_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package spire

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
)

// This file implements the CertificateProvider interface

func (s *SpireDelegateClient) GetTrustBundle() (*x509.CertPool, error) {
if s.trustBundle == nil {
return nil, errors.New("trust bundle not yet available")
}
return s.trustBundle, nil
}

func (s *SpireDelegateClient) GetCertificateForIdentity(identity string) (*tls.Certificate, error) {
spiffeID := s.sniToSPIFFEID(identity)
svid, ok := s.svidStore[spiffeID]
if !ok {
return nil, fmt.Errorf("no SPIFFE ID for %s", spiffeID)
}

if len(svid.X509Svid.CertChain) == 0 {
return nil, fmt.Errorf("no certificate chain inside %s", spiffeID)
}

var leafCert *x509.Certificate
for _, cert := range svid.X509Svid.CertChain {
cert, err := x509.ParseCertificate(cert)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}

if !cert.IsCA {
leafCert = cert
break
}
}
if leafCert == nil {
return nil, fmt.Errorf("no leaf certificate inside %s", spiffeID)
}

privKey, err := x509.ParsePKCS8PrivateKey(svid.X509SvidKey)
if err != nil {
return nil, fmt.Errorf("failed to parse private keyof %s: %w", spiffeID, err)
}

return &tls.Certificate{
Certificate: svid.X509Svid.CertChain,
PrivateKey: privKey,
Leaf: leafCert,
}, nil
}

func (s *SpireDelegateClient) sniToSPIFFEID(sni string) string {
return fmt.Sprintf("spiffe://%s/cilium-id/%s", s.cfg.SpiffeTrustDomain, sni)
}

func (s *SpireDelegateClient) ValidateIdentity(identity string, cert *x509.Certificate) (bool, error) {
spiffeID := s.sniToSPIFFEID(identity)

// Spec: SVIDs containing more than one URI SAN MUST be rejected
if len(cert.URIs) != 1 {
return false, errors.New("SPIFFE IDs must have exactly one URI SAN")
}

return cert.URIs[0].String() == spiffeID, nil
}
286 changes: 286 additions & 0 deletions pkg/auth/spire/delegate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium

package spire

import (
"bytes"
"context"
"crypto/x509"
"errors"
"fmt"
"os"
"time"

delegatedidentityv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/agent/delegatedidentity/v1"
spiffeTypes "github.com/spiffe/spire-api-sdk/proto/spire/api/types"

"github.com/sirupsen/logrus"
"github.com/spf13/pflag"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/cilium/cilium/pkg/auth/certs"
"github.com/cilium/cilium/pkg/backoff"
"github.com/cilium/cilium/pkg/hive"
"github.com/cilium/cilium/pkg/hive/cell"
"github.com/cilium/cilium/pkg/logging/logfields"
)

type SpireDelegateClient struct {
cfg SpireDelegateConfig
log logrus.FieldLogger

connectionAttempts int

stream delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient
trustStream delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient

svidStore map[string]*delegatedidentityv1.X509SVIDWithKey
trustBundle *x509.CertPool

cancelListenForUpdates context.CancelFunc
}

type SpireDelegateConfig struct {
SpireAdminSocketPath string `mapstructure:"mesh-auth-spire-admin-socket"`
SpiffeTrustDomain string `mapstructure:"mesh-auth-spiffe-trust-domain"`
}

type certificateProviderResult struct {
cell.Out

CertificateProvider certs.CertificateProvider `group:"certificateProviders"`
}

var Cell = cell.Module(
"spire-delegate",
"Spire Delegate API Client",
cell.Provide(newSpireDelegateClient),
cell.Config(SpireDelegateConfig{}),
)

func newSpireDelegateClient(lc hive.Lifecycle, cfg SpireDelegateConfig, log logrus.FieldLogger) certificateProviderResult {
client := &SpireDelegateClient{
cfg: cfg,
log: log.WithField(logfields.LogSubsys, "spire-delegate"),
svidStore: map[string]*delegatedidentityv1.X509SVIDWithKey{},
}

lc.Append(hive.Hook{OnStart: client.onStart, OnStop: client.onStop})

return certificateProviderResult{
CertificateProvider: client,
}
}

func (cfg SpireDelegateConfig) Flags(flags *pflag.FlagSet) {
flags.StringVar(&cfg.SpireAdminSocketPath, "mesh-auth-spire-admin-socket", "/run/spire/sockets/admin.sock", "The path for the SPIRE admin agent Unix socket.")
flags.StringVar(&cfg.SpiffeTrustDomain, "mesh-auth-spiffe-trust-domain", "spiffe.cilium.io", "The trust domain for the SPIFFE identity.")
}

func (s *SpireDelegateClient) onStart(ctx hive.HookContext) error {
s.log.Info("Spire Delegate API Client is running")

listenCtx, cancel := context.WithCancel(context.Background())
go s.listenForUpdates(listenCtx)

s.cancelListenForUpdates = cancel

return nil
}

func (s *SpireDelegateClient) onStop(ctx hive.HookContext) error {
s.log.Info("SPIFFE Delegate API Client is stopping")

s.cancelListenForUpdates()

if s.stream != nil {
s.stream.CloseSend()
}

return nil
}

func (s *SpireDelegateClient) listenForUpdates(ctx context.Context) {
s.openStream(ctx)

listenCtx, cancel := context.WithCancel(ctx)
err := make(chan error)

go s.listenForSVIDUpdates(listenCtx, err)
go s.listenForBundleUpdates(listenCtx, err)

backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second}
for {
select {
case <-ctx.Done():
cancel()
return
case e := <-err:
s.log.WithError(e).Error("error in delegate stream, restarting")
time.Sleep(backoffTime.Duration(s.connectionAttempts))
cancel()
s.connectionAttempts++
s.listenForUpdates(ctx)
return
}
}
}

func (s *SpireDelegateClient) listenForSVIDUpdates(ctx context.Context, errorChan chan error) {
for {
select {
case <-ctx.Done():
return
default:
resp, err := s.stream.Recv()
if err != nil {
errorChan <- err
return
}

s.log.Debugf("received %d X509-SVIDs in update", len(resp.X509Svids))
s.handleX509SVIDUpdate(resp.X509Svids)
}
}
}

func (s *SpireDelegateClient) listenForBundleUpdates(ctx context.Context, errorChan chan error) {
for {
select {
case <-ctx.Done():
return
default:
resp, err := s.trustStream.Recv()
if err != nil {
errorChan <- err
return
}

s.log.Debugf("received %d X509-Bundles in update", len(resp.CaCertificates))
s.handleX509BundleUpdate(resp.CaCertificates)
}
}
}

func (s *SpireDelegateClient) handleX509SVIDUpdate(svids []*delegatedidentityv1.X509SVIDWithKey) {
newSvidStore := map[string]*delegatedidentityv1.X509SVIDWithKey{}

for _, svid := range svids {

s.log.Debugf("processing spiffe://%s%s, Expires at %s", svid.X509Svid.Id.TrustDomain,
svid.X509Svid.Id.Path,
time.Unix(svid.X509Svid.ExpiresAt, 0))

if svid.X509Svid.Id.TrustDomain != s.cfg.SpiffeTrustDomain {
s.log.Debugf("skipping X509-SVID update for trust domain %s as it does not match ours", svid.X509Svid.Id.TrustDomain)
return
}

key := fmt.Sprintf("spiffe://%s%s", svid.X509Svid.Id.TrustDomain, svid.X509Svid.Id.Path)

if _, exists := s.svidStore[key]; exists {
old := s.svidStore[key]
if old.X509Svid.ExpiresAt != svid.X509Svid.ExpiresAt || !equalCertChains(old.X509Svid.CertChain, svid.X509Svid.CertChain) {
s.log.Debugf("X509-SVID for %s has changed, updating", key)
// this is a good point to in the future send a trigger for a new handshake
}
} else {
s.log.Debugf("X509-SVID for %s is new, adding", key)
}

newSvidStore[key] = svid

}

s.svidStore = newSvidStore
}

func (s *SpireDelegateClient) handleX509BundleUpdate(bundles map[string][]byte) {
pool := x509.NewCertPool()

for trustDomain, bundle := range bundles {
s.log.Debugf("processing trust domain %s cert bundle", trustDomain)

certs, err := x509.ParseCertificates(bundle)
if err != nil {
s.log.WithError(err).Errorf("failed to parse X.509 DER bundle for trust domain %s", trustDomain)
continue
}

for _, cert := range certs {
pool.AddCert(cert)
}
}

s.trustBundle = pool
}

func (s *SpireDelegateClient) openStream(ctx context.Context) {
// try to init the watcher with a backoff
backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second}
for {
s.log.Info("Connecting to SPIRE Delegate API Client")

var err error
s.stream, s.trustStream, err = s.initWatcher(ctx)
if err != nil {
s.log.WithError(err).Warn("SPIRE Delegate API Client failed to init watcher, retrying")
time.Sleep(backoffTime.Duration(s.connectionAttempts))
s.connectionAttempts++
continue
}
break
}
}

func (s *SpireDelegateClient) initWatcher(ctx context.Context) (delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient, delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient, error) {
if _, err := os.Stat(s.cfg.SpireAdminSocketPath); errors.Is(err, os.ErrNotExist) {
return nil, nil, fmt.Errorf("SPIRE admin socket (%s) does not exist: %w", s.cfg.SpireAdminSocketPath, err)
}

unixPath := fmt.Sprintf("unix://%s", s.cfg.SpireAdminSocketPath)

conn, err := grpc.Dial(unixPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, nil, fmt.Errorf("grpc.Dial() failed on %s: %w", unixPath, err)
}

client := delegatedidentityv1.NewDelegatedIdentityClient(conn)

stream, err := client.SubscribeToX509SVIDs(ctx, &delegatedidentityv1.SubscribeToX509SVIDsRequest{
Selectors: []*spiffeTypes.Selector{
{
Type: "cilium",
Value: "mtls",
},
},
})

if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("stream failed on %s: %w", unixPath, err)
}

trustStream, err := client.SubscribeToX509Bundles(ctx, &delegatedidentityv1.SubscribeToX509BundlesRequest{})
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("stream for x509 bundle failed on %s: %w", unixPath, err)
}

return stream, trustStream, nil
}

func equalCertChains(a, b [][]byte) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}

0 comments on commit ab5ea83

Please sign in to comment.