Skip to content

Commit

Permalink
Acquire connection lock for joining kubernetes sessions (#12473)
Browse files Browse the repository at this point in the history
  • Loading branch information
xacrimon authored May 10, 2022
1 parent 2f00cb4 commit e569902
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
31 changes: 22 additions & 9 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,21 @@ func (f *Forwarder) withAuthStd(handler handlerWithAuthFuncStd) http.HandlerFunc
}, f.formatResponseError)
}

// acquireConnectionLockWithIdentity acquires a connection lock under a given identity.
func (f *Forwarder) acquireConnectionLockWithIdentity(ctx context.Context, identity *authContext) error {
user := identity.Identity.GetIdentity().Username
roles, err := getRolesByName(f, identity.Identity.GetIdentity().Groups)
if err != nil {
return trace.Wrap(err)
}

if err := f.acquireConnectionLock(ctx, user, roles); err != nil {
return trace.Wrap(err)
}

return nil
}

func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle {
return httplib.MakeHandlerWithErrorWriter(func(w http.ResponseWriter, req *http.Request, p httprouter.Params) (interface{}, error) {
authContext, err := f.authenticate(req)
Expand All @@ -453,16 +468,10 @@ func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle {
if err := f.authorize(req.Context(), authContext); err != nil {
return nil, trace.Wrap(err)
}

user := authContext.Identity.GetIdentity().Username
roles, err := getRolesByName(f, authContext.Identity.GetIdentity().Groups)
err = f.acquireConnectionLockWithIdentity(req.Context(), authContext)
if err != nil {
return nil, trace.Wrap(err)
}

if err := f.AcquireConnectionLock(req.Context(), user, roles); err != nil {
return nil, trace.Wrap(err)
}
return handler(authContext, w, req, p)
}, f.formatResponseError)
}
Expand All @@ -477,6 +486,10 @@ func (f *Forwarder) withAuthPassthrough(handler handlerWithAuthFunc) httprouter.
return nil, trace.Wrap(err)
}
}
err = f.acquireConnectionLockWithIdentity(req.Context(), authContext)
if err != nil {
return nil, trace.Wrap(err)
}
return handler(authContext, w, req, p)
}, f.formatResponseError)
}
Expand Down Expand Up @@ -914,10 +927,10 @@ func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error {
return trace.Wrap(err)
}

// AcquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent.
// acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent.
// The semaphore is releasted when the request is returned/connection is closed.
// Returns an error if a semaphore could not be acquired.
func (f *Forwarder) AcquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error {
func (f *Forwarder) acquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error {
maxConnections := roles.MaxKubernetesConnections()
if maxConnections == 0 {
return nil
Expand Down
33 changes: 29 additions & 4 deletions lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,8 @@ func newTestForwarder(ctx context.Context, cfg ForwarderConfig) *Forwarder {

type mockSemaphoreClient struct {
auth.ClientI
sem types.Semaphores
sem types.Semaphores
roles map[string]types.Role
}

func (m *mockSemaphoreClient) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) {
Expand All @@ -1079,6 +1080,15 @@ func (m *mockSemaphoreClient) CancelSemaphoreLease(ctx context.Context, lease ty
return m.sem.CancelSemaphoreLease(ctx, lease)
}

func (m *mockSemaphoreClient) GetRole(ctx context.Context, name string) (types.Role, error) {
role, ok := m.roles[name]
if !ok {
return nil, trace.NotFound("role %q not found", name)
}

return role, nil
}

func TestKubernetesConnectionLimit(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down Expand Up @@ -1131,13 +1141,28 @@ func TestKubernetesConnectionLimit(t *testing.T) {
require.NoError(t, err)

sem := local.NewPresenceService(backend)
client := &mockSemaphoreClient{sem: sem}
client := &mockSemaphoreClient{
sem: sem,
roles: map[string]types.Role{testCase.role.GetName(): testCase.role},
}

forwarder := newTestForwarder(ctx, ForwarderConfig{
AuthClient: client,
AuthClient: client,
CachingAuthClient: client,
})

identity := &authContext{
Context: auth.Context{
User: user,
Identity: auth.WrapIdentity(tlsca.Identity{
Username: user.GetName(),
Groups: []string{testCase.role.GetName()},
}),
},
}

for i := 0; i < testCase.connections; i++ {
err = forwarder.AcquireConnectionLock(ctx, user.GetName(), services.NewRoleSet(testCase.role))
err = forwarder.acquireConnectionLockWithIdentity(ctx, identity)
if i == testCase.connections-1 {
testCase.assert(t, err)
}
Expand Down

0 comments on commit e569902

Please sign in to comment.