Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions pkg/redshift/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
)

type API struct {
DataClient redshiftdataapiserviceiface.RedshiftDataAPIServiceAPI
SecretsClient secretsmanageriface.SecretsManagerAPI
ManagementClient redshiftiface.RedshiftAPI
settings *models.RedshiftDataSourceSettings
DataClient redshiftdataapiserviceiface.RedshiftDataAPIServiceAPI
SecretsClient secretsmanageriface.SecretsManagerAPI
ManagementClient redshiftiface.RedshiftAPI
settings *models.RedshiftDataSourceSettings
}

func New(sessionCache *awsds.SessionCache, settings awsModels.Settings) (api.AWSAPI, error) {
Expand Down Expand Up @@ -54,10 +54,10 @@ func New(sessionCache *awsds.SessionCache, settings awsModels.Settings) (api.AWS
}

return &API{
DataClient: redshiftdataapiservice.New(sess),
SecretsClient: secretsmanager.New(sess),
DataClient: redshiftdataapiservice.New(sess),
SecretsClient: secretsmanager.New(sess),
ManagementClient: redshift.New(sess),
settings: redshiftSettings,
settings: redshiftSettings,
}, nil
}

Expand Down Expand Up @@ -342,30 +342,26 @@ func (c *API) Secret(ctx aws.Context, options sqlds.Options) (*models.RedshiftSe
return res, nil
}

