Skip to content

Commit

Permalink
Close expired gRPC connections if disconnect_expired_cert is set (#41827
Browse files Browse the repository at this point in the history
)

Updates the transport credentials used by gRPC servers that require
mTLS to enforce that connections are terminated when the client
certificate expires if `disconnect_expired_cert == true`. To prevent
session resumption from leaving open sessions established through
the Proxy gRPC server the redial mechanism was updated to inspect
for certificate expired errors and abort any future reconnection
attempts.

Partially addresses #1199.
  • Loading branch information
rosstimothy committed May 21, 2024
1 parent 4b85df6 commit e9c1506
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 127 deletions.
1 change: 1 addition & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5675,6 +5675,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
creds, err := NewTransportCredentials(TransportCredentialsConfig{
TransportCredentials: &httplib.TLSCreds{Config: cfg.TLS},
UserGetter: cfg.Middleware,
GetAuthPreference: cfg.AuthServer.Cache.GetAuthPreference,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
11 changes: 9 additions & 2 deletions lib/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -662,12 +663,18 @@ func (h *fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

type fakeConn struct {
net.Conn
closed atomic.Bool
}

func (f fakeConn) Close() error {
func (f *fakeConn) Close() error {
f.closed.CompareAndSwap(false, true)
return nil
}

func (f *fakeConn) RemoteAddr() net.Addr {
return &utils.NetAddr{}
}

func TestValidateClientVersion(t *testing.T) {
cases := []struct {
name string
Expand Down Expand Up @@ -729,7 +736,7 @@ func TestValidateClientVersion(t *testing.T) {
ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"version": tt.clientVersion}))
}

tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I}))
tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: &fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I}))
})
}
}
104 changes: 84 additions & 20 deletions lib/auth/transport_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"crypto/tls"
"io"
"net"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"google.golang.org/grpc/credentials"

"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/authz"
)
Expand Down Expand Up @@ -78,6 +81,12 @@ type TransportCredentialsConfig struct {
// of active connections is within the limit. If not set then no connection
// limits are enforced.
Enforcer ConnectionEnforcer
// Clock used to tell time.
Clock clockwork.Clock
// GetAuthPreference is used to retrieve the auth preference per connection
// to determine if connections should be terminated as soon as the client
// certificate has expired.
GetAuthPreference func(ctx context.Context) (types.AuthPreference, error)
}

// Check validates that the configuration is valid for use and
Expand Down Expand Up @@ -105,9 +114,11 @@ func (c *TransportCredentialsConfig) Check() error {
type TransportCredentials struct {
credentials.TransportCredentials

userGetter UserGetter
authorizer authz.Authorizer
enforcer ConnectionEnforcer
userGetter UserGetter
authorizer authz.Authorizer
enforcer ConnectionEnforcer
getAuthPreference func(context.Context) (types.AuthPreference, error)
clock clockwork.Clock
}

// NewTransportCredentials returns a new TransportCredentials
Expand All @@ -116,11 +127,25 @@ func NewTransportCredentials(cfg TransportCredentialsConfig) (*TransportCredenti
return nil, trace.Wrap(err)
}

getAuthPreference := func(context.Context) (types.AuthPreference, error) {
return types.DefaultAuthPreference(), nil
}
if cfg.GetAuthPreference != nil {
getAuthPreference = cfg.GetAuthPreference
}

clock := clockwork.NewRealClock()
if cfg.Clock != nil {
clock = cfg.Clock
}

return &TransportCredentials{
TransportCredentials: cfg.TransportCredentials,
userGetter: cfg.UserGetter,
authorizer: cfg.Authorizer,
enforcer: cfg.Enforcer,
getAuthPreference: getAuthPreference,
clock: clock,
}, nil
}

Expand All @@ -143,38 +168,78 @@ type IdentityInfo struct {
Conn net.Conn
}

// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
// At minimum the TLS handshake is performed and the identity is built from
// timeoutConn wraps a connection that is to be closed when
// the timer expires.
type timeoutConn struct {
net.Conn // The underlying [net.Conn] of the gRPC connection.
timer clockwork.Timer
}

// newTimeoutConn creates a [net.Conn] wrapper that closes the rawConn
// if the timeout is exceeded.
func newTimeoutConn(conn net.Conn, clock clockwork.Clock, expires time.Time) (net.Conn, error) {
if expires.IsZero() {
return conn, nil
}

return &timeoutConn{
Conn: conn,
timer: clock.AfterFunc(expires.Sub(clock.Now()), func() { conn.Close() }),
}, nil
}

// Close closes the wrapped [net.Conn] and stops the timer
// to prevent leaking it.
func (c *timeoutConn) Close() error {
c.timer.Stop()
return trace.Wrap(c.Conn.Close())
}

// ServerHandshake performs the authentication handshake for servers as per
// the [credentials.TransportCredentials] interface. It returns the authenticated
// connection and the corresponding auth information about the connection.
// At minimum, the TLS handshake is performed and the identity is built from
// the [tls.ConnectionState]. If the TransportCredentials is configured with
// and Authorizer and ConnectionEnforcer then additional session controls are
// applied before the handshake completes.
func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
// an [authz.Authorizer] and a [ConnectionEnforcer], then additional session
// controls are applied before the handshake completes.
func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, tlsInfo, err := c.performTLSHandshake(rawConn)
if err != nil {
return nil, nil, trace.Wrap(err)
}

