Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Support specifying assume_role_arn for Kube cluster matchers #28832

Merged
merged 1 commit into from Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/cloud/clients.go
Expand Up @@ -987,7 +987,7 @@ func (c *TestCloudClients) GetAzurePostgresClient(subscription string) (azure.DB

// GetAzureKubernetesClient returns an AKS client for the specified subscription
func (c *TestCloudClients) GetAzureKubernetesClient(subscription string) (azure.AKSClient, error) {
if len(c.AzurePostgresPerSub) != 0 {
if len(c.AzureAKSClientPerSub) != 0 {
return c.AzureAKSClientPerSub[subscription], nil
}
return c.AzureAKSClient, nil
Expand Down
69 changes: 68 additions & 1 deletion lib/cloud/mocks/aws.go
Expand Up @@ -18,11 +18,15 @@ package mocks

import (
"context"
"net/http"
"net/url"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/eks/eksiface"
"github.com/aws/aws-sdk-go/service/elasticache"
"github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface"
"github.com/aws/aws-sdk-go/service/iam"
Expand All @@ -46,6 +50,7 @@ import (
type STSMock struct {
stsiface.STSAPI
ARN string
URL *url.URL
assumedRoleARNs []string
assumedRoleExternalIDs []string
mu sync.Mutex
Expand Down Expand Up @@ -98,6 +103,21 @@ func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput
}, nil
}

func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) {
return &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
URL: m.URL,
},
Operation: &request.Operation{
Name: "GetCallerIdentity",
HTTPMethod: "POST",
HTTPPath: "/",
},
Handlers: request.Handlers{},
}, nil
}

// RDSMock mocks AWS RDS API.
type RDSMock struct {
rdsiface.RDSAPI
Expand Down Expand Up @@ -210,6 +230,7 @@ func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyD
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier))
}

func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
if aws.StringValue(input.DBProxyName) == "" {
return &rds.DescribeDBProxiesOutput{
Expand All @@ -225,6 +246,7 @@ func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.Descr
}
return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName))
}

func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
inputProxyName := aws.StringValue(input.DBProxyName)
inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName)
Expand Down Expand Up @@ -254,6 +276,7 @@ func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rd
}
return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil
}

func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds.DescribeDBProxyTargetsInput, options ...request.Option) (*rds.DescribeDBProxyTargetsOutput, error) {
// only mocking to return a port here
return &rds.DescribeDBProxyTargetsOutput{
Expand All @@ -262,18 +285,21 @@ func (m *RDSMock) DescribeDBProxyTargetsWithContext(ctx aws.Context, input *rds.
}},
}, nil
}

func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
fn(&rds.DescribeDBProxiesOutput{
DBProxies: m.DBProxies,
}, true)
return nil
}

func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error {
fn(&rds.DescribeDBProxyEndpointsOutput{
DBProxyEndpoints: m.DBProxyEndpoints,
}, true)
return nil
}

func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) {
return &rds.ListTagsForResourceOutput{}, nil
}
Expand Down Expand Up @@ -381,6 +407,7 @@ func (m *RedshiftMock) GetClusterCredentialsWithContext(aws.Context, *redshift.G
}
return m.GetClusterCredentialsOutput, nil
}

func (m *RedshiftMock) DescribeClustersWithContext(ctx aws.Context, input *redshift.DescribeClustersInput, options ...request.Option) (*redshift.DescribeClustersOutput, error) {
if aws.StringValue(input.ClusterIdentifier) == "" {
return &redshift.DescribeClustersOutput{
Expand Down Expand Up @@ -432,12 +459,15 @@ func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, inp
func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
return trace.AccessDenied("unauthorized")
}

func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
return nil, trace.AccessDenied("unauthorized")
}

func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
return nil, trace.AccessDenied("unauthorized")
}

func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
return trace.AccessDenied("unauthorized")
}
Expand All @@ -453,28 +483,35 @@ type RDSMockByDBType struct {
func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) {
return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) {
return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error {
return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...)
}

func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) {
return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) {
return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error {
return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...)
}

func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) {
return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) {
return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...)
}

