diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index cfa44be69..56ca0d9dc 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -129,6 +129,7 @@ type Disk struct { CapacityGiB int64 AvailabilityZone string SnapshotID string + OutpostArn string } // DiskOptions represents parameters to create an EBS volume @@ -138,6 +139,7 @@ type DiskOptions struct { VolumeType string IOPSPerGB int AvailabilityZone string + OutpostArn string Encrypted bool // KmsKeyID represents a fully qualified resource name to the key to use for encryption. // example: arn:aws:kms:us-east-1:012345678910:key/abcd1234-a123-456a-a12b-a123b4cd56ef @@ -284,6 +286,12 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * TagSpecifications: []*ec2.TagSpecification{&tagSpec}, Encrypted: aws.Bool(diskOptions.Encrypted), } + + // EBS doesn't handle empty outpost arn, so we have to include it only when it's non-empty + if len(diskOptions.OutpostArn) > 0 { + request.OutpostArn = aws.String(diskOptions.OutpostArn) + } + if len(diskOptions.KmsKeyID) > 0 { request.KmsKeyId = aws.String(diskOptions.KmsKeyID) request.Encrypted = aws.Bool(true) @@ -318,7 +326,9 @@ func (c *cloud) CreateDisk(ctx context.Context, volumeName string, diskOptions * return nil, fmt.Errorf("failed to get an available volume in EC2: %v", err) } - return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone, SnapshotID: snapshotID}, nil + outpostArn := aws.StringValue(response.OutpostArn) + + return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone, SnapshotID: snapshotID, OutpostArn: outpostArn}, nil } func (c *cloud) DeleteDisk(ctx context.Context, volumeID string) (bool, error) { @@ -486,6 +496,7 @@ func (c *cloud) GetDiskByName(ctx context.Context, name string, capacityBytes in CapacityGiB: volSizeBytes, AvailabilityZone: aws.StringValue(volume.AvailabilityZone), SnapshotID: aws.StringValue(volume.SnapshotId), + OutpostArn: aws.StringValue(volume.OutpostArn), }, nil } @@ -505,6 +516,7 @@ func (c *cloud) GetDiskByID(ctx context.Context, volumeID string) (*Disk, error) VolumeID: aws.StringValue(volume.VolumeId), CapacityGiB: aws.Int64Value(volume.Size), AvailabilityZone: aws.StringValue(volume.AvailabilityZone), + OutpostArn: aws.StringValue(volume.OutpostArn), }, nil } diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index e7f6c1ccc..d85683ef1 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -94,6 +94,39 @@ func TestCreateDisk(t *testing.T) { }, expErr: nil, }, + { + name: "success: outpost volume", + volumeName: "vol-test-name", + diskOptions: &DiskOptions{ + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: expZone, + OutpostArn: "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0", + }, + expDisk: &Disk{ + VolumeID: "vol-test", + CapacityGiB: 1, + AvailabilityZone: expZone, + OutpostArn: "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0", + }, + expErr: nil, + }, + { + name: "success: empty outpost arn", + volumeName: "vol-test-name", + diskOptions: &DiskOptions{ + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: expZone, + }, + expDisk: &Disk{ + VolumeID: "vol-test", + CapacityGiB: 1, + AvailabilityZone: expZone, + OutpostArn: "", + }, + expErr: nil, + }, { name: "fail: CreateVolume returned CreateVolume error", volumeName: "vol-test-name-error", @@ -162,6 +195,7 @@ func TestCreateDisk(t *testing.T) { Size: aws.Int64(util.BytesToGiB(tc.diskOptions.CapacityBytes)), State: aws.String(volState), AvailabilityZone: aws.String(tc.diskOptions.AvailabilityZone), + OutpostArn: aws.String(tc.diskOptions.OutpostArn), } snapshot := &ec2.Snapshot{ SnapshotId: aws.String(tc.diskOptions.SnapshotID), @@ -203,6 +237,9 @@ func TestCreateDisk(t *testing.T) { if tc.expDisk.AvailabilityZone != disk.AvailabilityZone { t.Fatalf("CreateDisk() failed: expected availabilityZone %q, got %q", tc.expDisk.AvailabilityZone, disk.AvailabilityZone) } + if tc.expDisk.OutpostArn != disk.OutpostArn { + t.Fatalf("CreateDisk() failed: expected outpoustArn %q, got %q", tc.expDisk.OutpostArn, disk.OutpostArn) + } } } @@ -380,6 +417,7 @@ func TestGetDiskByName(t *testing.T) { volumeName string volumeCapacity int64 availabilityZone string + outpostArn string expErr error }{ { @@ -389,6 +427,14 @@ func TestGetDiskByName(t *testing.T) { availabilityZone: expZone, expErr: nil, }, + { + name: "success: outpost volume", + volumeName: "vol-test-1234", + volumeCapacity: util.GiBToBytes(1), + availabilityZone: expZone, + outpostArn: "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0", + expErr: nil, + }, { name: "fail: DescribeVolumes returned generic error", volumeName: "vol-test-1234", @@ -407,6 +453,7 @@ func TestGetDiskByName(t *testing.T) { VolumeId: aws.String(tc.volumeName), Size: aws.Int64(util.BytesToGiB(tc.volumeCapacity)), AvailabilityZone: aws.String(tc.availabilityZone), + OutpostArn: aws.String(tc.outpostArn), } ctx := context.Background() @@ -427,6 +474,9 @@ func TestGetDiskByName(t *testing.T) { if tc.availabilityZone != disk.AvailabilityZone { t.Fatalf("GetDiskByName() failed: expected availabilityZone %q, got %q", tc.availabilityZone, disk.AvailabilityZone) } + if tc.outpostArn != disk.OutpostArn { + t.Fatalf("GetDiskByName() failed: expected outpostArn %q, got %q", tc.outpostArn, disk.OutpostArn) + } } mockCtrl.Finish() @@ -439,6 +489,7 @@ func TestGetDiskByID(t *testing.T) { name string volumeID string availabilityZone string + outpostArn string expErr error }{ { @@ -447,6 +498,13 @@ func TestGetDiskByID(t *testing.T) { availabilityZone: expZone, expErr: nil, }, + { + name: "success: outpost volume", + volumeID: "vol-test-1234", + availabilityZone: expZone, + outpostArn: "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0", + expErr: nil, + }, { name: "fail: DescribeVolumes returned generic error", volumeID: "vol-test-1234", @@ -467,6 +525,7 @@ func TestGetDiskByID(t *testing.T) { { VolumeId: aws.String(tc.volumeID), AvailabilityZone: aws.String(tc.availabilityZone), + OutpostArn: aws.String(tc.outpostArn), }, }, }, @@ -488,6 +547,9 @@ func TestGetDiskByID(t *testing.T) { if tc.availabilityZone != disk.AvailabilityZone { t.Fatalf("GetDiskByName() failed: expected availabilityZone %q, got %q", tc.availabilityZone, disk.AvailabilityZone) } + if disk.OutpostArn != tc.outpostArn { + t.Fatalf("GetDisk() failed: expected outpostArn %q, got %q", tc.outpostArn, disk.OutpostArn) + } } mockCtrl.Finish() diff --git a/pkg/cloud/metadata.go b/pkg/cloud/metadata.go index 649684139..60f6cf78c 100644 --- a/pkg/cloud/metadata.go +++ b/pkg/cloud/metadata.go @@ -18,14 +18,20 @@ package cloud import ( "fmt" + "strings" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" + + "k8s.io/klog" ) type EC2Metadata interface { Available() bool + // ec2 instance metadata endpoints: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + GetMetadata(string) (string, error) GetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error) } @@ -35,6 +41,7 @@ type MetadataService interface { GetInstanceType() string GetRegion() string GetAvailabilityZone() string + GetOutpostArn() arn.ARN } type Metadata struct { @@ -42,8 +49,12 @@ type Metadata struct { InstanceType string Region string AvailabilityZone string + OutpostArn arn.ARN } +// OutpostArnEndpoint is the ec2 instance metadata endpoint to query to get the outpost arn +const OutpostArnEndpoint string = "outpost-arn" + var _ MetadataService = &Metadata{} // GetInstanceID returns the instance identification. @@ -51,7 +62,7 @@ func (m *Metadata) GetInstanceID() string { return m.InstanceID } -// GetInstanceID returns the instance type. +// GetInstanceType returns the instance type. func (m *Metadata) GetInstanceType() string { return m.InstanceType } @@ -66,6 +77,11 @@ func (m *Metadata) GetAvailabilityZone() string { return m.AvailabilityZone } +// GetOutpostArn returns outpost arn if instance is running on an outpost. empty otherwise. +func (m *Metadata) GetOutpostArn() arn.ARN { + return m.OutpostArn +} + func NewMetadata() (MetadataService, error) { sess := session.Must(session.NewSession(&aws.Config{})) svc := ec2metadata.New(sess) @@ -99,10 +115,28 @@ func NewMetadataService(svc EC2Metadata) (MetadataService, error) { return nil, fmt.Errorf("could not get valid EC2 availavility zone") } - return &Metadata{ + outpostArn, err := svc.GetMetadata(OutpostArnEndpoint) + // "outpust-arn" returns 404 for non-outpost instances. note that the request is made to a link-local address. + // it's guaranteed to be in the form `arn::outposts:::outpost/` + // There's a case to be made here to ignore the error so a failure here wouldn't affect non-outpost calls. + if err != nil && !strings.Contains(err.Error(), "404") { + return nil, fmt.Errorf("something went wrong while getting EC2 outpost arn") + } + + metadata := Metadata{ InstanceID: doc.InstanceID, InstanceType: doc.InstanceType, Region: doc.Region, AvailabilityZone: doc.AvailabilityZone, - }, nil + } + + outpostArn = strings.ReplaceAll(outpostArn, "outpost/", "") + parsedArn, err := arn.Parse(outpostArn) + if err != nil { + klog.Warningf("Failed to parse the outpost arn: %s", outpostArn) + } else { + metadata.OutpostArn = parsedArn + } + + return &metadata, nil } diff --git a/pkg/cloud/metadata_test.go b/pkg/cloud/metadata_test.go index 7a9db3fb0..8898e2383 100644 --- a/pkg/cloud/metadata_test.go +++ b/pkg/cloud/metadata_test.go @@ -18,8 +18,10 @@ package cloud import ( "fmt" + "strings" "testing" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud/mocks" @@ -33,12 +35,19 @@ var ( ) func TestNewMetadataService(t *testing.T) { + + validRawOutpostArn := "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0" + validOutpostArn, _ := arn.Parse(strings.ReplaceAll(validRawOutpostArn, "outpost/", "")) + testCases := []struct { name string isAvailable bool isPartial bool identityDocument ec2metadata.EC2InstanceIdentityDocument + rawOutpostArn string + outpostArn arn.ARN err error + getOutpostArnErr error // We should keep this specific to outpost-arn until we need to use more endpoints }{ { name: "success: normal", @@ -51,6 +60,42 @@ func TestNewMetadataService(t *testing.T) { }, err: nil, }, + { + name: "success: outpost-arn is available", + isAvailable: true, + identityDocument: ec2metadata.EC2InstanceIdentityDocument{ + InstanceID: stdInstanceID, + InstanceType: stdInstanceType, + Region: stdRegion, + AvailabilityZone: stdAvailabilityZone, + }, + rawOutpostArn: validRawOutpostArn, + outpostArn: validOutpostArn, + err: nil, + }, + { + name: "success: outpost-arn is invalid", + isAvailable: true, + identityDocument: ec2metadata.EC2InstanceIdentityDocument{ + InstanceID: stdInstanceID, + InstanceType: stdInstanceType, + Region: stdRegion, + AvailabilityZone: stdAvailabilityZone, + }, + err: nil, + }, + { + name: "success: outpost-arn is not found", + isAvailable: true, + identityDocument: ec2metadata.EC2InstanceIdentityDocument{ + InstanceID: stdInstanceID, + InstanceType: stdInstanceType, + Region: stdRegion, + AvailabilityZone: stdAvailabilityZone, + }, + err: nil, + getOutpostArnErr: fmt.Errorf("404"), + }, { name: "fail: metadata not available", isAvailable: false, @@ -109,6 +154,18 @@ func TestNewMetadataService(t *testing.T) { }, err: nil, }, + { + name: "fail: outpost-arn failed", + isAvailable: true, + identityDocument: ec2metadata.EC2InstanceIdentityDocument{ + InstanceID: stdInstanceID, + InstanceType: stdInstanceType, + Region: stdRegion, + AvailabilityZone: stdAvailabilityZone, + }, + err: nil, + getOutpostArnErr: fmt.Errorf("405"), + }, } for _, tc := range testCases { @@ -121,8 +178,12 @@ func TestNewMetadataService(t *testing.T) { mockEC2Metadata.EXPECT().GetInstanceIdentityDocument().Return(tc.identityDocument, tc.err) } - m, err := NewMetadataService(mockEC2Metadata) if tc.isAvailable && tc.err == nil && !tc.isPartial { + mockEC2Metadata.EXPECT().GetMetadata(OutpostArnEndpoint).Return(tc.rawOutpostArn, tc.getOutpostArnErr) + } + + m, err := NewMetadataService(mockEC2Metadata) + if tc.isAvailable && tc.err == nil && tc.getOutpostArnErr == nil && !tc.isPartial { if err != nil { t.Fatalf("NewMetadataService() failed: expected no error, got %v", err) } @@ -142,8 +203,12 @@ func TestNewMetadataService(t *testing.T) { if m.GetAvailabilityZone() != tc.identityDocument.AvailabilityZone { t.Fatalf("GetAvailabilityZone() failed: expected %v, got %v", tc.identityDocument.AvailabilityZone, m.GetAvailabilityZone()) } + + if m.GetOutpostArn() != tc.outpostArn { + t.Fatalf("GetOutpostArn() failed: expected %v, got %v", tc.outpostArn, m.GetOutpostArn()) + } } else { - if err == nil { + if err == nil && tc.getOutpostArnErr == nil { t.Fatal("NewMetadataService() failed: expected error when GetInstanceIdentityDocument returns partial data, got nothing") } } diff --git a/pkg/cloud/mocks/mock_ec2metadata.go b/pkg/cloud/mocks/mock_ec2metadata.go index 37e2ea0d0..2719846d8 100644 --- a/pkg/cloud/mocks/mock_ec2metadata.go +++ b/pkg/cloud/mocks/mock_ec2metadata.go @@ -61,3 +61,18 @@ func (mr *MockEC2MetadataMockRecorder) GetInstanceIdentityDocument() *gomock.Cal mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceIdentityDocument", reflect.TypeOf((*MockEC2Metadata)(nil).GetInstanceIdentityDocument)) } + +// GetMetadata mocks base method +func (m *MockEC2Metadata) GetMetadata(arg0 string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMetadata", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMetadata indicates an expected call of GetMetadata +func (mr *MockEC2MetadataMockRecorder) GetMetadata(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetadata", reflect.TypeOf((*MockEC2Metadata)(nil).GetMetadata), arg0) +} diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 1e37dc000..814e1a59f 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -18,10 +18,12 @@ package driver import ( "context" + "fmt" "os" "strconv" "strings" + "github.com/aws/aws-sdk-go/aws/arn" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/protobuf/ptypes" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" @@ -189,6 +191,7 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol // create a new volume zone := pickAvailabilityZone(req.GetAccessibilityRequirements()) + outpostArn := getOutpostArn(req.GetAccessibilityRequirements()) // fill volume tags if d.driverOptions.kubernetesClusterID != "" { @@ -206,6 +209,7 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol VolumeType: volumeType, IOPSPerGB: iopsPerGB, AvailabilityZone: zone, + OutpostArn: outpostArn, Encrypted: isEncrypted, KmsKeyID: kmsKeyID, SnapshotID: snapshotID, @@ -534,6 +538,26 @@ func pickAvailabilityZone(requirement *csi.TopologyRequirement) string { return "" } +func getOutpostArn(requirement *csi.TopologyRequirement) string { + if requirement == nil { + return "" + } + for _, topology := range requirement.GetPreferred() { + _, exists := topology.GetSegments()[AwsOutpostIDKey] + if exists { + return BuildOutpostArn(topology.GetSegments()) + } + } + for _, topology := range requirement.GetRequisite() { + _, exists := topology.GetSegments()[AwsOutpostIDKey] + if exists { + return BuildOutpostArn(topology.GetSegments()) + } + } + + return "" +} + func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { var src *csi.VolumeContentSource if disk.SnapshotID != "" { @@ -545,6 +569,18 @@ func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { }, } } + + segments := map[string]string{TopologyKey: disk.AvailabilityZone} + + arn, err := arn.Parse(disk.OutpostArn) + + if err == nil { + segments[AwsRegionKey] = arn.Region + segments[AwsPartitionKey] = arn.Partition + segments[AwsAccountIDKey] = arn.AccountID + segments[AwsOutpostIDKey] = strings.ReplaceAll(arn.Resource, "outpost/", "") + } + return &csi.CreateVolumeResponse{ Volume: &csi.Volume{ VolumeId: disk.VolumeID, @@ -552,7 +588,7 @@ func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { VolumeContext: map[string]string{}, AccessibleTopology: []*csi.Topology{ { - Segments: map[string]string{TopologyKey: disk.AvailabilityZone}, + Segments: segments, }, }, ContentSource: src, @@ -622,3 +658,28 @@ func getVolSizeBytes(req *csi.CreateVolumeRequest) (int64, error) { } return volSizeBytes, nil } + +// BuildOutpostArn returns the string representation of the outpost ARN from the given csi.TopologyRequirement.segments +func BuildOutpostArn(segments map[string]string) string { + + if len(segments[AwsPartitionKey]) <= 0 { + return "" + } + + if len(segments[AwsRegionKey]) <= 0 { + return "" + } + if len(segments[AwsOutpostIDKey]) <= 0 { + return "" + } + if len(segments[AwsAccountIDKey]) <= 0 { + return "" + } + + return fmt.Sprintf("arn:%s:outposts:%s:%s:outpost/%s", + segments[AwsPartitionKey], + segments[AwsRegionKey], + segments[AwsAccountIDKey], + segments[AwsOutpostIDKey], + ) +} diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 9a4601734..2c5592e08 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -23,9 +23,11 @@ import ( "math/rand" "os" "reflect" + "strings" "testing" "time" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" @@ -162,6 +164,8 @@ func TestCreateVolume(t *testing.T) { stdVolSize := int64(5 * 1024 * 1024 * 1024) stdCapRange := &csi.CapacityRange{RequiredBytes: stdVolSize} stdParams := map[string]string{} + rawOutpostArn := "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0" + strippedOutpostArn, _ := arn.Parse(strings.ReplaceAll(rawOutpostArn, "outpost/", "")) testCases := []struct { name string @@ -206,6 +210,96 @@ func TestCreateVolume(t *testing.T) { } }, }, + { + name: "success outposts", + testFunc: func(t *testing.T) { + outpostArn := strippedOutpostArn + req := &csi.CreateVolumeRequest{ + Name: "test-vol", + CapacityRange: stdCapRange, + VolumeCapabilities: stdVolCap, + Parameters: map[string]string{}, + AccessibilityRequirements: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{ + TopologyKey: expZone, + AwsAccountIDKey: outpostArn.AccountID, + AwsOutpostIDKey: outpostArn.Resource, + AwsRegionKey: outpostArn.Region, + AwsPartitionKey: outpostArn.Partition, + }, + }, + }, + }, + } + expVol := &csi.Volume{ + CapacityBytes: stdVolSize, + VolumeId: "vol-test", + VolumeContext: map[string]string{}, + AccessibleTopology: []*csi.Topology{ + { + Segments: map[string]string{ + TopologyKey: expZone, + AwsAccountIDKey: outpostArn.AccountID, + AwsOutpostIDKey: outpostArn.Resource, + AwsRegionKey: outpostArn.Region, + AwsPartitionKey: outpostArn.Partition, + }, + }, + }, + } + + ctx := context.Background() + + mockDisk := &cloud.Disk{ + VolumeID: req.Name, + AvailabilityZone: expZone, + CapacityGiB: util.BytesToGiB(stdVolSize), + OutpostArn: outpostArn.String(), + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) + mockCloud.EXPECT().CreateDisk(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Any()).Return(mockDisk, nil) + + awsDriver := controllerService{ + cloud: mockCloud, + driverOptions: &DriverOptions{}, + } + + resp, err := awsDriver.CreateVolume(ctx, req) + if err != nil { + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + t.Fatalf("Unexpected error: %v", srvErr.Code()) + } + + // mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(mockDisk, nil) + vol := resp.GetVolume() + if vol == nil { + t.Fatalf("Expected volume %v, got nil", expVol) + } + + for expKey, expVal := range expVol.GetVolumeContext() { + ctx := vol.GetVolumeContext() + if gotVal, ok := ctx[expKey]; !ok || gotVal != expVal { + t.Fatalf("Expected volume context for key %v: %v, got: %v", expKey, expVal, gotVal) + } + } + + if expVol.GetAccessibleTopology() != nil { + if !reflect.DeepEqual(expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) { + t.Fatalf("Expected AccessibleTopology to be %+v, got: %+v", expVol.GetAccessibleTopology(), vol.GetAccessibleTopology()) + } + } + }, + }, { name: "restore snapshot", testFunc: func(t *testing.T) { @@ -1396,7 +1490,7 @@ func TestPickAvailabilityZone(t *testing.T) { requirement: &csi.TopologyRequirement{ Requisite: []*csi.Topology{ { - Segments: map[string]string{TopologyKey: expZone}, + Segments: map[string]string{TopologyKey: ""}, }, }, Preferred: []*csi.Topology{ @@ -1443,6 +1537,147 @@ func TestPickAvailabilityZone(t *testing.T) { } } +func TestGetOutpostArn(t *testing.T) { + expRawOutpostArn := "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0" + outpostArn, _ := arn.Parse(strings.ReplaceAll(expRawOutpostArn, "outpost/", "")) + testCases := []struct { + name string + requirement *csi.TopologyRequirement + expZone string + expOutpostArn string + }{ + { + name: "Get from preferred", + requirement: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{TopologyKey: expZone}, + }, + }, + Preferred: []*csi.Topology{ + { + Segments: map[string]string{ + TopologyKey: expZone, + AwsAccountIDKey: outpostArn.AccountID, + AwsOutpostIDKey: outpostArn.Resource, + AwsRegionKey: outpostArn.Region, + AwsPartitionKey: outpostArn.Partition, + }, + }, + }, + }, + expZone: expZone, + expOutpostArn: expRawOutpostArn, + }, + { + name: "Get from requisite", + requirement: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + { + Segments: map[string]string{ + TopologyKey: expZone, + AwsAccountIDKey: outpostArn.AccountID, + AwsOutpostIDKey: outpostArn.Resource, + AwsRegionKey: outpostArn.Region, + AwsPartitionKey: outpostArn.Partition, + }, + }, + }, + }, + expZone: expZone, + expOutpostArn: expRawOutpostArn, + }, + { + name: "Get from empty topology", + requirement: &csi.TopologyRequirement{ + Preferred: []*csi.Topology{{}}, + Requisite: []*csi.Topology{{}}, + }, + expZone: "", + expOutpostArn: "", + }, + { + name: "Topology Requirement is nil", + requirement: nil, + expZone: "", + expOutpostArn: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := getOutpostArn(tc.requirement) + if actual != tc.expOutpostArn { + t.Fatalf("Expected %v, got outpostArn: %v", tc.expOutpostArn, actual) + } + }) + } +} + +func TestBuildOutpostArn(t *testing.T) { + expRawOutpostArn := "arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0" + testCases := []struct { + name string + awsPartition string + awsRegion string + awsAccountID string + awsOutpostID string + expectedArn string + }{ + { + name: "all fields are present", + awsPartition: "aws", + awsRegion: "us-west-2", + awsOutpostID: "op-0aaa000a0aaaa00a0", + awsAccountID: "111111111111", + expectedArn: expRawOutpostArn, + }, + { + name: "partition is missing", + awsRegion: "us-west-2", + awsOutpostID: "op-0aaa000a0aaaa00a0", + awsAccountID: "111111111111", + expectedArn: "", + }, + { + name: "region is missing", + awsPartition: "aws", + awsOutpostID: "op-0aaa000a0aaaa00a0", + awsAccountID: "111111111111", + expectedArn: "", + }, + { + name: "account id is missing", + awsPartition: "aws", + awsRegion: "us-west-2", + awsOutpostID: "op-0aaa000a0aaaa00a0", + expectedArn: "", + }, + { + name: "outpost id is missing", + awsPartition: "aws", + awsRegion: "us-west-2", + awsAccountID: "111111111111", + expectedArn: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + segment := map[string]string{ + AwsRegionKey: tc.awsRegion, + AwsPartitionKey: tc.awsPartition, + AwsAccountIDKey: tc.awsAccountID, + AwsOutpostIDKey: tc.awsOutpostID, + } + actual := BuildOutpostArn(segment) + if actual != tc.expectedArn { + t.Fatalf("Expected %v, got outpostArn: %v", tc.expectedArn, actual) + } + }) + } +} + func TestCreateSnapshot(t *testing.T) { testCases := []struct { name string diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 1be2e86dc..186bdf9e5 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -40,8 +40,12 @@ const ( ) const ( - DriverName = "ebs.csi.aws.com" - TopologyKey = "topology." + DriverName + "/zone" + DriverName = "ebs.csi.aws.com" + TopologyKey = "topology." + DriverName + "/zone" + AwsPartitionKey = "topology." + DriverName + "/partition" + AwsAccountIDKey = "topology." + DriverName + "/account-id" + AwsRegionKey = "topology." + DriverName + "/region" + AwsOutpostIDKey = "topology." + DriverName + "/outpost-id" ) type Driver struct { diff --git a/pkg/driver/mocks/mock_metadata_service.go b/pkg/driver/mocks/mock_metadata_service.go index dba05fda5..ee2f67173 100644 --- a/pkg/driver/mocks/mock_metadata_service.go +++ b/pkg/driver/mocks/mock_metadata_service.go @@ -5,6 +5,7 @@ package mocks import ( + arn "github.com/aws/aws-sdk-go/aws/arn" gomock "github.com/golang/mock/gomock" reflect "reflect" ) @@ -74,6 +75,20 @@ func (mr *MockMetadataServiceMockRecorder) GetInstanceType() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceType", reflect.TypeOf((*MockMetadataService)(nil).GetInstanceType)) } +// GetOutpostArn mocks base method +func (m *MockMetadataService) GetOutpostArn() arn.ARN { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOutpostArn") + ret0, _ := ret[0].(arn.ARN) + return ret0 +} + +// GetOutpostArn indicates an expected call of GetOutpostArn +func (mr *MockMetadataServiceMockRecorder) GetOutpostArn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOutpostArn", reflect.TypeOf((*MockMetadataService)(nil).GetOutpostArn)) +} + // GetRegion mocks base method func (m *MockMetadataService) GetRegion() string { m.ctrl.T.Helper() diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 456823649..2dc3dc28f 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -366,10 +366,22 @@ func (d *nodeService) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetC func (d *nodeService) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { klog.V(4).Infof("NodeGetInfo: called with args %+v", *req) - topology := &csi.Topology{ - Segments: map[string]string{TopologyKey: d.metadata.GetAvailabilityZone()}, + segments := map[string]string{ + TopologyKey: d.metadata.GetAvailabilityZone(), } + outpostArn := d.metadata.GetOutpostArn() + + // to my surprise ARN's string representation is not empty for empty ARN + if len(outpostArn.Resource) > 0 { + segments[AwsRegionKey] = outpostArn.Region + segments[AwsPartitionKey] = outpostArn.Partition + segments[AwsAccountIDKey] = outpostArn.AccountID + segments[AwsOutpostIDKey] = outpostArn.Resource + } + + topology := &csi.Topology{Segments: segments} + return &csi.NodeGetInfoResponse{ NodeId: d.metadata.GetInstanceID(), MaxVolumesPerNode: d.getVolumesLimit(), diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 67da47ace..7430d5fd1 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -20,8 +20,10 @@ import ( "context" "errors" "reflect" + "strings" "testing" + "github.com/aws/aws-sdk-go/aws/arn" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/internal" @@ -1242,6 +1244,8 @@ func TestNodeGetCapabilities(t *testing.T) { } func TestNodeGetInfo(t *testing.T) { + validOutpostArn, _ := arn.Parse(strings.ReplaceAll("arn:aws:outposts:us-west-2:111111111111:outpost/op-0aaa000a0aaaa00a0", "outpost/", "")) + emptyOutpostArn := arn.ARN{} testCases := []struct { name string instanceID string @@ -1249,6 +1253,7 @@ func TestNodeGetInfo(t *testing.T) { availabilityZone string volumeAttachLimit int64 expMaxVolumes int64 + outpostArn arn.ARN }{ { name: "success normal", @@ -1257,6 +1262,7 @@ func TestNodeGetInfo(t *testing.T) { availabilityZone: "us-west-2b", volumeAttachLimit: -1, expMaxVolumes: 39, + outpostArn: emptyOutpostArn, }, { name: "success normal with overwrite", @@ -1265,6 +1271,7 @@ func TestNodeGetInfo(t *testing.T) { availabilityZone: "us-west-2b", volumeAttachLimit: 42, expMaxVolumes: 42, + outpostArn: emptyOutpostArn, }, { name: "success normal with NVMe", @@ -1273,6 +1280,7 @@ func TestNodeGetInfo(t *testing.T) { availabilityZone: "us-west-2b", volumeAttachLimit: -1, expMaxVolumes: 25, + outpostArn: emptyOutpostArn, }, { name: "success normal with NVMe and overwrite", @@ -1281,6 +1289,16 @@ func TestNodeGetInfo(t *testing.T) { availabilityZone: "us-west-2b", volumeAttachLimit: 30, expMaxVolumes: 30, + outpostArn: emptyOutpostArn, + }, + { + name: "success normal outposts", + instanceID: "i-123456789abcdef01", + instanceType: "m5d.large", + availabilityZone: "us-west-2b", + volumeAttachLimit: 30, + expMaxVolumes: 30, + outpostArn: validOutpostArn, }, } for _, tc := range testCases { @@ -1297,6 +1315,7 @@ func TestNodeGetInfo(t *testing.T) { mockMetadata := mocks.NewMockMetadataService(mockCtl) mockMetadata.EXPECT().GetInstanceID().Return(tc.instanceID) mockMetadata.EXPECT().GetAvailabilityZone().Return(tc.availabilityZone) + mockMetadata.EXPECT().GetOutpostArn().Return(tc.outpostArn) if tc.volumeAttachLimit < 0 { mockMetadata.EXPECT().GetInstanceType().Return(tc.instanceType) @@ -1327,6 +1346,22 @@ func TestNodeGetInfo(t *testing.T) { t.Fatalf("Expected topology %q, got %q", tc.availabilityZone, at.Segments[TopologyKey]) } + if at.Segments[AwsAccountIDKey] != tc.outpostArn.AccountID { + t.Fatalf("Expected AwsAccountId %q, got %q", tc.outpostArn.AccountID, at.Segments[AwsAccountIDKey]) + } + + if at.Segments[AwsRegionKey] != tc.outpostArn.Region { + t.Fatalf("Expected AwsRegion %q, got %q", tc.outpostArn.Region, at.Segments[AwsRegionKey]) + } + + if at.Segments[AwsOutpostIDKey] != tc.outpostArn.Resource { + t.Fatalf("Expected AwsOutpostID %q, got %q", tc.outpostArn.Resource, at.Segments[AwsOutpostIDKey]) + } + + if at.Segments[AwsPartitionKey] != tc.outpostArn.Partition { + t.Fatalf("Expected AwsPartition %q, got %q", tc.outpostArn.Partition, at.Segments[AwsPartitionKey]) + } + if resp.GetMaxVolumesPerNode() != tc.expMaxVolumes { t.Fatalf("Expected %d max volumes per node, got %d", tc.expMaxVolumes, resp.GetMaxVolumesPerNode()) } diff --git a/tests/e2e/driver/driver.go b/tests/e2e/driver/driver.go index a3a25dd47..4b54e6912 100644 --- a/tests/e2e/driver/driver.go +++ b/tests/e2e/driver/driver.go @@ -16,7 +16,7 @@ package driver import ( "github.com/kubernetes-csi/external-snapshotter/v2/pkg/apis/volumesnapshot/v1beta1" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" storagev1 "k8s.io/api/storage/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) diff --git a/vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go b/vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go new file mode 100644 index 000000000..1c4967429 --- /dev/null +++ b/vendor/github.com/aws/aws-sdk-go/aws/arn/arn.go @@ -0,0 +1,93 @@ +// Package arn provides a parser for interacting with Amazon Resource Names. +package arn + +import ( + "errors" + "strings" +) + +const ( + arnDelimiter = ":" + arnSections = 6 + arnPrefix = "arn:" + + // zero-indexed + sectionPartition = 1 + sectionService = 2 + sectionRegion = 3 + sectionAccountID = 4 + sectionResource = 5 + + // errors + invalidPrefix = "arn: invalid prefix" + invalidSections = "arn: not enough sections" +) + +// ARN captures the individual fields of an Amazon Resource Name. +// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information. +type ARN struct { + // The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in + // other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China + // (Beijing) region is "aws-cn". + Partition string + + // The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of + // namespaces, see + // http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces. + Service string + + // The region the resource resides in. Note that the ARNs for some resources do not require a region, so this + // component might be omitted. + Region string + + // The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the + // ARNs for some resources don't require an account number, so this component might be omitted. + AccountID string + + // The content of this part of the ARN varies by service. It often includes an indicator of the type of resource — + // for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the + // resource name itself. Some services allows paths for resource names, as described in + // http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths. + Resource string +} + +// Parse parses an ARN into its constituent parts. +// +// Some example ARNs: +// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment +// arn:aws:iam::123456789012:user/David +// arn:aws:rds:eu-west-1:123456789012:db:mysql-db +// arn:aws:s3:::my_corporate_bucket/exampleobject.png +func Parse(arn string) (ARN, error) { + if !strings.HasPrefix(arn, arnPrefix) { + return ARN{}, errors.New(invalidPrefix) + } + sections := strings.SplitN(arn, arnDelimiter, arnSections) + if len(sections) != arnSections { + return ARN{}, errors.New(invalidSections) + } + return ARN{ + Partition: sections[sectionPartition], + Service: sections[sectionService], + Region: sections[sectionRegion], + AccountID: sections[sectionAccountID], + Resource: sections[sectionResource], + }, nil +} + +// IsARN returns whether the given string is an ARN by looking for +// whether the string starts with "arn:" and contains the correct number +// of sections delimited by colons(:). +func IsARN(arn string) bool { + return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1 +} + +// String returns the canonical representation of the ARN +func (arn ARN) String() string { + return arnPrefix + + arn.Partition + arnDelimiter + + arn.Service + arnDelimiter + + arn.Region + arnDelimiter + + arn.AccountID + arnDelimiter + + arn.Resource +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 8a546655b..961ba40b6 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,5 +1,6 @@ # github.com/aws/aws-sdk-go v1.29.11 github.com/aws/aws-sdk-go/aws +github.com/aws/aws-sdk-go/aws/arn github.com/aws/aws-sdk-go/aws/awserr github.com/aws/aws-sdk-go/aws/awsutil github.com/aws/aws-sdk-go/aws/client