Skip to content

Commit

Permalink
Merge pull request #133 from carlory/patch-002
Browse files Browse the repository at this point in the history
Add GetGroupControllerCapabilities common function
  • Loading branch information
k8s-ci-robot committed Jun 14, 2023
2 parents a0d716c + 9ea2545 commit 1b2426d
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 5 deletions.
27 changes: 27 additions & 0 deletions rpc/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,33 @@ func GetControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (Cont
return caps, nil
}

// GroupControllerCapabilitySet is set of CSI groupcontroller capabilities. Only supported capabilities are in the map.
type GroupControllerCapabilitySet map[csi.GroupControllerServiceCapability_RPC_Type]bool

// GetGroupControllerCapabilities returns set of supported group controller capabilities of CSI driver.
func GetGroupControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (GroupControllerCapabilitySet, error) {
client := csi.NewGroupControllerClient(conn)
req := csi.GroupControllerGetCapabilitiesRequest{}
rsp, err := client.GroupControllerGetCapabilities(ctx, &req)
if err != nil {
return nil, err
}

caps := GroupControllerCapabilitySet{}
for _, cap := range rsp.GetCapabilities() {
if cap == nil {
continue
}
rpc := cap.GetRpc()
if rpc == nil {
continue
}
t := rpc.GetType()
caps[t] = true
}
return caps, nil
}

// ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready.
// Any error other than timeout is returned.
func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error {
Expand Down
130 changes: 125 additions & 5 deletions rpc/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ const (
// startServer creates a gRPC server without any registered services.
// The returned address can be used to connect to it. The cleanup
// function stops it. It can be called multiple times.
func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer) (string, func()) {
func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer, groupCtrl csi.GroupControllerServer) (string, func()) {
addr := path.Join(tmp, serverSock)
listener, err := net.Listen("unix", addr)
require.NoError(t, err, "listening on %s", addr)
Expand All @@ -63,6 +63,9 @@ func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controll
if controller != nil {
csi.RegisterControllerServer(server, controller)
}
if groupCtrl != nil {
csi.RegisterGroupControllerServer(server, groupCtrl)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand Down Expand Up @@ -127,7 +130,7 @@ func TestGetDriverName(t *testing.T) {
pluginInfoResponse: out,
err: injectedErr,
}
addr, stopServer := startServer(t, tmp, identity, nil)
addr, stopServer := startServer(t, tmp, identity, nil, nil)
defer func() {
stopServer()
}()
Expand Down Expand Up @@ -247,7 +250,7 @@ func TestGetPluginCapabilities(t *testing.T) {
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
err: injectedErr,
}
addr, stopServer := startServer(t, tmp, identity, nil)
addr, stopServer := startServer(t, tmp, identity, nil, nil)
defer func() {
stopServer()
}()
Expand Down Expand Up @@ -375,7 +378,7 @@ func TestGetControllerCapabilities(t *testing.T) {
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
err: injectedErr,
}
addr, stopServer := startServer(t, tmp, nil, controller)
addr, stopServer := startServer(t, tmp, nil, controller, nil)
defer func() {
stopServer()
}()
Expand All @@ -399,6 +402,100 @@ func TestGetControllerCapabilities(t *testing.T) {
}
}

func TestGetGroupControllerCapabilities(t *testing.T) {
tests := []struct {
name string
output *csi.GroupControllerGetCapabilitiesResponse
injectError bool
expectCapabilities GroupControllerCapabilitySet
expectError bool
}{
{
name: "success",
output: &csi.GroupControllerGetCapabilitiesResponse{
Capabilities: []*csi.GroupControllerServiceCapability{
{
Type: &csi.GroupControllerServiceCapability_Rpc{
Rpc: &csi.GroupControllerServiceCapability_RPC{
Type: csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT,
},
},
},
},
},
expectCapabilities: GroupControllerCapabilitySet{
csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT: true,
},
expectError: false,
},
{
name: "gRPC error",
output: nil,
injectError: true,
expectError: true,
},
{
name: "empty capability",
output: &csi.GroupControllerGetCapabilitiesResponse{
Capabilities: []*csi.GroupControllerServiceCapability{
{
Type: nil,
},
},
},
expectCapabilities: GroupControllerCapabilitySet{},
expectError: false,
},
{
name: "no capabilities",
output: &csi.GroupControllerGetCapabilitiesResponse{
Capabilities: []*csi.GroupControllerServiceCapability{},
},
expectCapabilities: GroupControllerCapabilitySet{},
expectError: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var injectedErr error
if test.injectError {
injectedErr = fmt.Errorf("mock error")
}

tmp := tmpDir(t)
defer os.RemoveAll(tmp)
groupCtrl := &fakeGroupControllerServer{
groupControllerGetCapabilitiesResponse: test.output,

// Make code compatible with gofmt 1.10.2 (used by pull-sig-storage-csi-lib-utils-stable)
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
err: injectedErr,
}
addr, stopServer := startServer(t, tmp, nil, nil, groupCtrl)
defer func() {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}

caps, err := GetGroupControllerCapabilities(context.Background(), conn)
if test.expectError && err == nil {
t.Errorf("Expected error, got none")
}
if !test.expectError && err != nil {
t.Errorf("Got error: %v", err)
}
if !reflect.DeepEqual(test.expectCapabilities, caps) {
t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps)
}
})
}
}

func TestProbeForever(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -509,7 +606,7 @@ func TestProbeForever(t *testing.T) {
identity := &fakeIdentityServer{
probeCalls: test.probeCalls,
}
addr, stopServer := startServer(t, tmp, identity, nil)
addr, stopServer := startServer(t, tmp, identity, nil, nil)
defer func() {
stopServer()
}()
Expand Down Expand Up @@ -624,3 +721,26 @@ func (c *fakeControllerServer) ListSnapshots(context.Context, *csi.ListSnapshots
func (c *fakeControllerServer) ControllerExpandVolume(context.Context, *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
return nil, fmt.Errorf("unimplemented")
}

type fakeGroupControllerServer struct {
groupControllerGetCapabilitiesResponse *csi.GroupControllerGetCapabilitiesResponse
err error
}

var _ csi.GroupControllerServer = &fakeGroupControllerServer{}

func (c *fakeGroupControllerServer) GroupControllerGetCapabilities(context.Context, *csi.GroupControllerGetCapabilitiesRequest) (*csi.GroupControllerGetCapabilitiesResponse, error) {
return c.groupControllerGetCapabilitiesResponse, c.err
}

func (c *fakeGroupControllerServer) CreateVolumeGroupSnapshot(context.Context, *csi.CreateVolumeGroupSnapshotRequest) (*csi.CreateVolumeGroupSnapshotResponse, error) {
return nil, fmt.Errorf("unimplemented")
}

func (c *fakeGroupControllerServer) DeleteVolumeGroupSnapshot(context.Context, *csi.DeleteVolumeGroupSnapshotRequest) (*csi.DeleteVolumeGroupSnapshotResponse, error) {
return nil, fmt.Errorf("unimplemented")
}

func (c *fakeGroupControllerServer) GetVolumeGroupSnapshot(context.Context, *csi.GetVolumeGroupSnapshotRequest) (*csi.GetVolumeGroupSnapshotResponse, error) {
return nil, fmt.Errorf("unimplemented")
}

0 comments on commit 1b2426d

Please sign in to comment.