func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error {
return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...)
}
Expand Down Expand Up @@ -539,6 +576,7 @@ func (m *ElastiCacheMock) AddMockUser(user *elasticache.User, tagsMap map[string
m.Users = append(m.Users, user)
m.addTags(aws.StringValue(user.ARN), tagsMap)
}

func (m *ElastiCacheMock) addTags(arn string, tagsMap map[string]string) {
if m.TagsByARN == nil {
m.TagsByARN = make(map[string][]*elasticache.Tag)
Expand Down Expand Up @@ -582,6 +620,7 @@ func (m *ElastiCacheMock) DescribeReplicationGroupsWithContext(_ aws.Context, in
}
return nil, trace.NotFound("ElastiCache %v not found", aws.StringValue(input.ReplicationGroupId))
}

func (m *ElastiCacheMock) DescribeReplicationGroupsPagesWithContext(_ aws.Context, _ *elasticache.DescribeReplicationGroupsInput, fn func(*elasticache.DescribeReplicationGroupsOutput, bool) bool, _ ...request.Option) error {
if m.Unauth {
return trace.AccessDenied("unauthorized")
Expand All @@ -591,6 +630,7 @@ func (m *ElastiCacheMock) DescribeReplicationGroupsPagesWithContext(_ aws.Contex
}, true)
return nil
}

func (m *ElastiCacheMock) DescribeUsersPagesWithContext(_ aws.Context, _ *elasticache.DescribeUsersInput, fn func(*elasticache.DescribeUsersOutput, bool) bool, _ ...request.Option) error {
if m.Unauth {
return trace.AccessDenied("unauthorized")
Expand All @@ -607,12 +647,14 @@ func (m *ElastiCacheMock) DescribeCacheClustersPagesWithContext(aws.Context, *el
}
return trace.NotImplemented("elasticache:DescribeCacheClustersPagesWithContext is not implemented")
}

func (m *ElastiCacheMock) DescribeCacheSubnetGroupsPagesWithContext(aws.Context, *elasticache.DescribeCacheSubnetGroupsInput, func(*elasticache.DescribeCacheSubnetGroupsOutput, bool) bool, ...request.Option) error {
if m.Unauth {
return trace.AccessDenied("unauthorized")
}
return trace.NotImplemented("elasticache:DescribeCacheSubnetGroupsPagesWithContext is not implemented")
}

func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *elasticache.ListTagsForResourceInput, _ ...request.Option) (*elasticache.TagListMessage, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
Expand All @@ -630,6 +672,7 @@ func (m *ElastiCacheMock) ListTagsForResourceWithContext(_ aws.Context, input *e
TagList: tags,
}, nil
}

func (m *ElastiCacheMock) ModifyUserWithContext(_ aws.Context, input *elasticache.ModifyUserInput, opts ...request.Option) (*elasticache.ModifyUserOutput, error) {
if m.Unauth {
return nil, trace.AccessDenied("unauthorized")
Expand Down Expand Up @@ -687,6 +730,7 @@ func (m *MemoryDBMock) AddMockUser(user *memorydb.User, tagsMap map[string]strin
m.Users = append(m.Users, user)
m.addTags(aws.StringValue(user.ARN), tagsMap)
}

func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) {
if m.TagsByARN == nil {
m.TagsByARN = make(map[string][]*memorydb.Tag)
Expand All @@ -701,11 +745,12 @@ func (m *MemoryDBMock) addTags(arn string, tagsMap map[string]string) {
}
m.TagsByARN[arn] = tags
}

func (m *MemoryDBMock) DescribeSubnetGroupsWithContext(aws.Context, *memorydb.DescribeSubnetGroupsInput, ...request.Option) (*memorydb.DescribeSubnetGroupsOutput, error) {
return nil, trace.AccessDenied("unauthorized")
}
func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) {

func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memorydb.DescribeClustersInput, _ ...request.Option) (*memorydb.DescribeClustersOutput, error) {
if aws.StringValue(input.ClusterName) == "" {
return &memorydb.DescribeClustersOutput{
Clusters: m.Clusters,
Expand All @@ -721,6 +766,7 @@ func (m *MemoryDBMock) DescribeClustersWithContext(_ aws.Context, input *memoryd
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.ClusterName))
}

func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTagsInput, _ ...request.Option) (*memorydb.ListTagsOutput, error) {
if m.TagsByARN == nil {
return nil, trace.NotFound("no tags")
Expand All @@ -735,11 +781,13 @@ func (m *MemoryDBMock) ListTagsWithContext(_ aws.Context, input *memorydb.ListTa
TagList: tags,
}, nil
}

func (m *MemoryDBMock) DescribeUsersWithContext(aws.Context, *memorydb.DescribeUsersInput, ...request.Option) (*memorydb.DescribeUsersOutput, error) {
return &memorydb.DescribeUsersOutput{
Users: m.Users,
}, nil
}

func (m *MemoryDBMock) UpdateUserWithContext(_ aws.Context, input *memorydb.UpdateUserInput, opts ...request.Option) (*memorydb.UpdateUserOutput, error) {
for _, user := range m.Users {
if aws.StringValue(user.Name) == aws.StringValue(input.UserName) {
Expand Down Expand Up @@ -838,3 +886,22 @@ func RedshiftGetClusterCredentialsOutput(user, password string, clock clockwork.
Expiration: aws.Time(clock.Now().Add(15 * time.Minute)),
}
}

// EKSMock is a mock EKS client.
type EKSMock struct {
eksiface.EKSAPI
Clusters []*eks.Cluster
Notify chan struct{}
}

func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
defer func() {
e.Notify <- struct{}{}
}()
for _, cluster := range e.Clusters {
if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) {
return &eks.DescribeClusterOutput{Cluster: cluster}, nil
}
}
return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name))
}
57 changes: 57 additions & 0 deletions lib/cloud/mocks/azure.go
@@ -0,0 +1,57 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mocks

import (
"context"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"k8s.io/client-go/rest"

"github.com/gravitational/teleport/lib/cloud/azure"
)

// AKSClusterEntry is an entry in the AKSMock.Clusters list.
type AKSClusterEntry struct {
azure.ClusterCredentialsConfig
Config *rest.Config
TTL time.Duration
}

// AKSMock implements the azure.AKSClient interface for tests.
type AKSMock struct {
azure.AKSClient
Clusters []AKSClusterEntry
Notify chan struct{}
Clock clockwork.Clock
}

func (a *AKSMock) ClusterCredentials(ctx context.Context, cfg azure.ClusterCredentialsConfig) (*rest.Config, time.Time, error) {
defer func() {
a.Notify <- struct{}{}
}()
for _, cluster := range a.Clusters {
if cluster.ClusterCredentialsConfig.ResourceGroup == cfg.ResourceGroup &&
cluster.ClusterCredentialsConfig.ResourceName == cfg.ResourceName &&
cluster.ClusterCredentialsConfig.TenantID == cfg.TenantID {
return cluster.Config, a.Clock.Now().Add(cluster.TTL), nil
}
}
return nil, time.Now(), trace.NotFound("cluster not found")
}
31 changes: 31 additions & 0 deletions lib/cloud/mocks/gcp.go
Expand Up @@ -19,8 +19,12 @@ package mocks
import (
"context"
"crypto/tls"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
"k8s.io/client-go/rest"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/cloud/gcp"
Expand Down Expand Up @@ -48,3 +52,30 @@ func (g *GCPSQLAdminClientMock) GetDatabaseInstance(ctx context.Context, db type
func (g *GCPSQLAdminClientMock) GenerateEphemeralCert(ctx context.Context, db types.Database, identity tlsca.Identity) (*tls.Certificate, error) {
return g.EphemeralCert, nil
}

// GKEClusterEntry is an entry in the GKEMock.Clusters list.
type GKEClusterEntry struct {
gcp.ClusterDetails
Config *rest.Config
TTL time.Duration
}

// GKEMock implements the gcp.GKEClient interface for tests.
type GKEMock struct {
gcp.GKEClient
Clusters []GKEClusterEntry
Notify chan struct{}
Clock clockwork.Clock
}

func (g *GKEMock) GetClusterRestConfig(ctx context.Context, cfg gcp.ClusterDetails) (*rest.Config, time.Time, error) {
defer func() {
g.Notify <- struct{}{}
}()
for _, cluster := range g.Clusters {
if cluster.ClusterDetails == cfg {
return cluster.Config, g.Clock.Now().Add(cluster.TTL), nil
}
}
return nil, time.Now(), trace.NotFound("cluster not found")
}
7 changes: 4 additions & 3 deletions lib/config/configuration.go
Expand Up @@ -1342,12 +1342,13 @@ func applyKubeConfig(fc *FileConfig, cfg *servicecfg.Config) error {
}

for _, matcher := range fc.Kube.ResourceMatchers {
if matcher.AWS.AssumeRoleARN != "" {
return trace.NotImplemented("assume_role_arn is not supported for kube resource matchers")
}
cfg.Kube.ResourceMatchers = append(cfg.Kube.ResourceMatchers,
services.ResourceMatcher{
Labels: matcher.Labels,
AWS: services.ResourceMatcherAWS{
AssumeRoleARN: matcher.AWS.AssumeRoleARN,
ExternalID: matcher.AWS.ExternalID,
},
})
}

Expand Down