Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v15] Close expired gRPC connections if disconnect_expired_cert is set #41827

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5666,6 +5666,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
creds, err := NewTransportCredentials(TransportCredentialsConfig{
TransportCredentials: &httplib.TLSCreds{Config: cfg.TLS},
UserGetter: cfg.Middleware,
GetAuthPreference: cfg.AuthServer.Cache.GetAuthPreference,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
11 changes: 9 additions & 2 deletions lib/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -662,12 +663,18 @@ func (h *fakeHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

type fakeConn struct {
net.Conn
closed atomic.Bool
}

func (f fakeConn) Close() error {
func (f *fakeConn) Close() error {
f.closed.CompareAndSwap(false, true)
return nil
}

func (f *fakeConn) RemoteAddr() net.Addr {
return &utils.NetAddr{}
}

func TestValidateClientVersion(t *testing.T) {
cases := []struct {
name string
Expand Down Expand Up @@ -729,7 +736,7 @@ func TestValidateClientVersion(t *testing.T) {
ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{"version": tt.clientVersion}))
}

tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I}))
tt.errAssertion(t, tt.middleware.ValidateClientVersion(ctx, IdentityInfo{Conn: &fakeConn{}, IdentityGetter: TestBuiltin(types.RoleNode).I}))
})
}
}
104 changes: 84 additions & 20 deletions lib/auth/transport_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"crypto/tls"
"io"
"net"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"google.golang.org/grpc/credentials"

"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/authz"
)
Expand Down Expand Up @@ -78,6 +81,12 @@ type TransportCredentialsConfig struct {
// of active connections is within the limit. If not set then no connection
// limits are enforced.
Enforcer ConnectionEnforcer
// Clock used to tell time.
Clock clockwork.Clock
// GetAuthPreference is used to retrieve the auth preference per connection
// to determine if connections should be terminated as soon as the client
// certificate has expired.
GetAuthPreference func(ctx context.Context) (types.AuthPreference, error)
}

// Check validates that the configuration is valid for use and
Expand Down Expand Up @@ -105,9 +114,11 @@ func (c *TransportCredentialsConfig) Check() error {
type TransportCredentials struct {
credentials.TransportCredentials

userGetter UserGetter
authorizer authz.Authorizer
enforcer ConnectionEnforcer
userGetter UserGetter
authorizer authz.Authorizer
enforcer ConnectionEnforcer
getAuthPreference func(context.Context) (types.AuthPreference, error)
clock clockwork.Clock
}

// NewTransportCredentials returns a new TransportCredentials
Expand All @@ -116,11 +127,25 @@ func NewTransportCredentials(cfg TransportCredentialsConfig) (*TransportCredenti
return nil, trace.Wrap(err)
}

getAuthPreference := func(context.Context) (types.AuthPreference, error) {
return types.DefaultAuthPreference(), nil
}
if cfg.GetAuthPreference != nil {
getAuthPreference = cfg.GetAuthPreference
}

clock := clockwork.NewRealClock()
if cfg.Clock != nil {
clock = cfg.Clock
}

return &TransportCredentials{
TransportCredentials: cfg.TransportCredentials,
userGetter: cfg.UserGetter,
authorizer: cfg.Authorizer,
enforcer: cfg.Enforcer,
getAuthPreference: getAuthPreference,
clock: clock,
}, nil
}

Expand All @@ -143,38 +168,78 @@ type IdentityInfo struct {
Conn net.Conn
}

// ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about
// the connection.
// At minimum the TLS handshake is performed and the identity is built from
// timeoutConn wraps a connection that is to be closed when
// the timer expires.
type timeoutConn struct {
net.Conn // The underlying [net.Conn] of the gRPC connection.
timer clockwork.Timer
}

// newTimeoutConn creates a [net.Conn] wrapper that closes the rawConn
// if the timeout is exceeded.
func newTimeoutConn(conn net.Conn, clock clockwork.Clock, expires time.Time) (net.Conn, error) {
if expires.IsZero() {
return conn, nil
}

return &timeoutConn{
Conn: conn,
timer: clock.AfterFunc(expires.Sub(clock.Now()), func() { conn.Close() }),
}, nil
}

// Close closes the wrapped [net.Conn] and stops the timer
// to prevent leaking it.
func (c *timeoutConn) Close() error {
c.timer.Stop()
return trace.Wrap(c.Conn.Close())
}

// ServerHandshake performs the authentication handshake for servers as per
// the [credentials.TransportCredentials] interface. It returns the authenticated
// connection and the corresponding auth information about the connection.
// At minimum, the TLS handshake is performed and the identity is built from
// the [tls.ConnectionState]. If the TransportCredentials is configured with
// and Authorizer and ConnectionEnforcer then additional session controls are
// applied before the handshake completes.
func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
// an [authz.Authorizer] and a [ConnectionEnforcer], then additional session
// controls are applied before the handshake completes.
func (c *TransportCredentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, tlsInfo, err := c.performTLSHandshake(rawConn)
if err != nil {
return nil, nil, trace.Wrap(err)
}

defer func() {
if err != nil {
conn.Close()
}
}()
validatedConn, info, err := c.validateIdentity(conn, tlsInfo)
if err != nil {
return nil, nil, trace.NewAggregate(err, conn.Close())
}
return validatedConn, info, nil
}

