/
sagemaker.go
93 lines (75 loc) · 2.44 KB
/
sagemaker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package sagemaker
import (
"log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sagemaker"
)
// Client for AWS SageMaker
type Client struct {
*sagemaker.SageMaker
}
// NewClient - SageMaker client initializer
func NewClient(profile, region, sessionName string, sharedConfigFiles []string) (*Client, error) {
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Region: aws.String(region),
},
SharedConfigState: session.SharedConfigEnable,
Profile: profile,
SharedConfigFiles: sharedConfigFiles,
})
if err != nil {
return &Client{}, err
}
return &Client{
sagemaker.New(sess),
}, nil
}
// GetTrainingJob returns the current status of the provided job.
func (sm *Client) GetTrainingJob(jobID, region, iamRole string) (*sagemaker.DescribeTrainingJobOutput, error) {
log.Print(sm.Config.Region, region)
sm.Config.Region = ®ion
log.Print(sm.Config.Region, region)
jobInput := &sagemaker.DescribeTrainingJobInput{
TrainingJobName: aws.String(jobID),
}
job, err := sm.DescribeTrainingJob(jobInput)
if err != nil {
return &sagemaker.DescribeTrainingJobOutput{}, err
}
return job, nil
}
// GetTransformJob returns the job description as provided by AWS.
func (sm *Client) GetTransformJob(jobID, region, iamRole string) (*sagemaker.DescribeTransformJobOutput, error) {
jobInput := &sagemaker.DescribeTransformJobInput{
TransformJobName: aws.String(jobID),
}
job, err := sm.DescribeTransformJob(jobInput)
if err != nil {
return &sagemaker.DescribeTransformJobOutput{}, err
}
return job, nil
}
// GetTuningJob returns the job description as provided by AWS.
func (sm *Client) GetTuningJob(jobID, region, iamRole string) (*sagemaker.DescribeHyperParameterTuningJobOutput, error) {
jobInput := &sagemaker.DescribeHyperParameterTuningJobInput{
HyperParameterTuningJobName: aws.String(jobID),
}
job, err := sm.DescribeHyperParameterTuningJob(jobInput)
if err != nil {
return &sagemaker.DescribeHyperParameterTuningJobOutput{}, err
}
return job, nil
}
// GetEndpoint returns the specified endpoint.
func (sm *Client) GetEndpoint(name, region, iamRole string) (*sagemaker.DescribeEndpointOutput, error) {
endpointInput := &sagemaker.DescribeEndpointInput{
EndpointName: aws.String(name),
}
endpoint, err := sm.DescribeEndpoint(endpointInput)
if err != nil {
return &sagemaker.DescribeEndpointOutput{}, err
}
return endpoint, nil
}