Skip to content

Commit

Permalink
Merge 7dda588 into d4f2e0c
Browse files Browse the repository at this point in the history
  • Loading branch information
anonymint committed Apr 20, 2021
2 parents d4f2e0c + 7dda588 commit 813904e
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 12 deletions.
1 change: 1 addition & 0 deletions hack/update-gomock
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ set -euo pipefail
IMPORT_PATH=github.com/kubernetes-sigs/aws-efs-csi-driver
mockgen -package=mocks -destination=./pkg/driver/mocks/mock_mount.go ${IMPORT_PATH}/pkg/driver Mounter
mockgen -package=mocks -destination=./pkg/cloud/mocks/mock_ec2metadata.go ${IMPORT_PATH}/pkg/cloud EC2Metadata
mockgen -package=mocks -destination=./pkg/cloud/mocks/mock_taskmetadata.go ${IMPORT_PATH}/pkg/cloud TaskMetadataService
5 changes: 1 addition & 4 deletions pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/efs"
Expand Down Expand Up @@ -91,9 +90,7 @@ type cloud struct {
// It panics if session is invalid
func NewCloud() (Cloud, error) {
sess := session.Must(session.NewSession(&aws.Config{}))
svc := ec2metadata.New(sess)

metadata, err := NewMetadataService(svc)
metadata, err := NewMetadataService(sess)
if err != nil {
return nil, fmt.Errorf("could not get metadata from AWS: %v", err)
}
Expand Down
17 changes: 14 additions & 3 deletions pkg/cloud/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ package cloud

import (
"fmt"

"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"os"
)

type EC2Metadata interface {
Expand Down Expand Up @@ -57,8 +58,18 @@ func (m *metadata) GetAvailabilityZone() string {
return m.availabilityZone
}

// NewMetadataService returns a new MetadataServiceImplementation.
func NewMetadataService(svc EC2Metadata) (MetadataService, error) {
// NewMetadataService return either EC2 or ECS Task MetadataServiceImplementation.
func NewMetadataService(sess *session.Session) (MetadataService, error) {
// check if it is running in ECS otherwise default fall back to ec2
if ecsContainerMetadataUri := os.Getenv(taskMetadataV4EnvName); ecsContainerMetadataUri != "" {
return getTaskMetadata(&taskMetadata{})
} else {
return getEC2Metadata(ec2metadata.New(sess))
}
}

// getEC2Metadata returns a new MetadataServiceImplementation.
func getEC2Metadata(svc EC2Metadata) (MetadataService, error) {
if !svc.Available() {
return nil, fmt.Errorf("EC2 instance metadata is not available")
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/cloud/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var (
stdAvailabilityZone = "az-1"
)

func TestNewMetadataService(t *testing.T) {
func TestGetEC2MetadataService(t *testing.T) {
testCases := []struct {
name string
isAvailable bool
Expand Down Expand Up @@ -114,10 +114,10 @@ func TestNewMetadataService(t *testing.T) {
mockEC2Metadata.EXPECT().GetInstanceIdentityDocument().Return(tc.identityDocument, tc.err)
}

m, err := NewMetadataService(mockEC2Metadata)
m, err := getEC2Metadata(mockEC2Metadata)
if tc.isAvailable && tc.err == nil && !tc.isPartial {
if err != nil {
t.Fatalf("NewMetadataService() failed: expected no error, got %v", err)
t.Fatalf("getEC2Metadata() failed: expected no error, got %v", err)
}

if m.GetInstanceID() != tc.identityDocument.InstanceID {
Expand All @@ -133,7 +133,7 @@ func TestNewMetadataService(t *testing.T) {
}
} else {
if err == nil {
t.Fatal("NewMetadataService() failed: expected error when GetInstanceIdentityDocument returns partial data, got nothing")
t.Fatal("getEC2Metadata() failed: expected error when GetInstanceIdentityDocument returns partial data, got nothing")
}
}

Expand Down
49 changes: 49 additions & 0 deletions pkg/cloud/mocks/mock_taskmetadata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 84 additions & 0 deletions pkg/cloud/task_metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
Copyright 2021 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package cloud

import (
"encoding/json"
"fmt"
"github.com/kubernetes-sigs/aws-efs-csi-driver/pkg/util"
"net/http"
"os"
"strings"
"time"
)

const (
taskMetadataV4EnvName = "ECS_CONTAINER_METADATA_URI_V4"
)

type TaskMetadataService interface {
GetTMDSV4Response() ([]byte, error)
}

type taskMetadata struct {
}

type TMDSV4Response struct {
Cluster string `json:"Cluster"`
TaskARN string `json:"TaskARN"`
AvailabilityZone string `json:"AvailabilityZone"`
}

func (taskMetadata taskMetadata) GetTMDSV4Response() ([]byte, error) {
client := &http.Client{
Timeout: 5 * time.Second,
}
metadataUrl := os.Getenv(taskMetadataV4EnvName)
if metadataUrl == "" {
return nil, fmt.Errorf("unable to get taskMetadataV4 environment variable")
}
respBody, err := util.GetHttpResponse(client, metadataUrl+"/task")
if err != nil {
return nil, fmt.Errorf("unable to get task metadata response: %v", err)
}

return respBody, nil
}

// getTaskMetadata return a new ECS MetadataServiceImplementation
func getTaskMetadata(svc TaskMetadataService) (MetadataService, error) {
metadataResp, err := svc.GetTMDSV4Response()
if err != nil {
return nil, fmt.Errorf("unable to get TaskMetadataService %v", err)
}

tmdsResp := &TMDSV4Response{}
err = json.Unmarshal(metadataResp, tmdsResp)
if err != nil {
return nil, fmt.Errorf("unable to parse task metadata response body %v", metadataResp)
}
taskSplit := strings.Split(tmdsResp.TaskARN, "/")
taskId := taskSplit[len(taskSplit)-1]
az := tmdsResp.AvailabilityZone
region := az[:len(az)-1]
return &metadata{
// does not need, but taskId would be a unique better choice
instanceID: taskId,
availabilityZone: az,
region: region,
}, nil
}
78 changes: 78 additions & 0 deletions pkg/cloud/task_metadata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package cloud

import (
"encoding/json"
"fmt"
"github.com/golang/mock/gomock"
"github.com/kubernetes-sigs/aws-efs-csi-driver/pkg/cloud/mocks"
"testing"
)

var (
clusterId = "default"
taskId = "158d1c8083dd49d6b527399fd6414f5c"
region = "us-west-2"
az = fmt.Sprintf(`%sa`, region)
taskArn = fmt.Sprintf(`arn:aws:ecs:us-west-2:111122223333:task/%s/%s`, clusterId, taskId)
)

func TestGetTaskMetadataService(t *testing.T) {
tests := []struct {
name string
returnTMDSV4Response TMDSV4Response
err error
}{
{
"success: normal",
TMDSV4Response{
Cluster: clusterId,
TaskARN: taskArn,
AvailabilityZone: az,
},
nil,
},
{
"fail: GetTMDSV4Response returned error",
TMDSV4Response{
Cluster: clusterId,
TaskARN: taskArn,
AvailabilityZone: az,
},
fmt.Errorf(""),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockTaskMetadata := mocks.NewMockTaskMetadataService(mockCtrl)
jsonData, _ := json.Marshal(tc.returnTMDSV4Response)
mockTaskMetadata.EXPECT().GetTMDSV4Response().Return(jsonData, tc.err)

m, err := getTaskMetadata(mockTaskMetadata)

if tc.err == nil {
if err != nil {
t.Fatalf("getTaskMetadata failed: expected no error, got %v", err)
}

if m.GetInstanceID() != taskId {
t.Fatalf("GetInstanceID() failed: expeted %v, got %v", taskId, m.GetInstanceID())
}

if m.GetRegion() != region {
t.Fatalf("GetRegion() failed: expeted %v, got %v", region, m.GetRegion())
}

if m.GetAvailabilityZone() != az {
t.Fatalf("GetAvailabilityZone() failed: expeted %v, got %v", az, m.GetAvailabilityZone())
}
} else {
if err == nil {
t.Fatalf("getTaskMetadata() failed: expected error")
}
}
})
}
}
9 changes: 8 additions & 1 deletion pkg/driver/efs_watch_dog.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ state_file_dir_mode = 750
dns_name_format = {fs_id}.efs.{region}.{dns_name_suffix}
dns_name_suffix = amazonaws.com
#The region of the file system when mounting from on-premises or cross region.
{{if .Region -}}
region = {{.Region -}}
{{else -}}
#region = us-east-1
{{- end}}
stunnel_debug_enabled = false
#Uncomment the below option to save all stunnel logs for a file system to the same file
#stunnel_logs_file = /var/log/amazon/efs/{fs_id}.stunnel.log
Expand Down Expand Up @@ -121,6 +125,7 @@ type execWatchdog struct {

type efsUtilsConfig struct {
EfsClientSource string
Region string
}

func newExecWatchdog(efsUtilsCfgPath, efsUtilsStaticFilesPath, cmd string, arg ...string) Watchdog {
Expand Down Expand Up @@ -214,7 +219,9 @@ func (w *execWatchdog) updateConfig(efsClientSource string) error {
return fmt.Errorf("cannot create config file %s for efs-utils. Error: %v", w.efsUtilsCfgPath, err)
}
defer f.Close()
efsCfg := efsUtilsConfig{EfsClientSource: efsClientSource}
// used on Fargate, IMDS queries suffice otherwise
region := os.Getenv("AWS_REGION")
efsCfg := efsUtilsConfig{EfsClientSource: efsClientSource, Region: region}
if err = efsCfgTemplate.Execute(f, efsCfg); err != nil {
return fmt.Errorf("cannot update config %s for efs-utils. Error: %v", w.efsUtilsCfgPath, err)
}
Expand Down
20 changes: 20 additions & 0 deletions pkg/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package util

import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -47,3 +49,21 @@ func ParseEndpoint(endpoint string) (string, string, error) {

return scheme, addr, nil
}

func GetHttpResponse(client *http.Client, endpoint string) ([]byte, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, fmt.Errorf("could not get data from %v %v", endpoint, err)
}
if resp.Body != nil {
defer resp.Body.Close()
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("incorrect status code %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
return body, nil
}

0 comments on commit 813904e

Please sign in to comment.