// validateIdentity extracts the identity from the client certificate,
// authorizes the user, enforces any connection limits, and ensures the
// connection is terminated at expiry of the client certificate if required.
func (c *TransportCredentials) validateIdentity(conn net.Conn, tlsInfo *credentials.TLSInfo) (net.Conn, IdentityInfo, error) {
identityGetter, err := c.userGetter.GetUser(tlsInfo.State)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

ctx := context.Background()
authCtx, err := c.authorize(ctx, conn.RemoteAddr(), identityGetter, &tlsInfo.State)
if err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

if err := c.enforceConnectionLimits(ctx, authCtx, conn); err != nil {
return nil, nil, trace.Wrap(err)
return nil, IdentityInfo{}, trace.Wrap(err)
}

if authPreference, err := c.getAuthPreference(ctx); err == nil {
expiry := authCtx.GetDisconnectCertExpiry(authPreference)
conn, err = newTimeoutConn(conn, c.clock, expiry)
if err != nil {
return nil, IdentityInfo{}, trace.Wrap(err)
}
}

return conn, IdentityInfo{
Expand All @@ -195,8 +260,7 @@ func (c *TransportCredentials) performTLSHandshake(rawConn net.Conn) (net.Conn,

tlsInfo, ok := info.(credentials.TLSInfo)
if !ok {
conn.Close()
return nil, nil, trace.BadParameter("unexpected type in tls auth info %T", info)
return nil, nil, trace.NewAggregate(conn.Close(), trace.BadParameter("unexpected type in tls auth info %T", info))
}

return conn, &tlsInfo, nil
Expand Down
118 changes: 115 additions & 3 deletions lib/auth/transport_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
)

// TestTransportCredentials_Check validates the returned values
Expand Down Expand Up @@ -306,18 +309,121 @@ func TestTransportCredentials_ServerHandshake(t *testing.T) {
}
}

type fakeUserGetter struct {
identity authz.IdentityGetter
}

func (f fakeUserGetter) GetUser(tls.ConnectionState) (authz.IdentityGetter, error) {
return f.identity, nil
}

func TestTransportCredentialsDisconnection(t *testing.T) {
cases := []struct {
name string
expiry time.Duration
}{
{
name: "no expiry",
},
{
name: "closed on expiry",
expiry: time.Hour,
},
{
name: "already expired",
expiry: -time.Hour,
},
}

// Assert that the connections remain open.
connectionOpenAssertion := func(t *testing.T, conn *fakeConn) {
assert.False(t, conn.closed.Load())
}

// Assert that the connections are eventually closed.
connectionClosedAssertion := func(t *testing.T, conn *fakeConn) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
assert.True(t, conn.closed.Load())
}, 5*time.Second, 100*time.Millisecond)
}

pref := types.DefaultAuthPreference()
pref.SetDisconnectExpiredCert(true)
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
clock := clockwork.NewFakeClock()
conn := &fakeConn{}

var expiry time.Time
if test.expiry != 0 {
expiry = clock.Now().Add(test.expiry)
}
identity := TestIdentity{
I: authz.LocalUser{
Username: "llama",
Identity: tlsca.Identity{Username: "llama", Expires: expiry},
},
}

creds, err := NewTransportCredentials(TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(&tls.Config{}),
Authorizer: &fakeAuthorizer{checker: &fakeChecker{}, identity: identity.I},
UserGetter: fakeUserGetter{
identity: identity.I,
},
Clock: clock,
GetAuthPreference: func(ctx context.Context) (types.AuthPreference, error) { return pref, nil },
})
require.NoError(t, err, "creating transport credentials")

validatedConn, _, err := creds.validateIdentity(conn, &credentials.TLSInfo{State: tls.ConnectionState{}})
switch {
case test.expiry == 0:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionOpenAssertion(t, conn)
clock.Advance(time.Hour)
connectionOpenAssertion(t, conn)
case test.expiry < 0:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionClosedAssertion(t, conn)
default:
require.NoError(t, err)
require.NotNil(t, validatedConn)

connectionOpenAssertion(t, conn)
clock.BlockUntil(1)
clock.Advance(test.expiry)
connectionClosedAssertion(t, conn)
}
})
}
}

type fakeChecker struct {
services.AccessChecker
maxConnections int64
maxConnections int64
disconnectExpired *bool
}

func (c *fakeChecker) MaxConnections() int64 {
return c.maxConnections
}

func (c *fakeChecker) AdjustDisconnectExpiredCert(b bool) bool {
if c.disconnectExpired == nil {
return b
}
return *c.disconnectExpired
}

type fakeAuthorizer struct {
authorizeError error
checker services.AccessChecker
identity authz.IdentityGetter
}

func (a *fakeAuthorizer) Authorize(ctx context.Context) (*authz.Context, error) {
Expand All @@ -330,9 +436,15 @@ func (a *fakeAuthorizer) Authorize(ctx context.Context) (*authz.Context, error)
return nil, err
}

identity := a.identity
if identity == nil {
identity = TestUser(user.GetName()).I
}

return &authz.Context{
User: user,
Checker: a.checker,
User: user,
Checker: a.checker,
Identity: identity,
}, nil
}

Expand Down
Loading
Loading