Skip to content

Commit

Permalink
Merge pull request #67110 from verult/kubelet-nodeid
Browse files Browse the repository at this point in the history
Automatic merge from submit-queue (batch tested with PRs 67017, 67190, 67110, 67140, 66873). If you want to cherry-pick this change to another branch, please follow the instructions <a href="https://github.com/kubernetes/community/blob/master/contributors/devel/cherry-picks.md">here</a>.

CSI plugin now calls NodeGetInfo() to get driver's node ID

**Which issue(s) this PR fixes** *(optional, in `fixes #<issue number>(, fixes #<issue_number>, ...)` format, will close the issue(s) when PR gets merged)*:
Fixes #67040

**Special notes for your reviewer**:

**Release note**:

```release-note
NONE
```
/sig storage
@sbezverk @vladimirvivien @saad-ali
  • Loading branch information
Kubernetes Submit Queue committed Aug 11, 2018
2 parents 29e167e + 7fa120c commit 032a096
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 9 deletions.
23 changes: 23 additions & 0 deletions pkg/volume/csi/csi_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ import (
)

type csiClient interface {
NodeGetInfo(ctx context.Context) (
nodeID string,
maxVolumePerNode int64,
accessibleTopology *csipb.Topology,
err error)
NodePublishVolume(
ctx context.Context,
volumeid string,
Expand Down Expand Up @@ -75,6 +80,24 @@ func newCsiDriverClient(driverName string) *csiDriverClient {
return c
}

func (c *csiDriverClient) NodeGetInfo(ctx context.Context) (
nodeID string,
maxVolumePerNode int64,
accessibleTopology *csipb.Topology,
err error) {
glog.V(4).Info(log("calling NodeGetInfo rpc"))

conn, err := newGrpcConn(c.driverName)
if err != nil {
return "", 0, nil, err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)

res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{})
return res.GetNodeId(), res.GetMaxVolumesPerNode(), res.GetAccessibleTopology(), nil
}

func (c *csiDriverClient) NodePublishVolume(
ctx context.Context,
volID string,
Expand Down
64 changes: 64 additions & 0 deletions pkg/volume/csi/csi_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
csipb "github.com/container-storage-interface/spec/lib/go/csi/v0"
api "k8s.io/api/core/v1"
"k8s.io/kubernetes/pkg/volume/csi/fake"
"reflect"
)

type fakeCsiDriverClient struct {
Expand All @@ -38,6 +39,15 @@ func newFakeCsiDriverClient(t *testing.T, stagingCapable bool) *fakeCsiDriverCli
}
}

func (c *fakeCsiDriverClient) NodeGetInfo(ctx context.Context) (
nodeID string,
maxVolumePerNode int64,
accessibleTopology *csipb.Topology,
err error) {
resp, err := c.nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{})
return resp.GetNodeId(), resp.GetMaxVolumesPerNode(), resp.GetAccessibleTopology(), err
}

func (c *fakeCsiDriverClient) NodePublishVolume(
ctx context.Context,
volID string,
Expand Down Expand Up @@ -141,6 +151,60 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient {
return newFakeCsiDriverClient(t, stageUnstageSet)
}

func TestClientNodeGetInfo(t *testing.T) {
testCases := []struct {
name string
expectedNodeID string
expectedMaxVolumePerNode int64
expectedAccessibleTopology *csipb.Topology
mustFail bool
err error
}{
{
name: "test ok",
expectedNodeID: "node1",
expectedMaxVolumePerNode: 16,
expectedAccessibleTopology: &csipb.Topology{
Segments: map[string]string{"com.example.csi-topology/zone": "zone1"},
},
},
{name: "grpc error", mustFail: true, err: errors.New("grpc error")},
}

client := setupClient(t, false /* stageUnstageSet */)

for _, tc := range testCases {
t.Logf("test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{
NodeId: tc.expectedNodeID,
MaxVolumesPerNode: tc.expectedMaxVolumePerNode,
AccessibleTopology: tc.expectedAccessibleTopology,
})
nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background())

if tc.mustFail && err == nil {
t.Error("expected an error but got none")
}

if !tc.mustFail && err != nil {
t.Errorf("expected no errors but got: %v", err)
}

if nodeID != tc.expectedNodeID {
t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID)
}

if maxVolumePerNode != tc.expectedMaxVolumePerNode {
t.Errorf("expected maxVolumePerNode: %v; got: %v", tc.expectedMaxVolumePerNode, maxVolumePerNode)
}

if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) {
t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology)
}
}
}

