Skip to content

Commit

Permalink
Fix the unsafe casbin Model (#21132)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Dec 12, 2022
1 parent a974d72 commit f209c9c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
1 change: 0 additions & 1 deletion internal/distributed/proxy/service.go
Expand Up @@ -338,7 +338,6 @@ func (s *Server) init() error {
}
s.etcdCli = etcdCli
s.proxy.SetEtcdClient(s.etcdCli)
proxy.InitPolicyModel()

errChan := make(chan error, 1)
{
Expand Down
14 changes: 4 additions & 10 deletions internal/proxy/privilege_interceptor.go
Expand Up @@ -40,16 +40,12 @@ m = r.sub == p.sub && globMatch(r.obj, p.obj) && globMatch(r.act, p.act) || r.su
`
)

var (
casbinModel model.Model
)

func InitPolicyModel() {
var err error
casbinModel, err = model.NewModelFromString(ModelStr)
func getPolicyModel(modelString string) model.Model {
model, err := model.NewModelFromString(modelString)
if err != nil {
log.Panic("NewModelFromString fail", zap.String("model", ModelStr), zap.Error(err))
}
return model
}

// UnaryServerInterceptor returns a new unary server interceptors that performs per-request privilege access.
Expand Down Expand Up @@ -107,9 +103,7 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context
policy := fmt.Sprintf("[%s]", policyInfo)
b := []byte(policy)
a := jsonadapter.NewAdapter(&b)
if casbinModel == nil {
log.Panic("fail to get policy model")
}
casbinModel := getPolicyModel(ModelStr)
e, err := casbin.NewEnforcer(casbinModel, a)
if err != nil {
log.Error("NewEnforcer fail", zap.String("policy", policy), zap.Error(err))
Expand Down
27 changes: 18 additions & 9 deletions internal/proxy/privilege_interceptor_test.go
Expand Up @@ -2,13 +2,13 @@ package proxy

import (
"context"
"sync"
"testing"

"github.com/milvus-io/milvus/internal/util/funcutil"

"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
)

Expand All @@ -21,7 +21,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
ctx := context.Background()

t.Run("Authorization Disabled", func(t *testing.T) {
InitPolicyModel()
Params.CommonCfg.AuthorizationEnabled = false
_, err := PrivilegeInterceptor(ctx, &milvuspb.LoadCollectionRequest{
DbName: "db_test",
Expand All @@ -31,7 +30,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
})

t.Run("Authorization Enabled", func(t *testing.T) {
InitPolicyModel()
Params.CommonCfg.AuthorizationEnabled = true

_, err := PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{})
Expand Down Expand Up @@ -113,12 +111,23 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.Nil(t, err)

casbinModel = nil
g := sync.WaitGroup{}
for i := 0; i < 20; i++ {
g.Add(1)
go func() {
defer g.Done()
assert.NotPanics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
})
}()
}
g.Wait()

assert.Panics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
getPolicyModel("foo")
})
})

Expand Down

0 comments on commit f209c9c

Please sign in to comment.