diff --git a/pkg/cri/resource-manager/flags.go b/pkg/cri/resource-manager/flags.go index a141309dc6..a2a6ba7006 100644 --- a/pkg/cri/resource-manager/flags.go +++ b/pkg/cri/resource-manager/flags.go @@ -28,6 +28,7 @@ type options struct { RuntimeSocket string RelaySocket string RelayDir string + AllowDocker bool AgentSocket string ConfigSocket string PidFile string @@ -46,6 +47,10 @@ type options struct { // Relay command line options. var opt = options{} +const ( + allowDockerFlag = "allow-docker" +) + // Register us for command line option processing. func init() { flag.StringVar(&opt.ImageSocket, "image-socket", sockets.Containerd, @@ -56,6 +61,9 @@ func init() { "Unix domain socket path where the resource manager should serve requests on.") flag.StringVar(&opt.RelayDir, "relay-dir", "/var/lib/cri-resmgr", "Permanent storage directory path for the resource manager to store its state in.") + flag.BoolVar(&opt.AllowDocker, allowDockerFlag, false, + "Allow cri-dockerd/docker-shim as a CRI runtime. Usually this is not a good idea.") + flag.StringVar(&opt.AgentSocket, "agent-socket", sockets.ResourceManagerAgent, "local socket of the cri-resmgr agent to connect") flag.StringVar(&opt.ConfigSocket, "config-socket", sockets.ResourceManagerConfig, diff --git a/pkg/cri/resource-manager/requests.go b/pkg/cri/resource-manager/requests.go index a2146bc765..8f5efc1b2c 100644 --- a/pkg/cri/resource-manager/requests.go +++ b/pkg/cri/resource-manager/requests.go @@ -17,6 +17,7 @@ package resmgr import ( "context" "fmt" + "strings" criapi "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" @@ -28,6 +29,11 @@ import ( "github.com/intel/cri-resource-manager/pkg/cri/server" ) +const ( + kubeAPIVersion = "0.1.0" + dockerRuntimeName = "docker" +) + // setupRequestProcessing prepares the resource manager for CRI request processing. func (m *resmgr) setupRequestProcessing() error { interceptors := map[string]server.Interceptor{ @@ -158,6 +164,18 @@ func (m *resmgr) startRequestProcessing() error { // syncWithCRI synchronizes cache pods and containers with the CRI runtime. func (m *resmgr) syncWithCRI(ctx context.Context) ([]cache.Container, []cache.Container, error) { + version, err := m.relay.Client().Version(ctx, &criapi.VersionRequest{ + Version: kubeAPIVersion, + }) + if err != nil { + return nil, nil, resmgrError("failed to query runtime version: %v", err) + } + if strings.HasPrefix(version.RuntimeName, dockerRuntimeName) { + if !opt.AllowDocker { + return nil, nil, rejectRuntimeError(version.RuntimeName) + } + } + if m.policy.Bypassed() || !m.relay.Client().HasRuntimeService() { return nil, nil, nil } @@ -898,3 +916,8 @@ func (m *resmgr) sendCRIRequest(ctx context.Context, request interface{}) (inter return nil, resmgrError("sendCRIRequest: unhandled request type %T", request) } } + +func rejectRuntimeError(name string) error { + return resmgrError("rejecting disallowed runtime %s, use --%s to allow it", + name, allowDockerFlag) +} diff --git a/test/functional/fake_cri_server_test.go b/test/functional/fake_cri_server_test.go index 7945c885c1..69889e1688 100644 --- a/test/functional/fake_cri_server_test.go +++ b/test/functional/fake_cri_server_test.go @@ -31,6 +31,13 @@ import ( api "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" ) +const ( + fakeVersion = "0.1.0" + fakeRuntimeName = "fake-CRI-runtime" + fakeRuntimeVersion = "v0.0.0" + fakeRuntimeApiVersion = "v1" +) + type fakeCriServer struct { t *testing.T socket string @@ -125,7 +132,16 @@ func (s *fakeCriServer) callHandler(ctx context.Context, request interface{}, de // Implementation of api.RuntimeServiceServer func (s *fakeCriServer) Version(ctx context.Context, req *api.VersionRequest) (*api.VersionResponse, error) { - response, err := s.callHandler(ctx, req, nil) + response, err := s.callHandler(ctx, req, + func(context.Context, *api.VersionRequest) (*api.VersionResponse, error) { + return &api.VersionResponse{ + Version: fakeVersion, + RuntimeName: fakeRuntimeName, + RuntimeVersion: fakeRuntimeVersion, + RuntimeApiVersion: fakeRuntimeApiVersion, + }, nil + }, + ) return response.(*api.VersionResponse), err }