Skip to content

Commit

Permalink
Add structure for passing hardcoded credentials to task. (#722)
Browse files Browse the repository at this point in the history
Co-authored-by: Domas Monkus <domas@iterative.ai>

Co-authored-by: Helio Machado <0x2b3bfa0+git@googlemail.com>
  • Loading branch information
tasdomas and 0x2b3bfa0 authored Nov 18, 2022
1 parent bd37b68 commit 85645f4
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 43 deletions.
18 changes: 17 additions & 1 deletion task/aws/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/s3"
Expand All @@ -29,7 +30,22 @@ func New(ctx context.Context, cloud common.Cloud, tags map[string]string) (*Clie
region = val
}

config, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
options := []func(*config.LoadOptions) error{
config.WithRegion(region),
}

if awsCredentials := cloud.Credentials.AWSCredentials; awsCredentials != nil {
options = append(options, config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: aws.Credentials{
AccessKeyID: awsCredentials.AccessKeyID,
SecretAccessKey: awsCredentials.SecretAccessKey,
SessionToken: awsCredentials.SessionToken,
Source: "user-specified credentials",
},
}))
}

config, err := config.LoadDefaultConfig(ctx, options...)
if err != nil {
return nil, err
}
Expand Down
82 changes: 51 additions & 31 deletions task/az/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,52 @@ import (
"github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2020-06-01/resources"
"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2021-04-01/storage"

"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure/auth"

"terraform-provider-iterative/task/common"
"terraform-provider-iterative/task/common/ssh"
)