defer func() {
if err != nil {
conn.Close()
}
}()
validatedConn, info, err := c.validateIdentity(conn, tlsInfo)
if err != nil {
return nil, nil, trace.NewAggregate(err, conn.Close())
}
return validatedConn, info, nil
}

// validateIdentity extracts the identity from the client certificate,
// authorizes the user, enforces any connection limits, and ensures the
// connection is terminated at expiry of the client certificate if required.
func (c *TransportCredentials) validateIdentity(conn net.Conn, tlsInfo *credentials.TLSInfo) (net.Conn, IdentityInfo, error) {
identityGetter, err := c.userGetter.GetUser(tlsInfo.State)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

ctx := context.Background()
authCtx, err := c.authorize(ctx, conn.RemoteAddr(), identityGetter, &tlsInfo.State)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

if err := c.enforceConnectionLimits(ctx, authCtx, conn); err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

if authPreference, err := c.getAuthPreference(ctx); err == nil {
expiry := authCtx.GetDisconnectCertExpiry(authPreference)
conn, err = newTimeoutConn(conn, c.clock, expiry)
if err != nil {
return nil, IdentityInfo{}, trace.Wrap(err)
}
}

return conn, IdentityInfo{
Expand All @@ -195,8 +260,7 @@ func (c *TransportCredentials) performTLSHandshake(rawConn net.Conn) (net.Conn,

tlsInfo, ok := info.(credentials.TLSInfo)
if !ok {
conn.Close()
return nil, nil, trace.BadParameter("unexpected type in tls auth info %T", info)
return nil, nil, trace.NewAggregate(conn.Close(), trace.BadParameter("unexpected type in tls auth info %T", info))
}

return conn, &tlsInfo, nil
Expand Down
118 changes: 115 additions & 3 deletions lib/auth/transport_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
)

// TestTransportCredentials_Check validates the returned values
Expand Down Expand Up @@ -306,18 +309,121 @@ func TestTransportCredentials_ServerHandshake(t *testing.T) {
}
}

type fakeUserGetter struct {
identity authz.IdentityGetter
}

func (f fakeUserGetter) GetUser(tls.ConnectionState) (authz.IdentityGetter, error) {
return f.identity, nil
}

func TestTransportCredentialsDisconnection(t *testing.T) {
cases := []struct {
name string
expiry time.Duration
}{
{
name: "no expiry",
},
{
name: "closed on expiry",
expiry: time.Hour,
},
{
name: "already expired",
expiry: -time.Hour,
},
}

// Assert that the connections remain open.
connectionOpenAssertion := func(t *testing.T, conn *fakeConn) {
assert.False(t, conn.closed.Load())
}

// Assert that the connections are eventually closed.
connectionClosedAssertion := func(t *testing.T, conn *fakeConn) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
assert.True(t, conn.closed.Load())
}, 5*time.Second, 100*time.Millisecond)
}

pref := types.DefaultAuthPreference()
pref.SetDisconnectExpiredCert(true)
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
clock := clockwork.NewFakeClock()
conn := &fakeConn{}

var expiry time.Time
if test.expiry != 0 {
expiry = clock.Now().Add(test.expiry)
}
identity := TestIdentity{
I: authz.LocalUser{
Username: "llama",
Identity: tlsca.Identity{Username: "llama", Expires: expiry},
},
}

creds, err := NewTransportCredentials(TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(&tls.Config{}),
Authorizer: &fakeAuthorizer{checker: &fakeChecker{}, identity: identity.I},
UserGetter: fakeUserGetter{
identity: identity.I,
},
Clock: clock,
GetAuthPreference: func(ctx context.Context) (types.AuthPreference, error) { return pref, nil },
})
require.NoError(t, err, "creating transport credentials")

validatedConn, _, err := creds.validateIdentity(conn, &credentials.TLSInfo{State: tls.ConnectionState{}})
switch {
case test.expiry == 0:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionOpenAssertion(t, conn)
clock.Advance(time.Hour)
connectionOpenAssertion(t, conn)
case test.expiry < 0:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionClosedAssertion(t, conn)
default:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionOpenAssertion(t, conn)
clock.BlockUntil(1)
clock.Advance(test.expiry)
connectionClosedAssertion(t, conn)
}
})
}
}

type fakeChecker struct {
services.AccessChecker
maxConnections int64
maxConnections int64
disconnectExpired *bool
}

func (c *fakeChecker) MaxConnections() int64 {
return c.maxConnections
}

func (c *fakeChecker) AdjustDisconnectExpiredCert(b bool) bool {
if c.disconnectExpired == nil {
return b
}
return *c.disconnectExpired
}

type fakeAuthorizer struct {
authorizeError error
checker services.AccessChecker
identity authz.IdentityGetter
}

func (a *fakeAuthorizer) Authorize(ctx context.Context) (*authz.Context, error) {
Expand All @@ -330,9 +436,15 @@ func (a *fakeAuthorizer) Authorize(ctx context.Context) (*authz.Context, error)
return nil, err
}

identity := a.identity
if identity == nil {
identity = TestUser(user.GetName()).I
}

return &authz.Context{
User: user,
Checker: a.checker,
User: user,
Checker: a.checker,
Identity: identity,
}, nil
}

Expand Down
Loading

0 comments on commit e9c1506

Please sign in to comment.