Skip to content

Commit

Permalink
Add GPU instance support to the up command
Browse files Browse the repository at this point in the history
Refer to aws#729
  • Loading branch information
efekarakus committed Mar 8, 2019
1 parent 320c6ce commit 9a6b2dc
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 79 deletions.
51 changes: 22 additions & 29 deletions ecs-cli/modules/cli/cluster/cluster_app.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"bufio"
"fmt"
"os"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -72,10 +71,6 @@ const (
ParameterKeySpotPrice = "SpotPrice"
)

const (
defaultARM64InstanceType = "a1.medium"
)

var flagNamesToStackParameterKeys map[string]string
var requiredParameters []string = []string{ParameterKeyCluster}

Expand Down Expand Up @@ -312,20 +307,13 @@ func createCluster(context *cli.Context, awsClients *AWSClients, commandConfig *
}

if launchType == config.LaunchTypeEC2 {
architecture, err := determineArchitecture(cfnParams)
if err != nil {
return err
}

// Check if image id was supplied, else populate
_, err = cfnParams.GetParameter(ParameterKeyAmiId)
if err == cloudformation.ParameterNotFoundError {
amiMetadata, err := metadataClient.GetRecommendedECSLinuxAMI(architecture)
err := populateAMIID(cfnParams, metadataClient)
if err != nil {
return err
}
logrus.Infof("Using recommended %s AMI with ECS Agent %s and %s", amiMetadata.OsName, amiMetadata.AgentVersion, amiMetadata.RuntimeVersion)
cfnParams.Add(ParameterKeyAmiId, amiMetadata.ImageID)
} else if err != nil {
return err
}
Expand Down Expand Up @@ -359,7 +347,6 @@ func createCluster(context *cli.Context, awsClients *AWSClients, commandConfig *
return err
}

logrus.Info("Waiting for your cluster resources to be created...")
// Wait for stack creation
return cfnClient.WaitUntilCreateComplete(stackName)
}
Expand Down Expand Up @@ -388,26 +375,32 @@ func canEnableContainerInstanceTagging(client ecsclient.ECSClient) (bool, error)
return false, nil
}

func determineArchitecture(cfnParams *cloudformation.CfnStackParams) (string, error) {
architecture := amimetadata.ArchitectureTypeX86
func retrieveInstanceType(cfnParams *cloudformation.CfnStackParams) (string, error) {
param, err := cfnParams.GetParameter(ParameterKeyInstanceType)

// a1 instances get the Arm based ECS AMI
instanceTypeParam, err := cfnParams.GetParameter(ParameterKeyInstanceType)
if err == cloudformation.ParameterNotFoundError {
logrus.Infof("Defaulting instance type to t2.micro")
} else if err != nil {
return cloudformation.DefaultECSInstanceType, nil
}
if err != nil {
return "", err
} else {
instanceType := aws.StringValue(instanceTypeParam.ParameterValue)
// This regex matches all current a1 instances, and should work for any future additions as well
r := regexp.MustCompile("a1\\.(medium|\\d*x?large)")
if r.MatchString(instanceType) {
logrus.Infof("Using Arm ecs-optimized AMI because instance type was %s", instanceType)
architecture = amimetadata.ArchitectureTypeARM64
}
}
return aws.StringValue(param.ParameterValue), nil
}

func populateAMIID(cfnParams *cloudformation.CfnStackParams, client amimetadata.Client) error {
instanceType, err := retrieveInstanceType(cfnParams)
if err != nil {
return err
}

return architecture, nil
amiMetadata, err := client.GetRecommendedECSLinuxAMI(instanceType)
if err != nil {
return err
}
logrus.Infof("Using recommended %s AMI with ECS Agent %s and %s",
amiMetadata.OsName, amiMetadata.AgentVersion, amiMetadata.RuntimeVersion)
cfnParams.Add(ParameterKeyAmiId, amiMetadata.ImageID)
return nil
}

// unfortunately go SDK lacks a unified Tag type
Expand Down
43 changes: 8 additions & 35 deletions ecs-cli/modules/cli/cluster/cluster_app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func TestClusterUpWithForce(t *testing.T) {
)

gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)

gomock.InOrder(
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestClusterUpWithoutPublicIP(t *testing.T) {
)

gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)

gomock.InOrder(
Expand Down Expand Up @@ -232,7 +232,7 @@ func TestClusterUpWithUserData(t *testing.T) {
)

gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)

gomock.InOrder(
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestClusterUpWithSpotPrice(t *testing.T) {
)

gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)

gomock.InOrder(
Expand Down Expand Up @@ -976,7 +976,7 @@ func TestClusterUpARM64(t *testing.T) {
)

gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("arm64").Return(amiMetadata(armAMIID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("a1.medium").Return(amiMetadata(armAMIID), nil),
)

gomock.InOrder(
Expand Down Expand Up @@ -1050,7 +1050,7 @@ func TestClusterUpWithTags(t *testing.T) {
}),
)
gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)
gomock.InOrder(
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),
Expand Down Expand Up @@ -1131,7 +1131,7 @@ func TestClusterUpWithTagsContainerInstanceTaggingEnabled(t *testing.T) {
}),
)
gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)
gomock.InOrder(
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),
Expand Down Expand Up @@ -1165,33 +1165,6 @@ func TestClusterUpWithTagsContainerInstanceTaggingEnabled(t *testing.T) {
assert.True(t, userdataMock.enableTagging, "Expected tagging to be enabled in container instance user data")
}

