diff --git a/pkg/lib/aws/ec2.go b/pkg/lib/aws/ec2.go index 9ca9ef36bc..8ae55b0cb8 100644 --- a/pkg/lib/aws/ec2.go +++ b/pkg/lib/aws/ec2.go @@ -30,7 +30,10 @@ import ( s "github.com/cortexlabs/cortex/pkg/lib/strings" ) -var _digitsRegex = regexp.MustCompile(`[0-9]+`) +var ( + _digitsRegex = regexp.MustCompile(`[0-9]+`) + _gpuInstanceFamilies = strset.New("g", "p") +) type ParsedInstanceType struct { Family string @@ -117,6 +120,23 @@ func IsARMInstance(instanceType string) (bool, error) { return false, nil } +func IsAMDGPUInstance(instanceType string) (bool, error) { + parsedType, err := ParseInstanceType(instanceType) + if err != nil { + return false, err + } + + if !_gpuInstanceFamilies.Has(parsedType.Family) { + return false, nil + } + + if parsedType.Capabilities.Has("a") { + return true, nil + } + + return false, nil +} + func (c *Client) SpotInstancePrice(instanceType string) (float64, error) { result, err := c.EC2().DescribeSpotPriceHistory(&ec2.DescribeSpotPriceHistoryInput{ InstanceTypes: []*string{aws.String(instanceType)}, diff --git a/pkg/types/clusterconfig/cluster_config.go b/pkg/types/clusterconfig/cluster_config.go index cd34c4ff7c..0bac1e602d 100644 --- a/pkg/types/clusterconfig/cluster_config.go +++ b/pkg/types/clusterconfig/cluster_config.go @@ -1269,6 +1269,14 @@ func validateInstanceType(instanceType string) (string, error) { return "", ErrorARMInstancesNotSupported(instanceType) } + isAMDGPU, err := aws.IsAMDGPUInstance(instanceType) + if err != nil { + return "", err + } + if isAMDGPU { + return "", ErrorAMDGPUInstancesNotSupported(instanceType) + } + if err := checkCNISupport(instanceType); err != nil { return "", err } diff --git a/pkg/types/clusterconfig/errors.go b/pkg/types/clusterconfig/errors.go index e238912aab..fa69e95c23 100644 --- a/pkg/types/clusterconfig/errors.go +++ b/pkg/types/clusterconfig/errors.go @@ -47,6 +47,7 @@ const ( ErrSpotPriceGreaterThanMaxPrice = "clusterconfig.spot_price_greater_than_max_price" ErrInstanceTypeNotSupportedByCortex = "clusterconfig.instance_type_not_supported_by_cortex" ErrARMInstancesNotSupported = "clusterconfig.arm_instances_not_supported" + ErrAMDGPUInstancesNotSupported = "clusterconfig.amd_gpu_instances_not_supported" ErrAtLeastOneInstanceDistribution = "clusterconfig.at_least_one_instance_distribution" ErrNoCompatibleSpotInstanceFound = "clusterconfig.no_compatible_spot_instance_found" ErrConfiguredWhenSpotIsNotEnabled = "clusterconfig.configured_when_spot_is_not_enabled" @@ -203,7 +204,14 @@ func ErrorInstanceTypeNotSupportedByCortex(instanceType string) error { func ErrorARMInstancesNotSupported(instanceType string) error { return errors.WithStack(&errors.Error{ Kind: ErrARMInstancesNotSupported, - Message: fmt.Sprintf("ARM-based instances (including %s) are not supported", instanceType), + Message: fmt.Sprintf("ARM-based instances (including %s) are not supported by cortex", instanceType), + }) +} + +func ErrorAMDGPUInstancesNotSupported(instanceType string) error { + return errors.WithStack(&errors.Error{ + Kind: ErrAMDGPUInstancesNotSupported, + Message: fmt.Sprintf("AMD GPU instances (including %s) are not supported by cortex", instanceType), }) }