func TestClientNodePublishVolume(t *testing.T) {
testCases := []struct {
name string
Expand Down
28 changes: 23 additions & 5 deletions pkg/volume/csi/csi_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"sync"
"time"

"context"
"github.com/golang/glog"
api "k8s.io/api/core/v1"
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -76,6 +77,7 @@ type csiDriversStore struct {
sync.RWMutex
}

// TODO (verult) consider using a struct instead of global variables
// csiDrivers map keep track of all registered CSI drivers on the node and their
// corresponding sockets
var csiDrivers csiDriversStore
Expand All @@ -92,17 +94,33 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string,
if endpoint == "" {
endpoint = socketPath
}
// Calling nodeLabelManager to update label for newly registered CSI driver
err := lm.AddLabels(pluginName)
if err != nil {
return nil, err
}

// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
// all other CSI components will be able to get the actual socket of CSI drivers by its name.
csiDrivers.Lock()
defer csiDrivers.Unlock()
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint}

// Get node info from the driver.
csi := newCsiDriverClient(pluginName)
// TODO (verult) retry with exponential backoff, possibly added in csi client library.
ctx, cancel := context.WithTimeout(context.Background(), csiTimeout)
defer cancel()
driverNodeID, _, _, err := csi.NodeGetInfo(ctx)
if err != nil {
return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err)
}

// Calling nodeLabelManager to update annotations and labels for newly registered CSI driver
err = lm.AddLabels(pluginName, driverNodeID)
if err != nil {
// Unregister the driver and return error
csiDrivers.Lock()
defer csiDrivers.Unlock()
delete(csiDrivers.driversMap, pluginName)
return nil, err
}

return nil, nil
}

Expand Down
13 changes: 13 additions & 0 deletions pkg/volume/csi/fake/fake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type NodeClient struct {
nodePublishedVolumes map[string]string
nodeStagedVolumes map[string]string
stageUnstageSet bool
nodeGetInfoResp *csipb.NodeGetInfoResponse
nextErr error
}

Expand All @@ -78,6 +79,10 @@ func (f *NodeClient) SetNextError(err error) {
f.nextErr = err
}

func (f *NodeClient) SetNodeGetInfoResp(resp *csipb.NodeGetInfoResponse) {
f.nodeGetInfoResp = resp
}

// GetNodePublishedVolumes returns node published volumes
func (f *NodeClient) GetNodePublishedVolumes() map[string]string {
return f.nodePublishedVolumes
Expand Down Expand Up @@ -179,6 +184,14 @@ func (f *NodeClient) NodeGetId(ctx context.Context, in *csipb.NodeGetIdRequest,
return nil, nil
}

// NodeGetId implements csi method
func (f *NodeClient) NodeGetInfo(ctx context.Context, in *csipb.NodeGetInfoRequest, opts ...grpc.CallOption) (*csipb.NodeGetInfoResponse, error) {
if f.nextErr != nil {
return nil, f.nextErr
}
return f.nodeGetInfoResp, nil
}

// NodeGetCapabilities implements csi method
func (f *NodeClient) NodeGetCapabilities(ctx context.Context, in *csipb.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipb.NodeGetCapabilitiesResponse, error) {
resp := &csipb.NodeGetCapabilitiesResponse{
Expand Down
7 changes: 3 additions & 4 deletions pkg/volume/csi/labelmanager/labelmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ const (
// Name of node annotation that contains JSON map of driver names to node
// names
annotationKey = "csi.volume.kubernetes.io/nodeid"
csiPluginName = "kubernetes.io/csi"
)

// labelManagementStruct is struct of channels used for communication between the driver registration
Expand All @@ -46,7 +45,7 @@ type labelManagerStruct struct {

// Interface implements an interface for managing labels of a node
type Interface interface {
AddLabels(driverName string) error
AddLabels(driverName string, driverNodeId string) error
}

// NewLabelManager initializes labelManagerStruct and returns available interfaces
Expand All @@ -59,8 +58,8 @@ func NewLabelManager(nodeName types.NodeName, kubeClient kubernetes.Interface) I

// nodeLabelManager waits for labeling requests initiated by the driver's registration
// process.
func (lm labelManagerStruct) AddLabels(driverName string) error {
err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, string(lm.nodeName))
func (lm labelManagerStruct) AddLabels(driverName string, driverNodeId string) error {
err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, driverNodeId)
if err != nil {
return fmt.Errorf("failed to update node %s's annotation with error: %+v", lm.nodeName, err)
}
Expand Down

0 comments on commit 032a096

Please sign in to comment.