/
client.go
137 lines (112 loc) · 3.03 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/*
* Copyright (c) 2018 Lyft. All rights reserved.
*/
// Package aws contains AWS-specific logic to handle execution and monitoring of batch jobs.
package aws
import (
"context"
"fmt"
"os"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/flyteorg/flyte/flytestdlib/errors"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
const (
EnvSharedCredFilePath = "AWS_SHARED_CREDENTIALS_FILE" // #nosec
EnvAwsProfile = "AWS_PROFILE"
ErrEmptyCredentials errors.ErrorCode = "EMPTY_CREDS"
ErrUnknownHost errors.ErrorCode = "UNKNOWN_HOST"
)
type singleton struct {
client Client
lock sync.RWMutex
}
var single = singleton{
lock: sync.RWMutex{},
}
// Client is a generic AWS Client that can be used for all AWS Client libraries.
type Client interface {
GetSession() *session.Session
GetSdkConfig() *aws.Config
GetConfig() *Config
GetHostName() string
}
type client struct {
config *Config
Session *session.Session
SdkConfig *aws.Config
HostName string
}
// Gets the initialized session.
func (c client) GetSession() *session.Session {
return c.Session
}
// Gets the final config that was used to initialize AWS Session.
func (c client) GetSdkConfig() *aws.Config {
return c.SdkConfig
}
// Gets client's Hostname
func (c client) GetHostName() string {
return c.HostName
}
func (c client) GetConfig() *Config {
return c.config
}
func newClient(ctx context.Context, cfg *Config) (Client, error) {
awsConfig := aws.NewConfig().WithRegion(cfg.Region).WithMaxRetries(cfg.Retries)
if os.Getenv(EnvSharedCredFilePath) != "" {
creds := credentials.NewSharedCredentials(os.Getenv(EnvSharedCredFilePath), os.Getenv(EnvAwsProfile))
if creds == nil {
return nil, fmt.Errorf("unable to Load AWS credentials")
}
_, e := creds.Get()
if e != nil {
return nil, errors.Wrapf(ErrEmptyCredentials, e, "Empty credentials")
}
awsConfig = awsConfig.WithCredentials(creds)
}
sess, err := session.NewSession(awsConfig)
if err != nil {
logger.Fatalf(ctx, "Error while creating session: %v", err)
}
hostname, err := os.Hostname()
if err != nil {
return nil, errors.Wrapf(ErrUnknownHost, err, "Unable to discover current hostname")
}
return &client{
config: cfg,
SdkConfig: awsConfig,
Session: sess,
HostName: hostname,
}, nil
}
// Initializes singleton AWS Client if one hasn't been initialized yet.
func Init(ctx context.Context, cfg *Config) (err error) {
if single.client == nil {
single.lock.Lock()
defer single.lock.Unlock()
if single.client == nil {
single.client, err = newClient(ctx, cfg)
}
}
return err
}
// Gets singleton AWS Client.
func GetClient() (c Client, err error) {
single.lock.RLock()
defer single.lock.RUnlock()
if single.client == nil {
single.client, err = newClient(context.TODO(), GetConfig())
}
return single.client, err
}
func SetClient(c Client) {
single.lock.Lock()
defer single.lock.Unlock()
if single.client == nil {
single.client = c
}
}