Skip to content

Commit

Permalink
populate verified chains when using custom buildVerifyFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
mudhireddy committed Oct 28, 2023
1 parent e88e849 commit 9d07858
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
33 changes: 20 additions & 13 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,14 @@ func (o *ClientOptions) config() (*tls.Config, error) {
return config, nil
}

func (o *ServerOptions) config() (*tls.Config, error) {
func (o *ServerOptions) config(config *tls.Config) (*tls.Config, error) {
if config == nil {
config = &tls.Config{
ClientAuth: tls.NoClientCert,
MinVersion: o.MinVersion,
MaxVersion: o.MaxVersion,
}
}
if o.RequireClientCert && o.VType == SkipVerification && o.VerifyPeer == nil {
return nil, fmt.Errorf("server needs to provide custom verification mechanism if choose to skip default verification, but require client certificate(s)")
}
Expand All @@ -319,17 +326,11 @@ func (o *ServerOptions) config() (*tls.Config, error) {
if o.MinVersion > o.MaxVersion {
return nil, fmt.Errorf("the minimum TLS version is larger than the maximum TLS version")
}
clientAuth := tls.NoClientCert
if o.RequireClientCert {
// We have to set clientAuth to RequireAnyClientCert to force underlying
// TLS package to use the verification function we built from
// buildVerifyFunc.
clientAuth = tls.RequireAnyClientCert
}
config := &tls.Config{
ClientAuth: clientAuth,
MinVersion: o.MinVersion,
MaxVersion: o.MaxVersion,
config.ClientAuth = tls.RequireAnyClientCert
}
// Propagate root-certificate-related fields in tls.Config.
switch {
Expand Down Expand Up @@ -414,7 +415,8 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
if cfg.ServerName == "" {
cfg.ServerName = authority
}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn)
peerVerifiedChains := [][]*x509.Certificate{}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, cfg.ServerName, rawConn, &peerVerifiedChains)
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
go func() {
Expand All @@ -438,12 +440,14 @@ func (c *advancedTLSCreds) ClientHandshake(ctx context.Context, authority string
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
info.State.VerifiedChains = peerVerifiedChains
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}

func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
cfg := credinternal.CloneTLSConfig(c.config)
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn)
peerVerifiedChains := [][]*x509.Certificate{}
cfg.VerifyPeerCertificate = buildVerifyFunc(c, "", rawConn, &peerVerifiedChains)
conn := tls.Server(rawConn, cfg)
if err := conn.Handshake(); err != nil {
conn.Close()
Expand All @@ -456,6 +460,7 @@ func (c *advancedTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credenti
},
}
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
info.State.VerifiedChains = peerVerifiedChains
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
}

Expand All @@ -482,7 +487,8 @@ func (c *advancedTLSCreds) OverrideServerName(serverNameOverride string) error {
// to true.
func buildVerifyFunc(c *advancedTLSCreds,
serverName string,
rawConn net.Conn) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
rawConn net.Conn,
peerVerifiedChains *[][]*x509.Certificate) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
chains := verifiedChains
var leafCert *x509.Certificate
Expand Down Expand Up @@ -541,6 +547,7 @@ func buildVerifyFunc(c *advancedTLSCreds,
return err
}
leafCert = rawCertList[0]
*peerVerifiedChains = chains
}
// Perform certificate revocation check if specified.
if c.revocationConfig != nil {
Expand Down Expand Up @@ -587,8 +594,8 @@ func NewClientCreds(o *ClientOptions) (credentials.TransportCredentials, error)

// NewServerCreds uses ServerOptions to construct a TransportCredentials based
// on TLS.
func NewServerCreds(o *ServerOptions) (credentials.TransportCredentials, error) {
conf, err := o.config()
func NewServerCreds(o *ServerOptions, config *tls.Config) (credentials.TransportCredentials, error) {
conf, err := o.config(config)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions security/advancedtls/advancedtls_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func (s) TestEnd2End(t *testing.T) {
VerifyPeer: test.serverVerifyFunc,
VType: test.serverVType,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
serverTLSCreds, err := NewServerCreds(serverOptions, nil)
if err != nil {
t.Fatalf("failed to create server creds: %v", err)
}
Expand Down Expand Up @@ -640,7 +640,7 @@ func (s) TestPEMFileProviderEnd2End(t *testing.T) {
},
VType: CertVerification,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
serverTLSCreds, err := NewServerCreds(serverOptions, nil)
if err != nil {
t.Fatalf("failed to create server creds: %v", err)
}
Expand Down Expand Up @@ -771,7 +771,7 @@ func (s) TestDefaultHostNameCheck(t *testing.T) {
RequireClientCert: false,
VType: test.serverVType,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
serverTLSCreds, err := NewServerCreds(serverOptions, nil)
if err != nil {
t.Fatalf("failed to create server creds: %v", err)
}
Expand Down Expand Up @@ -911,7 +911,7 @@ func (s) TestTLSVersions(t *testing.T) {
MinVersion: test.serverMinVersion,
MaxVersion: test.serverMaxVersion,
}
serverTLSCreds, err := NewServerCreds(serverOptions)
serverTLSCreds, err := NewServerCreds(serverOptions, nil)
if err != nil {
t.Fatalf("failed to create server creds: %v", err)
}
Expand Down
8 changes: 4 additions & 4 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func (s) TestServerOptionsConfigErrorCases(t *testing.T) {
MinVersion: test.MinVersion,
MaxVersion: test.MaxVersion,
}
_, err := serverOptions.config()
_, err := serverOptions.config(nil)
if err == nil {
t.Fatalf("ServerOptions{%v}.config() returns no err, wantErr != nil", serverOptions)
}
Expand Down Expand Up @@ -316,7 +316,7 @@ func (s) TestServerOptionsConfigSuccessCases(t *testing.T) {
MinVersion: test.MinVersion,
MaxVersion: test.MaxVersion,
}
serverConfig, err := serverOptions.config()
serverConfig, err := serverOptions.config(nil)
if err != nil {
t.Fatalf("ServerOptions{%v}.config() = %v, wantErr == nil", serverOptions, err)
}
Expand Down Expand Up @@ -735,7 +735,7 @@ func (s) TestClientServerHandshake(t *testing.T) {
close(done)
return
}
serverTLS, err := NewServerCreds(serverOptions)
serverTLS, err := NewServerCreds(serverOptions, nil)
if err != nil {
serverRawConn.Close()
close(done)
Expand Down Expand Up @@ -879,7 +879,7 @@ func (s) TestGetCertificatesSNI(t *testing.T) {
},
},
}
serverConfig, err := serverOptions.config()
serverConfig, err := serverOptions.config(nil)
if err != nil {
t.Fatalf("serverOptions.config() failed: %v", err)
}
Expand Down

0 comments on commit 9d07858

Please sign in to comment.