func New(ctx context.Context, cloud common.Cloud, tags map[string]string) (*Client, error) {
settings, err := auth.GetSettingsFromEnvironment()
if err != nil {
return nil, err
}

subscription := settings.GetSubscriptionID()
if subscription == "" {
return nil, errors.New("subscription environment variable not found")
}

authorizer, err := settings.GetAuthorizer()
if err != nil {
return nil, err
var authorizer autorest.Authorizer

if azCredentials := cloud.Credentials.AZCredentials; azCredentials != nil {
au, err := auth.NewClientCredentialsConfig(
azCredentials.ClientID,
azCredentials.ClientSecret,
azCredentials.TenantID,
).Authorizer()
if err != nil {
return nil, err
}
authorizer = au
} else {
settings, err := auth.GetSettingsFromEnvironment()
if err != nil {
return nil, err
}
credentials, err := settings.GetClientCredentials()
if err != nil {
return nil, err
}
authorizer, err = settings.GetAuthorizer()
if err != nil {
return nil, err
}

cloud.Credentials.AZCredentials = &common.AZCredentials{
SubscriptionID: settings.GetSubscriptionID(),
ClientID: credentials.ClientID,
ClientSecret: credentials.ClientSecret,
TenantID: credentials.TenantID,
}
}

agent := "tpi"

c := &Client{
Cloud: cloud,
Settings: settings,
Cloud: cloud,
}

for key, value := range tags {
Expand All @@ -54,75 +73,79 @@ func New(ctx context.Context, cloud common.Cloud, tags map[string]string) (*Clie
region = val
}

if cloud.Credentials.AZCredentials.SubscriptionID == "" {
return nil, errors.New("subscription environment variable not found")
}

c.Region = region

c.Services.Groups = resources.NewGroupsClient(subscription)
c.Services.Groups = resources.NewGroupsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.Groups.Authorizer = authorizer
if err := c.Services.Groups.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.SecurityGroups = network.NewSecurityGroupsClient(subscription)
c.Services.SecurityGroups = network.NewSecurityGroupsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.SecurityGroups.Authorizer = authorizer
if err := c.Services.SecurityGroups.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.PublicIPPrefixes = network.NewPublicIPPrefixesClient(subscription)
c.Services.PublicIPPrefixes = network.NewPublicIPPrefixesClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.PublicIPPrefixes.Authorizer = authorizer
if err := c.Services.PublicIPPrefixes.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.PublicIPAddresses = network.NewPublicIPAddressesClient(subscription)
c.Services.PublicIPAddresses = network.NewPublicIPAddressesClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.PublicIPAddresses.Authorizer = authorizer
if err := c.Services.PublicIPAddresses.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.VirtualNetworks = network.NewVirtualNetworksClient(subscription)
c.Services.VirtualNetworks = network.NewVirtualNetworksClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.VirtualNetworks.Authorizer = authorizer
if err := c.Services.VirtualNetworks.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.Subnets = network.NewSubnetsClient(subscription)
c.Services.Subnets = network.NewSubnetsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.Subnets.Authorizer = authorizer
if err := c.Services.Subnets.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.Interfaces = network.NewInterfacesClient(subscription)
c.Services.Interfaces = network.NewInterfacesClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.Interfaces.Authorizer = authorizer
if err := c.Services.Interfaces.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.VirtualMachines = compute.NewVirtualMachinesClient(subscription)
c.Services.VirtualMachines = compute.NewVirtualMachinesClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.VirtualMachines.Authorizer = authorizer
if err := c.Services.VirtualMachines.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.VirtualMachineScaleSets = compute.NewVirtualMachineScaleSetsClient(subscription)
c.Services.VirtualMachineScaleSets = compute.NewVirtualMachineScaleSetsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.VirtualMachineScaleSets.Authorizer = authorizer
if err := c.Services.VirtualMachineScaleSets.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.VirtualMachineScaleSetVMs = compute.NewVirtualMachineScaleSetVMsClient(subscription)
c.Services.VirtualMachineScaleSetVMs = compute.NewVirtualMachineScaleSetVMsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.VirtualMachineScaleSetVMs.Authorizer = authorizer
if err := c.Services.VirtualMachineScaleSetVMs.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.StorageAccounts = storage.NewAccountsClient(subscription)
c.Services.StorageAccounts = storage.NewAccountsClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.StorageAccounts.Authorizer = authorizer
if err := c.Services.StorageAccounts.AddToUserAgent(agent); err != nil {
return nil, err
}

c.Services.BlobContainers = storage.NewBlobContainersClient(subscription)
c.Services.BlobContainers = storage.NewBlobContainersClient(cloud.Credentials.AZCredentials.SubscriptionID)
c.Services.BlobContainers.Authorizer = authorizer
if err := c.Services.BlobContainers.AddToUserAgent(agent); err != nil {
return nil, err
Expand Down Expand Up @@ -153,10 +176,7 @@ type Client struct {
}

func (c *Client) GetKeyPair(ctx context.Context) (*ssh.DeterministicSSHKeyPair, error) {
credentials, err := c.Settings.GetClientCredentials()
if err != nil {
return nil, err
}
credentials := c.Cloud.Credentials.AZCredentials

if len(credentials.ClientSecret) == 0 {
return nil, errors.New("unable to find client secret")
Expand Down
9 changes: 2 additions & 7 deletions task/az/resources/data_source_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ type Credentials struct {
}

func (c *Credentials) Read(ctx context.Context) error {
credentials, err := c.client.Settings.GetClientCredentials()
if err != nil {
return err
}
credentials := c.client.Cloud.Credentials.AZCredentials

if len(credentials.ClientSecret) == 0 {
return errors.New("unable to find client secret")
Expand All @@ -44,12 +41,10 @@ func (c *Credentials) Read(ctx context.Context) error {
return err
}

subscriptionID := c.client.Settings.GetSubscriptionID()

c.Resource = map[string]string{
"AZURE_CLIENT_ID": credentials.ClientID,
"AZURE_CLIENT_SECRET": credentials.ClientSecret,
"AZURE_SUBSCRIPTION_ID": subscriptionID,
"AZURE_SUBSCRIPTION_ID": credentials.SubscriptionID,
"AZURE_TENANT_ID": credentials.TenantID,
"RCLONE_REMOTE": connectionString,
"TPI_TASK_CLOUD_PROVIDER": string(c.client.Cloud.Provider),
Expand Down
37 changes: 33 additions & 4 deletions task/common/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
)

type Cloud struct {
Timeouts Timeouts
Provider Provider
Region Region
Tags map[string]string
Timeouts Timeouts
Provider Provider
Credentials Credentials
Region Region
Tags map[string]string
}

type Timeouts struct {
Expand All @@ -29,6 +30,34 @@ const (
ProviderK8S Provider = "k8s"
)

type Credentials struct {
AWSCredentials *AWSCredentials
GCPCredentials *GCPCredentials
AZCredentials *AZCredentials
K8SCredentials *K8SCredentials
}

type AWSCredentials struct {
AccessKeyID string // AWS_ACCESS_KEY_ID
SecretAccessKey string // AWS_SECRET_ACCESS_KEY
SessionToken string // AWS_SESSION_TOKEN
}

type GCPCredentials struct {
ApplicationCredentials string // GOOGLE_APPLICATION_CREDENTIALS (contents of file)
}

type AZCredentials struct {
ClientID string // AZURE_CLIENT_ID
ClientSecret string // AZURE_CLIENT_SECRET
SubscriptionID string // AZURE_SUBSCRIPTION_ID
TenantID string // AZURE_TENANT_ID
}

type K8SCredentials struct {
Config string // KUBECONFIG (contents of file)
}

func (c *Cloud) GetClosestRegion(regions map[string]Region) (string, error) {
for key, value := range regions {
if value == c.Region {
Expand Down
4 changes: 4 additions & 0 deletions task/gcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ func New(ctx context.Context, cloud common.Cloud, tags map[string]string) (*Clie

credentialsData := []byte(os.Getenv("GOOGLE_APPLICATION_CREDENTIALS_DATA"))

if gcpCredentials := cloud.Credentials.GCPCredentials; gcpCredentials != nil {
credentialsData = []byte(gcpCredentials.ApplicationCredentials)
}

var err error
var credentials *google.Credentials
if len(credentialsData) > 0 {
Expand Down
4 changes: 4 additions & 0 deletions task/k8s/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func New(ctx context.Context, cloud common.Cloud, tags map[string]string) (*Clie
kubeconfig = os.Getenv("KUBECONFIG_DATA")
}

if k8sCredentials := cloud.Credentials.K8SCredentials; k8sCredentials != nil {
kubeconfig = k8sCredentials.Config
}

config, err := clientcmd.NewClientConfigFromBytes([]byte(kubeconfig))
if err != nil || kubeconfig == "" {
config = clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
Expand Down

0 comments on commit 85645f4

Please sign in to comment.