func (c *API) Cluster(options sqlds.Options) (*models.RedshiftCluster, error) {
clusterId := options["clusterIdentifier"]
input := &redshift.DescribeClustersInput{
ClusterIdentifier: aws.String(clusterId),
}
out, err := c.ManagementClient.DescribeClusters(input)
func (c *API) Clusters() ([]models.RedshiftCluster, error) {
out, err := c.ManagementClient.DescribeClusters(&redshift.DescribeClustersInput{})
if err != nil {
return nil, err
}
if out == nil {
return nil, fmt.Errorf("missing cluster content")
return nil, fmt.Errorf("missing clusters content")
}
res := &models.RedshiftCluster{}
for _,r := range out.Clusters {
if (r != nil && r.ClusterIdentifier != nil && *r.ClusterIdentifier == clusterId && r.Endpoint != nil && r.Endpoint.Address != nil && r.Endpoint.Port != nil) {
res.Endpoint = models.RedshiftEndpoint{
Address: *r.Endpoint.Address,
Port: *r.Endpoint.Port,
}
if (r.DBName != nil) {
res.Database = *r.DBName
}
return res, nil
res := []models.RedshiftCluster{}
for _, r := range out.Clusters {
if (r != nil && r.ClusterIdentifier != nil && r.Endpoint != nil && r.Endpoint.Address != nil && r.Endpoint.Port != nil && r.DBName != nil) {
res = append(res, models.RedshiftCluster{
ClusterIdentifier: *r.ClusterIdentifier,
Endpoint: models.RedshiftEndpoint{
Address: *r.Endpoint.Address,
Port: *r.Endpoint.Port,
},
Database: *r.DBName,
})
}
}
return nil, fmt.Errorf("ClusterId %s not found", clusterId)
return res, nil
}
84 changes: 41 additions & 43 deletions pkg/redshift/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func Test_apiInput(t *testing.T) {

func Test_Execute(t *testing.T) {
c := &API{
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{ExecutionResult: &redshiftdataapiservice.ExecuteStatementOutput{Id: aws.String("foo")}},
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{ExecutionResult: &redshiftdataapiservice.ExecuteStatementOutput{Id: aws.String("foo")}},
}
res, err := c.Execute(context.TODO(), &api.ExecuteQueryInput{Query: "select * from foo"})
if err != nil {
Expand Down Expand Up @@ -132,8 +132,8 @@ func Test_ListSchemas(t *testing.T) {
}
expectedResult := []string{"bar", "foo"}
c := &API{
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
}
res, err := c.Schemas(context.TODO(), sqlds.Options{})
if err != nil {
Expand All @@ -156,8 +156,8 @@ func Test_ListTables(t *testing.T) {
}
expectedResult := []string{"foofoo"}
c := &API{
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
}
res, err := c.Tables(context.TODO(), sqlds.Options{"schema": "foo"})
if err != nil {
Expand All @@ -182,8 +182,8 @@ func Test_ListColumns(t *testing.T) {
}
expectedResult := []string{"col1", "col2"}
c := &API{
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
settings: &models.RedshiftDataSourceSettings{},
DataClient: &redshiftclientmock.MockRedshiftClient{Resources: resources},
}
res, err := c.Columns(context.TODO(), sqlds.Options{"schema": "public", "table": "foo"})
if err != nil {
Expand Down Expand Up @@ -218,58 +218,56 @@ func Test_GetSecret(t *testing.T) {
}
}

func Test_GetCluster(t *testing.T) {
fooC := &API{ManagementClient: &redshiftclientmock.MockRedshiftClient{Clusters: []string{"foo"}}}
c := &API{ManagementClient: &redshiftclientmock.MockRedshiftClient{Clusters: []string{"foo"}}}
func Test_GetClusters(t *testing.T) {
c := &API{ManagementClient: &redshiftclientmock.MockRedshiftClient{Clusters: []string{"foo", "bar"}}}
errC := &API{ManagementClient: &redshiftclientmock.MockRedshiftClientError{}}
nilC := &API{ManagementClient: &redshiftclientmock.MockRedshiftClientNil{}}
expectedCluster := &models.RedshiftCluster{
Endpoint: models.RedshiftEndpoint{
Address: "foo",
Port: 123,
},
expectedCluster1 := &models.RedshiftCluster{
ClusterIdentifier: "foo",
Endpoint: models.RedshiftEndpoint{
Address: "foo",
Port: 123,
},
Database: "foo",
}
expectedCluster2 := &models.RedshiftCluster{
ClusterIdentifier: "bar",
Endpoint: models.RedshiftEndpoint{
Address: "bar",
Port: 123,
},
Database: "bar",
}
tests := []struct {
c *API
desc string
clusterId string
errMsg string
expectedCluster *models.RedshiftCluster
c *API
desc string
errMsg string
expectedClusters []models.RedshiftCluster
}{
{
c: c,
desc: "Happy Path",
clusterId: "foo",
expectedCluster: expectedCluster,
},
{
c: fooC,
desc: "Error cluster ID not found",
clusterId: "xyz",
errMsg: "ClusterId xyz not found",
c: c,
desc: "Happy Path",
expectedClusters: []models.RedshiftCluster{*expectedCluster1, *expectedCluster2},
},
{
c: errC,
desc: "Error with DescribeCluster",
clusterId: "foo",
errMsg: "Boom!",
c: errC,
desc: "Error with DescribeCluster",
errMsg: "Boom!",
},
{
c: nilC,
desc: "DescribeCluster returned nil",
clusterId: "foo",
errMsg: "missing cluster content",
c: nilC,
desc: "DescribeCluster returned nil",
errMsg: "missing clusters content",
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
cluster, err := tt.c.Cluster(sqlds.Options{"clusterIdentifier": tt.clusterId})
if (tt.errMsg == "") {
clusters, err := tt.c.Clusters()
if tt.errMsg == "" {
assert.NoError(t, err)
assert.Equal(t, expectedCluster, cluster)
assert.Equal(t, tt.expectedClusters, clusters)
} else {
assert.Nil(t, cluster)
assert.Nil(t, clusters)
assert.EqualError(t, err, tt.errMsg)
}
})
Expand Down
6 changes: 3 additions & 3 deletions pkg/redshift/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type RedshiftDatasourceIface interface {
Columns(ctx context.Context, options sqlds.Options) ([]string, error)
Secrets(ctx context.Context, options sqlds.Options) ([]models.ManagedSecret, error)
Secret(ctx context.Context, options sqlds.Options) (*models.RedshiftSecret, error)
Cluster(ctx context.Context, options sqlds.Options) (*models.RedshiftCluster, error)
Clusters(ctx context.Context, options sqlds.Options) ([]models.RedshiftCluster, error)
}

type RedshiftDatasource struct {
Expand Down Expand Up @@ -141,10 +141,10 @@ func (s *RedshiftDatasource) Secret(ctx context.Context, options sqlds.Options)
return api.Secret(ctx, options)
}

func (s *RedshiftDatasource) Cluster(ctx context.Context, options sqlds.Options) (*models.RedshiftCluster, error) {
func (s *RedshiftDatasource) Clusters(ctx context.Context, options sqlds.Options) ([]models.RedshiftCluster, error) {
api, err := s.getApi(ctx, options)
if err != nil {
return nil, err
}
return api.Cluster(options)
return api.Clusters()
}
6 changes: 3 additions & 3 deletions pkg/redshift/fake/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
type RedshiftFakeDatasource struct {
SecretList []models.ManagedSecret
RSecret models.RedshiftSecret
RCluster models.RedshiftCluster
RClusters []models.RedshiftCluster
}

func (s *RedshiftFakeDatasource) Settings(_ backend.DataSourceInstanceSettings) sqlds.DriverSettings {
Expand Down Expand Up @@ -60,6 +60,6 @@ func (s *RedshiftFakeDatasource) Secrets(ctx context.Context, options sqlds.Opti
func (s *RedshiftFakeDatasource) Secret(ctx context.Context, options sqlds.Options) (*models.RedshiftSecret, error) {
return &s.RSecret, nil
}
func (s *RedshiftFakeDatasource) Cluster(ctx context.Context, options sqlds.Options) (*models.RedshiftCluster, error) {
return &s.RCluster, nil
func (s *RedshiftFakeDatasource) Clusters(ctx context.Context, options sqlds.Options) ([]models.RedshiftCluster, error) {
return s.RClusters, nil
}
7 changes: 4 additions & 3 deletions pkg/redshift/models/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ type RedshiftSecret struct {

type RedshiftEndpoint struct {
Address string `json:"address"`
Port int64 `json:"port"`
Port int64 `json:"port"`
}

type RedshiftCluster struct {
Endpoint RedshiftEndpoint `json:"endpoint"`
Database string `json:"database"`
ClusterIdentifier string `json:"clusterIdentifier"`
Endpoint RedshiftEndpoint `json:"endpoint"`
Database string `json:"database"`
}

type RedshiftDataSourceSettings struct {
Expand Down
14 changes: 4 additions & 10 deletions pkg/redshift/routes/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,15 @@ func (r *RedshiftResourceHandler) secret(rw http.ResponseWriter, req *http.Reque
routes.SendResources(rw, secret, err)
}

func (r *RedshiftResourceHandler) cluster(rw http.ResponseWriter, req *http.Request) {
reqBody, err := routes.ParseBody(req.Body)
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
routes.Write(rw, []byte(err.Error()))
return
}
cluster, err := r.redshift.Cluster(req.Context(), reqBody)
routes.SendResources(rw, cluster, err)
func (r *RedshiftResourceHandler) clusters(rw http.ResponseWriter, req *http.Request) {
clusters, err := r.redshift.Clusters(req.Context(), sqlds.Options{})
routes.SendResources(rw, clusters, err)
}

func (r *RedshiftResourceHandler) Routes() map[string]func(http.ResponseWriter, *http.Request) {
routes := r.DefaultRoutes()
routes["/secrets"] = r.secrets
routes["/secret"] = r.secret
routes["/cluster"] = r.cluster
routes["/clusters"] = r.clusters
return routes
}
18 changes: 10 additions & 8 deletions pkg/redshift/routes/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ var ds = &fake.RedshiftFakeDatasource{
{Name: "secret1", ARN: "arn:secret1"},
},
RSecret: models.RedshiftSecret{ClusterIdentifier: "clu", DBUser: "user"},
RCluster: models.RedshiftCluster{
RClusters: []models.RedshiftCluster{models.RedshiftCluster{
ClusterIdentifier: "foo",
Endpoint: models.RedshiftEndpoint{
Address: "foo.a.b.c",
Port: 123,
Port: 123,
},
Database: "db-foo",
},
},
}

func TestRoutes(t *testing.T) {
Expand All @@ -47,10 +49,10 @@ func TestRoutes(t *testing.T) {
expectedResult: `{"dbClusterIdentifier":"clu","username":"user"}`,
},
{
description: "return cluster",
route: "cluster",
description: "return clusters",
route: "clusters",
expectedCode: http.StatusOK,
expectedResult: `{"endpoint":{"address":"foo.a.b.c","port":123},"database":"db-foo"}`,
expectedResult: `[{"clusterIdentifier":"foo","endpoint":{"address":"foo.a.b.c","port":123},"database":"db-foo"}]`,
},
}
for _, tt := range tests {
Expand All @@ -63,8 +65,8 @@ func TestRoutes(t *testing.T) {
rh.secrets(rw, req)
case "secret":
rh.secret(rw, req)
case "cluster":
rh.cluster(rw, req)
case "clusters":
rh.clusters(rw, req)
default:
t.Fatalf("unexpected route %s", tt.route)
}
Expand All @@ -90,5 +92,5 @@ func Test_Routes(t *testing.T) {
r := rh.Routes()
assert.Contains(t, r, "/secrets")
assert.Contains(t, r, "/secret")
assert.Contains(t, r, "/cluster")
assert.Contains(t, r, "/clusters")
}