From f209c9cc21ebbb31ba2b57a488d81da790879fbc Mon Sep 17 00:00:00 2001 From: SimFG Date: Mon, 12 Dec 2022 14:19:22 +0800 Subject: [PATCH] Fix the unsafe casbin `Model` (#21132) Signed-off-by: SimFG Signed-off-by: SimFG --- internal/distributed/proxy/service.go | 1 - internal/proxy/privilege_interceptor.go | 14 +++------- internal/proxy/privilege_interceptor_test.go | 27 +++++++++++++------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index bbb3cde857c7..c9fc7258b082 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -338,7 +338,6 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.proxy.SetEtcdClient(s.etcdCli) - proxy.InitPolicyModel() errChan := make(chan error, 1) { diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 5db7f016f10e..a813943e0902 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -40,16 +40,12 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su ` ) -var ( - casbinModel model.Model -) - -func InitPolicyModel() { - var err error - casbinModel, err = model.NewModelFromString(ModelStr) +func getPolicyModel(modelString string) model.Model { + model, err := model.NewModelFromString(modelString) if err != nil { log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err)) } + return model } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access. @@ -107,9 +103,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context policy := fmt.Sprintf("[%s]", policyInfo) b := []byte(policy) a := jsonadapter.NewAdapter(&b) - if casbinModel == nil { - log.Panic("fail to get policy model") - } + casbinModel := getPolicyModel(ModelStr) e, err := casbin.NewEnforcer(casbinModel, a) if err != nil { log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err)) diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index 00929e162eac..e6ffc2713999 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -2,13 +2,13 @@ package proxy import ( "context" + "sync" "testing" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/stretchr/testify/assert" ) @@ -21,7 +21,6 @@ func TestPrivilegeInterceptor(t *testing.T) { ctx := context.Background() t.Run("Authorization Disabled", func(t *testing.T) { - InitPolicyModel() Params.CommonCfg.AuthorizationEnabled = false _, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{ DbName: "db_test", @@ -31,7 +30,6 @@ func TestPrivilegeInterceptor(t *testing.T) { }) t.Run("Authorization Enabled", func(t *testing.T) { - InitPolicyModel() Params.CommonCfg.AuthorizationEnabled = true _, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{}) @@ -113,12 +111,23 @@ func TestPrivilegeInterceptor(t *testing.T) { }) assert.Nil(t, err) - casbinModel = nil + g := sync.WaitGroup{} + for i := 0; i < 20; i++ { + g.Add(1) + go func() { + defer g.Done() + assert.NotPanics(t, func() { + PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ + DbName: "db_test", + CollectionName: "col1", + }) + }) + }() + } + g.Wait() + assert.Panics(t, func() { - PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ - DbName: "db_test", - CollectionName: "col1", - }) + getPolicyModel("foo") }) })