Skip to content

Commit

Permalink
Merge pull request #257 from leakingtapan/refactor-controller-node
Browse files Browse the repository at this point in the history
Refactor driver to modularize node service and controller service
  • Loading branch information
Cheng Pan committed Apr 3, 2019
2 parents 7632fba + 84ccecf commit daaa95d
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 78 deletions.
2 changes: 0 additions & 2 deletions pkg/cloud/cloud.go
Expand Up @@ -167,12 +167,10 @@ type cloud struct {
var _ Cloud = &cloud{}

// NewCloud returns a new instance of AWS cloud
// Pass in nil metadata to use an auto created EC2Metadata service
// It panics if session is invalid
func NewCloud() (Cloud, error) {
svc := newEC2MetadataSvc()

var err error
metadata, err := NewMetadataService(svc)
if err != nil {
return nil, fmt.Errorf("could not get metadata from AWS: %v", err)
Expand Down
48 changes: 33 additions & 15 deletions pkg/driver/controller.go
Expand Up @@ -47,7 +47,25 @@ var (
}
)

func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
// controllerService represents the controller service of CSI driver
type controllerService struct {
cloud cloud.Cloud
}

// newControllerService creates a new controller service
// it panics if failed to create the service
func newControllerService() controllerService {
cloud, err := cloud.NewCloud()
if err != nil {
panic(err)
}

return controllerService{
cloud: cloud,
}
}

func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
klog.V(4).Infof("CreateVolume: called with args %+v", *req)
volName := req.GetName()
if len(volName) == 0 {
Expand All @@ -71,7 +89,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
return nil, status.Error(codes.InvalidArgument, "Volume capabilities not provided")
}

if !d.isValidVolumeCapabilities(volCaps) {
if !isValidVolumeCapabilities(volCaps) {
return nil, status.Error(codes.InvalidArgument, "Volume capabilities not supported")
}

Expand Down Expand Up @@ -147,7 +165,7 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
return newCreateVolumeResponse(disk), nil
}

func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
klog.V(4).Infof("DeleteVolume: called with args: %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -165,7 +183,7 @@ func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest)
return &csi.DeleteVolumeResponse{}, nil
}

func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
klog.V(4).Infof("ControllerPublishVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -183,7 +201,7 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle
}

caps := []*csi.VolumeCapability{volCap}
if !d.isValidVolumeCapabilities(caps) {
if !isValidVolumeCapabilities(caps) {
return nil, status.Error(codes.InvalidArgument, "Volume capability not supported")
}

Expand Down Expand Up @@ -211,7 +229,7 @@ func (d *Driver) ControllerPublishVolume(ctx context.Context, req *csi.Controlle
return &csi.ControllerPublishVolumeResponse{PublishContext: pvInfo}, nil
}

func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
klog.V(4).Infof("ControllerUnpublishVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -231,7 +249,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control
return &csi.ControllerUnpublishVolumeResponse{}, nil
}

func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
func (d *controllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
klog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req)
var caps []*csi.ControllerServiceCapability
for _, cap := range controllerCaps {
Expand All @@ -247,17 +265,17 @@ func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.Control
return &csi.ControllerGetCapabilitiesResponse{Capabilities: caps}, nil
}

func (d *Driver) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
func (d *controllerService) GetCapacity(ctx context.Context, req *csi.GetCapacityRequest) (*csi.GetCapacityResponse, error) {
klog.V(4).Infof("GetCapacity: called with args %+v", *req)
return nil, status.Error(codes.Unimplemented, "")
}

func (d *Driver) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
func (d *controllerService) ListVolumes(ctx context.Context, req *csi.ListVolumesRequest) (*csi.ListVolumesResponse, error) {
klog.V(4).Infof("ListVolumes: called with args %+v", *req)
return nil, status.Error(codes.Unimplemented, "")
}

func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
func (d *controllerService) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
klog.V(4).Infof("ValidateVolumeCapabilities: called with args %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -277,15 +295,15 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida
}

var confirmed *csi.ValidateVolumeCapabilitiesResponse_Confirmed
if d.isValidVolumeCapabilities(volCaps) {
if isValidVolumeCapabilities(volCaps) {
confirmed = &csi.ValidateVolumeCapabilitiesResponse_Confirmed{VolumeCapabilities: volCaps}
}
return &csi.ValidateVolumeCapabilitiesResponse{
Confirmed: confirmed,
}, nil
}

func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool {
func isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool {
hasSupport := func(cap *csi.VolumeCapability) bool {
for _, c := range volumeCaps {
if c.GetMode() == cap.AccessMode.GetMode() {
Expand All @@ -304,7 +322,7 @@ func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool
return foundAll
}

func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
klog.V(4).Infof("CreateSnapshot: called with args %+v", req)
snapshotName := req.GetName()
if len(snapshotName) == 0 {
Expand Down Expand Up @@ -339,7 +357,7 @@ func (d *Driver) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequ
return newCreateSnapshotResponse(snapshot)
}

func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
klog.V(4).Infof("DeleteSnapshot: called with args %+v", req)
snapshotID := req.GetSnapshotId()
if len(snapshotID) == 0 {
Expand All @@ -357,7 +375,7 @@ func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequ
return &csi.DeleteSnapshotResponse{}, nil
}

func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
return nil, status.Error(codes.Unimplemented, "")
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/driver/controller_test.go
Expand Up @@ -268,7 +268,7 @@ func TestCreateVolume(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
awsDriver := NewFakeDriver("", cloud.NewFakeCloudProvider(), NewFakeMounter())
awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()}

resp, err := awsDriver.CreateVolume(context.TODO(), tc.req)
if err != nil {
Expand Down Expand Up @@ -353,7 +353,7 @@ func TestDeleteVolume(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
awsDriver := NewFakeDriver("", cloud.NewFakeCloudProvider(), NewFakeMounter())
awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()}
_, err := awsDriver.DeleteVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down Expand Up @@ -499,7 +499,7 @@ func TestCreateSnapshot(t *testing.T) {
}
for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewFakeDriver("", cloud.NewFakeCloudProvider(), NewFakeMounter())
awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()}
resp, err := awsDriver.CreateSnapshot(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down Expand Up @@ -565,7 +565,7 @@ func TestDeleteSnapshot(t *testing.T) {
}
for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewFakeDriver("", cloud.NewFakeCloudProvider(), NewFakeMounter())
awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()}
snapResp, err := awsDriver.CreateSnapshot(context.TODO(), snapReq)
if err != nil {
t.Fatalf("Error creating testing snapshot: %v", err)
Expand Down Expand Up @@ -707,7 +707,7 @@ func TestControllerPublishVolume(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.setup(tc.req)
awsDriver := NewFakeDriver("", fakeCloud, NewFakeMounter())
awsDriver := controllerService{cloud: fakeCloud}
_, err := awsDriver.ControllerPublishVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down Expand Up @@ -772,7 +772,7 @@ func TestControllerUnpublishVolume(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.setup(tc.req)
awsDriver := NewFakeDriver("", fakeCloud, NewFakeMounter())
awsDriver := controllerService{cloud: cloud.NewFakeCloudProvider()}
_, err := awsDriver.ControllerUnpublishVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down
35 changes: 7 additions & 28 deletions pkg/driver/driver.go
Expand Up @@ -21,12 +21,9 @@ import (
"net"

csi "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/internal"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util"
"google.golang.org/grpc"
"k8s.io/klog"
"k8s.io/kubernetes/pkg/util/mount"
)

const (
Expand All @@ -35,31 +32,20 @@ const (
)

type Driver struct {
endpoint string
nodeID string

cloud cloud.Cloud
srv *grpc.Server
controllerService
nodeService

mounter *mount.SafeFormatAndMount
inFlight *internal.InFlight
srv *grpc.Server
endpoint string
}

func NewDriver(endpoint string) (*Driver, error) {
klog.Infof("Driver: %v Version: %v", DriverName, driverVersion)

cloud, err := cloud.NewCloud()
if err != nil {
return nil, err
}

m := cloud.GetMetadata()
return &Driver{
endpoint: endpoint,
nodeID: m.GetInstanceID(),
cloud: cloud,
mounter: newSafeMounter(),
inFlight: internal.NewInFlight(),
endpoint: endpoint,
controllerService: newControllerService(),
nodeService: newNodeService(),
}, nil
}

Expand Down Expand Up @@ -98,10 +84,3 @@ func (d *Driver) Stop() {
klog.Infof("Stopping server")
d.srv.Stop()
}

func newSafeMounter() *mount.SafeFormatAndMount {
return &mount.SafeFormatAndMount{
Interface: mount.New(""),
Exec: mount.NewOsExec(),
}
}
18 changes: 12 additions & 6 deletions pkg/driver/fakes.go
Expand Up @@ -29,21 +29,27 @@ func NewFakeMounter() *mount.FakeMounter {
}
}

func NewFakeSafeFormatAndMounter(fakeMounter *mount.FakeMounter) *mount.SafeFormatAndMount {
func NewFakeSafeFormatAndMounter(fakeMounter mount.Interface) *mount.SafeFormatAndMount {
return &mount.SafeFormatAndMount{
Interface: fakeMounter,
Exec: mount.NewFakeExec(nil),
}

}

// NewFakeDriver creates a new mock driver used for testing
func NewFakeDriver(endpoint string, fakeCloud *cloud.FakeCloudProvider, fakeMounter *mount.FakeMounter) *Driver {
return &Driver{
endpoint: endpoint,
nodeID: fakeCloud.GetMetadata().GetInstanceID(),
cloud: fakeCloud,
mounter: NewFakeSafeFormatAndMounter(fakeMounter),
inFlight: internal.NewInFlight(),
controllerService: controllerService{
cloud: fakeCloud,
},
nodeService: nodeService{
metadata: fakeCloud.GetMetadata(),
mounter: &mount.SafeFormatAndMount{
Interface: fakeMounter,
Exec: mount.NewFakeExec(nil),
},
inFlight: internal.NewInFlight(),
},
}
}

0 comments on commit daaa95d

Please sign in to comment.