func TestDetermineArchitecture(t *testing.T) {
var testCases = []struct {
in string
out string
}{
{"a1.medium", "arm64"},
{"a1.large", "arm64"},
{"a1.xlarge", "arm64"},
{"a1.2xlarge", "arm64"},
{"a1.4xlarge", "arm64"},
{"t2.medium", "x86"},
{"c5.large", "x86"},
{"i3.metal", "x86"},
{"t3.micro", "x86"},
}

for _, tt := range testCases {
t.Run(tt.in, func(t *testing.T) {
cfnParams := cloudformation.NewCfnStackParams(requiredParameters)
cfnParams.Add(ParameterKeyInstanceType, tt.in)
arch, err := determineArchitecture(cfnParams)
assert.NoError(t, err, "Unexpected error determining architecture")
assert.Equal(t, tt.out, arch, "Expected architecture to match")
})
}
}

///////////////////
// Cluster Down //
//////////////////
Expand Down Expand Up @@ -1397,7 +1370,7 @@ func mocksForSuccessfulClusterUp(mockECS *mock_ecs.MockECSClient, mockCloudforma
mockECS.EXPECT().CreateCluster(clusterName, gomock.Any()).Return(clusterName, nil),
)
gomock.InOrder(
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
)
gomock.InOrder(
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),
Expand Down
71 changes: 59 additions & 12 deletions ecs-cli/modules/clients/aws/amimetadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,42 @@
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

// Package amimetadata provides AMI metadata given an instance type.
package amimetadata

import (
"encoding/json"

"github.com/aws/amazon-ecs-cli/ecs-cli/modules/clients"
"github.com/aws/amazon-ecs-cli/ecs-cli/modules/config"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"regexp"
"strings"
)

// SSM parameter names to retrieve ECS optimized AMI.
// See: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/retrieve-ecs-optimized_AMI.html
const (
amazonLinux2X86RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/recommended"
amazonLinux2ARM64RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/arm64/recommended"
amazonLinux2X86RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/recommended"
amazonLinux2ARM64RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/arm64/recommended"
amazonLinux2X86GPURecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended"
)

// Architecture types of EC2 instances.
const (
ArchitectureTypeARM64 = "arm64"
ArchitectureTypeX86 = "x86"
)

// AMIMetadata is returned through ssm:GetParameters and can be used to retrieve the ImageId
// while launching instances.
//
// See: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/retrieve-ecs-optimized_AMI.html
// See: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-as-launchconfig.html#cfn-as-launchconfig-imageid
type AMIMetadata struct {
ImageID string `json:"image_id"`
OsName string `json:"os"`
Expand All @@ -47,13 +59,13 @@ type Client interface {
GetRecommendedECSLinuxAMI(string) (*AMIMetadata, error)
}

// ssmClient implements Client
// metadataClient implements Client.
type metadataClient struct {
client ssmiface.SSMAPI
region string
}

// NewSSMClient creates an instance of Client.
// NewMetadataClient creates an instance of Client.
func NewMetadataClient(commandConfig *config.CommandConfig) Client {
client := ssm.New(commandConfig.Session)
client.Handlers.Build.PushBackNamed(clients.CustomUserAgentHandler())
Expand All @@ -63,20 +75,29 @@ func NewMetadataClient(commandConfig *config.CommandConfig) Client {
}
}

func (c *metadataClient) GetRecommendedECSLinuxAMI(architecture string) (*AMIMetadata, error) {
ssmParam := amazonLinux2X86RecommendedParameterName
if architecture == ArchitectureTypeARM64 {
ssmParam = amazonLinux2ARM64RecommendedParameterName
// GetRecommendedECSLinuxAMI returns the recommended Amazon ECS-Optimized AMI Metadata given the instance type.
func (c *metadataClient) GetRecommendedECSLinuxAMI(instanceType string) (*AMIMetadata, error) {
if isARM64Instance(instanceType) {
return c.parameterValueFor(amazonLinux2ARM64RecommendedParameterName)
}
if isGPUInstance(instanceType) {
return c.parameterValueFor(amazonLinux2X86GPURecommendedParameterName)
}
return c.parameterValueFor(amazonLinux2X86RecommendedParameterName)
}

func (c *metadataClient) parameterValueFor(ssmParamName string) (*AMIMetadata, error) {
response, err := c.client.GetParameter(&ssm.GetParameterInput{
Name: aws.String(ssmParam),
Name: aws.String(ssmParamName),
})
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
if aerr.Code() == ssm.ErrCodeParameterNotFound {
// Added for arm AMIs which are only supported in some regions
return nil, errors.Wrapf(err, "Could not find Recommended Amazon Linux 2 AMI in %s with architecture %s; the AMI may not be supported in this region", c.region, architecture)
// Added for AMIs which are only supported in some regions
return nil, errors.Wrapf(err,
"Could not find Recommended Amazon Linux 2 AMI %s in %s; the AMI may not be supported in this region",
ssmParamName,
c.region)
}
}
return nil, err
Expand All @@ -85,3 +106,29 @@ func (c *metadataClient) GetRecommendedECSLinuxAMI(architecture string) (*AMIMet
err = json.Unmarshal([]byte(aws.StringValue(response.Parameter.Value)), metadata)
return metadata, err
}

func isARM64Instance(instanceType string) bool {
return architectureFor(instanceType) == ArchitectureTypeARM64
}

func isGPUInstance(instanceType string) bool {
if strings.HasPrefix(instanceType, "p2.") {
return true
}
if strings.HasPrefix(instanceType, "p3.") {
return true
}
if strings.HasPrefix(instanceType, "p3dn.") {
return true
}
return false
}

func architectureFor(instanceType string) string {
r := regexp.MustCompile("a1\\.(medium|\\d*x?large)")
if r.MatchString(instanceType) {
logrus.Infof("Using Arm ecs-optimized AMI because instance type was %s", instanceType)
return ArchitectureTypeARM64
}
return ArchitectureTypeX86
}
Loading

0 comments on commit 9a6b2dc

Please sign in to comment.