diff --git a/pkg/volume/csi/BUILD b/pkg/volume/csi/BUILD index 2067ebfeb9dd3..63e72468708e8 100644 --- a/pkg/volume/csi/BUILD +++ b/pkg/volume/csi/BUILD @@ -42,7 +42,8 @@ go_test( "//pkg/volume:go_default_library", "//pkg/volume/csi/fake:go_default_library", "//pkg/volume/testing:go_default_library", - "//vendor/google.golang.org/grpc:go_default_library", + "//pkg/volume/util:go_default_library", + "//vendor/github.com/container-storage-interface/spec/lib/go/csi/v0:go_default_library", "//vendor/k8s.io/api/core/v1:go_default_library", "//vendor/k8s.io/api/storage/v1beta1:go_default_library", "//vendor/k8s.io/apimachinery/pkg/api/errors:go_default_library", diff --git a/pkg/volume/csi/csi_attacher.go b/pkg/volume/csi/csi_attacher.go index 8c280d86ddde9..82fb8e6620cb7 100644 --- a/pkg/volume/csi/csi_attacher.go +++ b/pkg/volume/csi/csi_attacher.go @@ -270,11 +270,7 @@ func (c *csiAttacher) MountDevice(spec *volume.Spec, devicePath string, deviceMo } if c.csiClient == nil { - if csiSource.Driver == "" { - return fmt.Errorf("attacher.MountDevice failed, driver name is empty") - } - addr := fmt.Sprintf(csiAddrTemplate, csiSource.Driver) - c.csiClient = newCsiDriverClient("unix", addr) + c.csiClient = newCsiDriverClient(csiSource.Driver) } csi := c.csiClient @@ -472,8 +468,7 @@ func (c *csiAttacher) UnmountDevice(deviceMountPath string) error { } if c.csiClient == nil { - addr := fmt.Sprintf(csiAddrTemplate, driverName) - c.csiClient = newCsiDriverClient("unix", addr) + c.csiClient = newCsiDriverClient(driverName) } csi := c.csiClient diff --git a/pkg/volume/csi/csi_attacher_test.go b/pkg/volume/csi/csi_attacher_test.go index dda1a1a66af32..83703ddd491d1 100644 --- a/pkg/volume/csi/csi_attacher_test.go +++ b/pkg/volume/csi/csi_attacher_test.go @@ -33,7 +33,6 @@ import ( core "k8s.io/client-go/testing" utiltesting "k8s.io/client-go/util/testing" "k8s.io/kubernetes/pkg/volume" - "k8s.io/kubernetes/pkg/volume/csi/fake" volumetest "k8s.io/kubernetes/pkg/volume/testing" ) @@ -583,8 +582,8 @@ func TestAttacherMountDevice(t *testing.T) { numStaged = 0 } - cdc := csiAttacher.csiClient.(*csiDriverClient) - staged := cdc.nodeClient.(*fake.NodeClient).GetNodeStagedVolumes() + cdc := csiAttacher.csiClient.(*fakeCsiDriverClient) + staged := cdc.nodeClient.GetNodeStagedVolumes() if len(staged) != numStaged { t.Errorf("got wrong number of staged volumes, expecting %v got: %v", numStaged, len(staged)) } @@ -668,8 +667,8 @@ func TestAttacherUnmountDevice(t *testing.T) { csiAttacher.csiClient = setupClient(t, tc.stageUnstageSet) // Add the volume to NodeStagedVolumes - cdc := csiAttacher.csiClient.(*csiDriverClient) - cdc.nodeClient.(*fake.NodeClient).AddNodeStagedVolume(tc.volID, tc.deviceMountPath) + cdc := csiAttacher.csiClient.(*fakeCsiDriverClient) + cdc.nodeClient.AddNodeStagedVolume(tc.volID, tc.deviceMountPath) // Make the PV for this object dir := filepath.Dir(tc.deviceMountPath) @@ -700,7 +699,7 @@ func TestAttacherUnmountDevice(t *testing.T) { if !tc.stageUnstageSet { expectedSet = 1 } - staged := cdc.nodeClient.(*fake.NodeClient).GetNodeStagedVolumes() + staged := cdc.nodeClient.GetNodeStagedVolumes() if len(staged) != expectedSet { t.Errorf("got wrong number of staged volumes, expecting %v got: %v", expectedSet, len(staged)) } diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index e9283efbf18bd..6398ce455d1da 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -19,6 +19,7 @@ package csi import ( "context" "errors" + "fmt" "net" "time" @@ -61,45 +62,15 @@ type csiClient interface { // csiClient encapsulates all csi-plugin methods type csiDriverClient struct { - network string - addr string - conn *grpc.ClientConn - idClient csipb.IdentityClient - nodeClient csipb.NodeClient - ctrlClient csipb.ControllerClient - versionAsserted bool - versionSupported bool - publishAsserted bool - publishCapable bool + driverName string + nodeClient csipb.NodeClient } -func newCsiDriverClient(network, addr string) *csiDriverClient { - return &csiDriverClient{network: network, addr: addr} -} - -// assertConnection ensures a valid connection has been established -// if not, it creates a new connection and associated clients -func (c *csiDriverClient) assertConnection() error { - if c.conn == nil { - conn, err := grpc.Dial( - c.addr, - grpc.WithInsecure(), - grpc.WithDialer(func(target string, timeout time.Duration) (net.Conn, error) { - return net.Dial(c.network, target) - }), - ) - if err != nil { - return err - } - c.conn = conn - c.idClient = csipb.NewIdentityClient(conn) - c.nodeClient = csipb.NewNodeClient(conn) - c.ctrlClient = csipb.NewControllerClient(conn) - - // set supported version - } +var _ csiClient = &csiDriverClient{} - return nil +func newCsiDriverClient(driverName string) *csiDriverClient { + c := &csiDriverClient{driverName: driverName} + return c } func (c *csiDriverClient) NodePublishVolume( @@ -121,10 +92,13 @@ func (c *csiDriverClient) NodePublishVolume( if targetPath == "" { return errors.New("missing target path") } - if err := c.assertConnection(); err != nil { - glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + + conn, err := newGrpcConn(c.driverName) + if err != nil { return err } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) req := &csipb.NodePublishVolumeRequest{ VolumeId: volID, @@ -148,7 +122,7 @@ func (c *csiDriverClient) NodePublishVolume( req.StagingTargetPath = stagingTargetPath } - _, err := c.nodeClient.NodePublishVolume(ctx, req) + _, err = nodeClient.NodePublishVolume(ctx, req) return err } @@ -160,17 +134,20 @@ func (c *csiDriverClient) NodeUnpublishVolume(ctx context.Context, volID string, if targetPath == "" { return errors.New("missing target path") } - if err := c.assertConnection(); err != nil { - glog.Error(log("failed to assert a connection: %v", err)) + + conn, err := newGrpcConn(c.driverName) + if err != nil { return err } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) req := &csipb.NodeUnpublishVolumeRequest{ VolumeId: volID, TargetPath: targetPath, } - _, err := c.nodeClient.NodeUnpublishVolume(ctx, req) + _, err = nodeClient.NodeUnpublishVolume(ctx, req) return err } @@ -190,10 +167,13 @@ func (c *csiDriverClient) NodeStageVolume(ctx context.Context, if stagingTargetPath == "" { return errors.New("missing staging target path") } - if err := c.assertConnection(); err != nil { - glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + + conn, err := newGrpcConn(c.driverName) + if err != nil { return err } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) req := &csipb.NodeStageVolumeRequest{ VolumeId: volID, @@ -213,7 +193,7 @@ func (c *csiDriverClient) NodeStageVolume(ctx context.Context, VolumeAttributes: volumeAttribs, } - _, err := c.nodeClient.NodeStageVolume(ctx, req) + _, err = nodeClient.NodeStageVolume(ctx, req) return err } @@ -225,27 +205,34 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT if stagingTargetPath == "" { return errors.New("missing staging target path") } - if err := c.assertConnection(); err != nil { - glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + + conn, err := newGrpcConn(c.driverName) + if err != nil { return err } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) req := &csipb.NodeUnstageVolumeRequest{ VolumeId: volID, StagingTargetPath: stagingTargetPath, } - _, err := c.nodeClient.NodeUnstageVolume(ctx, req) + _, err = nodeClient.NodeUnstageVolume(ctx, req) return err } func (c *csiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { glog.V(4).Info(log("calling NodeGetCapabilities rpc")) - if err := c.assertConnection(); err != nil { - glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + + conn, err := newGrpcConn(c.driverName) + if err != nil { return nil, err } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) + req := &csipb.NodeGetCapabilitiesRequest{} - resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) + resp, err := nodeClient.NodeGetCapabilities(ctx, req) if err != nil { return nil, err } @@ -263,3 +250,21 @@ func asCSIAccessMode(am api.PersistentVolumeAccessMode) csipb.VolumeCapability_A } return csipb.VolumeCapability_AccessMode_UNKNOWN } + +func newGrpcConn(driverName string) (*grpc.ClientConn, error) { + if driverName == "" { + return nil, fmt.Errorf("driver name is empty") + } + + network := "unix" + addr := fmt.Sprintf(csiAddrTemplate, driverName) + glog.V(4).Infof(log("creating new gRPC connection for [%s://%s]", network, addr)) + + return grpc.Dial( + addr, + grpc.WithInsecure(), + grpc.WithDialer(func(target string, timeout time.Duration) (net.Conn, error) { + return net.Dial(network, target) + }), + ) +} diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index d4c843ef1bf49..c5b8cf01f78e6 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -21,21 +21,124 @@ import ( "errors" "testing" - "google.golang.org/grpc" + csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" api "k8s.io/api/core/v1" "k8s.io/kubernetes/pkg/volume/csi/fake" ) -func setupClient(t *testing.T, stageUnstageSet bool) *csiDriverClient { - client := newCsiDriverClient("unix", "/tmp/test.sock") - client.conn = new(grpc.ClientConn) //avoids creating conn object +type fakeCsiDriverClient struct { + t *testing.T + nodeClient *fake.NodeClient +} + +func newFakeCsiDriverClient(t *testing.T, stagingCapable bool) *fakeCsiDriverClient { + return &fakeCsiDriverClient{ + t: t, + nodeClient: fake.NewNodeClient(stagingCapable), + } +} + +func (c *fakeCsiDriverClient) NodePublishVolume( + ctx context.Context, + volID string, + readOnly bool, + stagingTargetPath string, + targetPath string, + accessMode api.PersistentVolumeAccessMode, + volumeInfo map[string]string, + volumeAttribs map[string]string, + nodePublishSecrets map[string]string, + fsType string, +) error { + c.t.Log("calling fake.NodePublishVolume...") + req := &csipb.NodePublishVolumeRequest{ + VolumeId: volID, + TargetPath: targetPath, + Readonly: readOnly, + PublishInfo: volumeInfo, + VolumeAttributes: volumeAttribs, + NodePublishSecrets: nodePublishSecrets, + VolumeCapability: &csipb.VolumeCapability{ + AccessMode: &csipb.VolumeCapability_AccessMode{ + Mode: asCSIAccessMode(accessMode), + }, + AccessType: &csipb.VolumeCapability_Mount{ + Mount: &csipb.VolumeCapability_MountVolume{ + FsType: fsType, + }, + }, + }, + } - // setup mock grpc clients - client.idClient = fake.NewIdentityClient() - client.nodeClient = fake.NewNodeClient(stageUnstageSet) - client.ctrlClient = fake.NewControllerClient() + _, err := c.nodeClient.NodePublishVolume(ctx, req) + return err +} + +func (c *fakeCsiDriverClient) NodeUnpublishVolume(ctx context.Context, volID string, targetPath string) error { + c.t.Log("calling fake.NodeUnpublishVolume...") + req := &csipb.NodeUnpublishVolumeRequest{ + VolumeId: volID, + TargetPath: targetPath, + } + + _, err := c.nodeClient.NodeUnpublishVolume(ctx, req) + return err +} + +func (c *fakeCsiDriverClient) NodeStageVolume(ctx context.Context, + volID string, + publishInfo map[string]string, + stagingTargetPath string, + fsType string, + accessMode api.PersistentVolumeAccessMode, + nodeStageSecrets map[string]string, + volumeAttribs map[string]string, +) error { + c.t.Log("calling fake.NodeStageVolume...") + req := &csipb.NodeStageVolumeRequest{ + VolumeId: volID, + PublishInfo: publishInfo, + StagingTargetPath: stagingTargetPath, + VolumeCapability: &csipb.VolumeCapability{ + AccessMode: &csipb.VolumeCapability_AccessMode{ + Mode: asCSIAccessMode(accessMode), + }, + AccessType: &csipb.VolumeCapability_Mount{ + Mount: &csipb.VolumeCapability_MountVolume{ + FsType: fsType, + }, + }, + }, + NodeStageSecrets: nodeStageSecrets, + VolumeAttributes: volumeAttribs, + } + + _, err := c.nodeClient.NodeStageVolume(ctx, req) + return err +} + +func (c *fakeCsiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingTargetPath string) error { + c.t.Log("calling fake.NodeUnstageVolume...") + req := &csipb.NodeUnstageVolumeRequest{ + VolumeId: volID, + StagingTargetPath: stagingTargetPath, + } + _, err := c.nodeClient.NodeUnstageVolume(ctx, req) + return err +} + +func (c *fakeCsiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { + c.t.Log("calling fake.NodeGetCapabilities...") + req := &csipb.NodeGetCapabilitiesRequest{} + resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) + if err != nil { + return nil, err + } + return resp.GetCapabilities(), nil +} - return client +func setupClient(t *testing.T, stageUnstageSet bool) csiClient { + return newFakeCsiDriverClient(t, stageUnstageSet) } func TestClientNodePublishVolume(t *testing.T) { @@ -58,7 +161,7 @@ func TestClientNodePublishVolume(t *testing.T) { for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) err := client.NodePublishVolume( context.Background(), tc.volID, @@ -96,7 +199,7 @@ func TestClientNodeUnpublishVolume(t *testing.T) { for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) err := client.NodeUnpublishVolume(context.Background(), tc.volID, tc.targetPath) if tc.mustFail && err == nil { t.Error("test must fail, but err is nil") @@ -125,7 +228,7 @@ func TestClientNodeStageVolume(t *testing.T) { for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) err := client.NodeStageVolume( context.Background(), tc.volID, @@ -161,7 +264,7 @@ func TestClientNodeUnstageVolume(t *testing.T) { for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) err := client.NodeUnstageVolume( context.Background(), tc.volID, tc.stagingTargetPath, diff --git a/pkg/volume/csi/csi_mounter.go b/pkg/volume/csi/csi_mounter.go index 53908c76b93fb..cb3de6475fa42 100644 --- a/pkg/volume/csi/csi_mounter.go +++ b/pkg/volume/csi/csi_mounter.go @@ -55,8 +55,8 @@ var ( ) type csiMountMgr struct { - k8s kubernetes.Interface csiClient csiClient + k8s kubernetes.Interface plugin *csiPlugin driverName string volumeID string @@ -121,6 +121,7 @@ func (c *csiMountMgr) SetUpAt(dir string, fsGroup *int64) error { ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) defer cancel() + // Check for STAGE_UNSTAGE_VOLUME set and populate deviceMountPath if so deviceMountPath := "" stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) @@ -170,24 +171,6 @@ func (c *csiMountMgr) SetUpAt(dir string, fsGroup *int64) error { } glog.V(4).Info(log("created target path successfully [%s]", dir)) - // persist volume info data for teardown - volData := map[string]string{ - volDataKey.specVolID: c.spec.Name(), - volDataKey.volHandle: csiSource.VolumeHandle, - volDataKey.driverName: csiSource.Driver, - volDataKey.nodeName: nodeName, - volDataKey.attachmentID: attachID, - } - - if err := saveVolumeData(c.plugin, c.podUID, c.spec.Name(), volData); err != nil { - glog.Error(log("mounter.SetUpAt failed to save volume info data: %v", err)) - if err := removeMountDir(c.plugin, dir); err != nil { - glog.Error(log("mounter.SetUpAt failed to remove mount dir after a saveVolumeData() error [%s]: %v", dir, err)) - return err - } - return err - } - //TODO (vladimirvivien) implement better AccessModes mapping between k8s and CSI accessMode := api.ReadWriteOnce if c.spec.PersistentVolume.Spec.AccessModes != nil { @@ -290,29 +273,12 @@ func (c *csiMountMgr) TearDownAt(dir string) error { return nil } - // load volume info from file - dataDir := path.Dir(dir) // dropoff /mount at end - data, err := loadVolumeData(dataDir, volDataFileName) - if err != nil { - glog.Error(log("unmounter.Teardown failed to load volume data file using dir [%s]: %v", dir, err)) - return err - } - - volID := data[volDataKey.volHandle] - driverName := data[volDataKey.driverName] - - if c.csiClient == nil { - addr := fmt.Sprintf(csiAddrTemplate, driverName) - client := newCsiDriverClient("unix", addr) - glog.V(4).Infof(log("unmounter csiClient setup [volume=%v,driver=%v]", volID, driverName)) - c.csiClient = client - } + volID := c.volumeID + csi := c.csiClient ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) defer cancel() - csi := c.csiClient - if err := csi.NodeUnpublishVolume(ctx, volID, dir); err != nil { glog.Errorf(log("mounter.TearDownAt failed: %v", err)) return err @@ -328,12 +294,10 @@ func (c *csiMountMgr) TearDownAt(dir string) error { return nil } -// saveVolumeData persists parameter data as json file using the location -// generated by /var/lib/kubelet/pods//volumes/kubernetes.io~csi//volume_data.json -func saveVolumeData(p *csiPlugin, podUID types.UID, specVolID string, data map[string]string) error { - dir := getTargetPath(podUID, specVolID, p.host) - dataFilePath := path.Join(dir, volDataFileName) - +// saveVolumeData persists parameter data as json file at the provided location +func saveVolumeData(dir string, fileName string, data map[string]string) error { + dataFilePath := path.Join(dir, fileName) + glog.V(4).Info(log("saving volume data file [%s]", dataFilePath)) file, err := os.Create(dataFilePath) if err != nil { glog.Error(log("failed to save volume data file %s: %v", dataFilePath, err)) @@ -348,10 +312,7 @@ func saveVolumeData(p *csiPlugin, podUID types.UID, specVolID string, data map[s return nil } -// loadVolumeData uses the directory returned by mounter.GetPath with value -// /var/lib/kubelet/pods//volumes/kubernetes.io~csi//mount. -// The function extracts specVolumeID and uses it to load the json data file from dir -// /var/lib/kubelet/pods//volumes/kubernetes.io~csi//volume_data.json +// loadVolumeData loads volume info from specified json file/location func loadVolumeData(dir string, fileName string) (map[string]string, error) { // remove /mount at the end dataFileName := path.Join(dir, fileName) diff --git a/pkg/volume/csi/csi_mounter_test.go b/pkg/volume/csi/csi_mounter_test.go index c64792c5d11ce..d1e942871d064 100644 --- a/pkg/volume/csi/csi_mounter_test.go +++ b/pkg/volume/csi/csi_mounter_test.go @@ -31,8 +31,8 @@ import ( "k8s.io/apimachinery/pkg/types" fakeclient "k8s.io/client-go/kubernetes/fake" "k8s.io/kubernetes/pkg/volume" - "k8s.io/kubernetes/pkg/volume/csi/fake" volumetest "k8s.io/kubernetes/pkg/volume/testing" + "k8s.io/kubernetes/pkg/volume/util" ) var ( @@ -78,7 +78,6 @@ func TestMounterGetPath(t *testing.T) { csiMounter := mounter.(*csiMountMgr) path := csiMounter.GetPath() - t.Logf("*** GetPath: %s", path) if tc.path != path { t.Errorf("expecting path %s, got %s", tc.path, path) @@ -114,7 +113,7 @@ func TestMounterSetUp(t *testing.T) { } csiMounter := mounter.(*csiMountMgr) - csiMounter.csiClient = setupClient(t, false) + csiMounter.csiClient = setupClient(t, true) attachID := getAttachmentName(csiMounter.volumeID, csiMounter.driverName, string(plug.host.GetNodeName())) @@ -155,7 +154,7 @@ func TestMounterSetUp(t *testing.T) { } // ensure call went all the way - pubs := csiMounter.csiClient.(*csiDriverClient).nodeClient.(*fake.NodeClient).GetNodePublishedVolumes() + pubs := csiMounter.csiClient.(*fakeCsiDriverClient).nodeClient.GetNodePublishedVolumes() if pubs[csiMounter.volumeID] != csiMounter.GetPath() { t.Error("csi server may not have received NodePublishVolume call") } @@ -164,39 +163,46 @@ func TestMounterSetUp(t *testing.T) { func TestUnmounterTeardown(t *testing.T) { plug, tmpDir := newTestPlugin(t) defer os.RemoveAll(tmpDir) - pv := makeTestPV("test-pv", 10, testDriver, testVol) - unmounter, err := plug.NewUnmounter(pv.ObjectMeta.Name, testPodUID) - if err != nil { - t.Fatalf("failed to make a new Unmounter: %v", err) - } - - csiUnmounter := unmounter.(*csiMountMgr) - csiUnmounter.csiClient = setupClient(t, false) - - dir := csiUnmounter.GetPath() - // save the data file prior to unmount + dir := path.Join(getTargetPath(testPodUID, pv.ObjectMeta.Name, plug.host), "/mount") if err := os.MkdirAll(dir, 0755); err != nil && !os.IsNotExist(err) { t.Errorf("failed to create dir [%s]: %v", dir, err) } + + // do a fake local mount + diskMounter := util.NewSafeFormatAndMountFromHost(plug.GetPluginName(), plug.host) + if err := diskMounter.FormatAndMount("/fake/device", dir, "testfs", nil); err != nil { + t.Errorf("failed to mount dir [%s]: %v", dir, err) + } + if err := saveVolumeData( - plug, - testPodUID, - "test-pv", - map[string]string{volDataKey.specVolID: "test-pv", volDataKey.driverName: "driver", volDataKey.volHandle: "vol-handle"}, + path.Dir(dir), + volDataFileName, + map[string]string{ + volDataKey.specVolID: pv.ObjectMeta.Name, + volDataKey.driverName: testDriver, + volDataKey.volHandle: testVol, + }, ); err != nil { t.Fatalf("failed to save volume data: %v", err) } + unmounter, err := plug.NewUnmounter(pv.ObjectMeta.Name, testPodUID) + if err != nil { + t.Fatalf("failed to make a new Unmounter: %v", err) + } + + csiUnmounter := unmounter.(*csiMountMgr) + csiUnmounter.csiClient = setupClient(t, true) err = csiUnmounter.TearDownAt(dir) if err != nil { t.Fatal(err) } // ensure csi client call - pubs := csiUnmounter.csiClient.(*csiDriverClient).nodeClient.(*fake.NodeClient).GetNodePublishedVolumes() + pubs := csiUnmounter.csiClient.(*fakeCsiDriverClient).nodeClient.GetNodePublishedVolumes() if _, ok := pubs[csiUnmounter.volumeID]; ok { t.Error("csi server may not have received NodeUnpublishVolume call") } @@ -223,7 +229,7 @@ func TestSaveVolumeData(t *testing.T) { t.Errorf("failed to create dir [%s]: %v", mountDir, err) } - err := saveVolumeData(plug, testPodUID, specVolID, tc.data) + err := saveVolumeData(path.Dir(mountDir), volDataFileName, tc.data) if !tc.shouldFail && err != nil { t.Errorf("unexpected failure: %v", err) diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index d63296877d02c..8a993afb92050 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -19,6 +19,8 @@ package csi import ( "errors" "fmt" + "os" + "path" "time" "github.com/golang/glog" @@ -103,17 +105,14 @@ func (p *csiPlugin) NewMounter( return nil, err } - // before it is used in any paths such as socket etc - addr := fmt.Sprintf(csiAddrTemplate, pvSource.Driver) - glog.V(4).Infof(log("setting up mounter for [volume=%v,driver=%v]", pvSource.VolumeHandle, pvSource.Driver)) - client := newCsiDriverClient("unix", addr) - k8s := p.host.GetKubeClient() if k8s == nil { glog.Error(log("failed to get a kubernetes client")) return nil, errors.New("failed to get a Kubernetes client") } + csi := newCsiDriverClient(pvSource.Driver) + mounter := &csiMountMgr{ plugin: p, k8s: k8s, @@ -123,19 +122,66 @@ func (p *csiPlugin) NewMounter( driverName: pvSource.Driver, volumeID: pvSource.VolumeHandle, specVolumeID: spec.Name(), - csiClient: client, + csiClient: csi, readOnly: readOnly, } + + // Save volume info in pod dir + dir := mounter.GetPath() + dataDir := path.Dir(dir) // dropoff /mount at end + + if err := os.MkdirAll(dataDir, 0750); err != nil { + glog.Error(log("failed to create dir %#v: %v", dataDir, err)) + return nil, err + } + glog.V(4).Info(log("created path successfully [%s]", dataDir)) + + // persist volume info data for teardown + node := string(p.host.GetNodeName()) + attachID := getAttachmentName(pvSource.VolumeHandle, pvSource.Driver, node) + volData := map[string]string{ + volDataKey.specVolID: spec.Name(), + volDataKey.volHandle: pvSource.VolumeHandle, + volDataKey.driverName: pvSource.Driver, + volDataKey.nodeName: node, + volDataKey.attachmentID: attachID, + } + + if err := saveVolumeData(dataDir, volDataFileName, volData); err != nil { + glog.Error(log("failed to save volume info data: %v", err)) + if err := os.RemoveAll(dataDir); err != nil { + glog.Error(log("failed to remove dir after error [%s]: %v", dataDir, err)) + return nil, err + } + return nil, err + } + + glog.V(4).Info(log("mounter created successfully")) + return mounter, nil } func (p *csiPlugin) NewUnmounter(specName string, podUID types.UID) (volume.Unmounter, error) { glog.V(4).Infof(log("setting up unmounter for [name=%v, podUID=%v]", specName, podUID)) + unmounter := &csiMountMgr{ plugin: p, podUID: podUID, specVolumeID: specName, } + + // load volume info from file + dir := unmounter.GetPath() + dataDir := path.Dir(dir) // dropoff /mount at end + data, err := loadVolumeData(dataDir, volDataFileName) + if err != nil { + glog.Error(log("unmounter failed to load volume data file [%s]: %v", dir, err)) + return nil, err + } + unmounter.driverName = data[volDataKey.driverName] + unmounter.volumeID = data[volDataKey.volHandle] + unmounter.csiClient = newCsiDriverClient(unmounter.driverName) + return unmounter, nil } diff --git a/pkg/volume/csi/csi_plugin_test.go b/pkg/volume/csi/csi_plugin_test.go index 66151e0ac7d05..ed7a92ffe5183 100644 --- a/pkg/volume/csi/csi_plugin_test.go +++ b/pkg/volume/csi/csi_plugin_test.go @@ -160,7 +160,7 @@ func TestPluginConstructVolumeSpec(t *testing.T) { if err := os.MkdirAll(mountDir, 0755); err != nil && !os.IsNotExist(err) { t.Errorf("failed to create dir [%s]: %v", mountDir, err) } - if err := saveVolumeData(plug, testPodUID, tc.specVolID, tc.data); err != nil { + if err := saveVolumeData(path.Dir(mountDir), volDataFileName, tc.data); err != nil { t.Fatal(err) } } @@ -225,6 +225,25 @@ func TestPluginNewUnmounter(t *testing.T) { pv := makeTestPV("test-pv", 10, testDriver, testVol) + // save the data file to re-create client + dir := path.Join(getTargetPath(testPodUID, pv.ObjectMeta.Name, plug.host), "/mount") + if err := os.MkdirAll(dir, 0755); err != nil && !os.IsNotExist(err) { + t.Errorf("failed to create dir [%s]: %v", dir, err) + } + + if err := saveVolumeData( + path.Dir(dir), + volDataFileName, + map[string]string{ + volDataKey.specVolID: pv.ObjectMeta.Name, + volDataKey.driverName: testDriver, + volDataKey.volHandle: testVol, + }, + ); err != nil { + t.Fatalf("failed to save volume data: %v", err) + } + + // test unmounter unmounter, err := plug.NewUnmounter(pv.ObjectMeta.Name, testPodUID) csiUnmounter := unmounter.(*csiMountMgr)