Skip to content

Commit

Permalink
enhance: [2.4] the panic when db isn't existed in the rate limit inte…
Browse files Browse the repository at this point in the history
…rceptor (#33308)

issue: #33243
pr: #33244

1. fix: the panic when db isn't existed in the rate limit interceptor
#33244
2. enhance: check the auth in some rest v2 api #33256

---------

Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed May 23, 2024
1 parent 8c9afd5 commit 37b2f90
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 20 deletions.
16 changes: 8 additions & 8 deletions internal/distributed/proxy/httpserver/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a
DbName: dbName,
CollectionName: collectionName,
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest))
})
if err != nil {
Expand Down Expand Up @@ -1601,7 +1601,7 @@ func (h *HandlersV2) listIndexes(ctx context.Context, c *gin.Context, anyReq any
DbName: dbName,
CollectionName: collectionGetter.GetCollectionName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) {
resp, err := h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
if errors.Is(err, merr.ErrIndexNotFound) {
return &milvuspb.DescribeIndexResponse{
Expand Down Expand Up @@ -1633,7 +1633,7 @@ func (h *HandlersV2) describeIndex(ctx context.Context, c *gin.Context, anyReq a
CollectionName: collectionGetter.GetCollectionName(),
IndexName: indexGetter.GetIndexName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest))
})
if err == nil {
Expand Down Expand Up @@ -1681,7 +1681,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any
for key, value := range indexParam.Params {
req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)})
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest))
})
if err != nil {
Expand All @@ -1700,7 +1700,7 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any,
CollectionName: collGetter.GetCollectionName(),
IndexName: indexGetter.GetIndexName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest))
})
if err == nil {
Expand Down Expand Up @@ -1752,7 +1752,7 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any
CollectionName: collectionGetter.GetCollectionName(),
Alias: aliasGetter.GetAliasName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest))
})
if err == nil {
Expand All @@ -1767,7 +1767,7 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any,
DbName: dbName,
Alias: getter.GetAliasName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest))
})
if err == nil {
Expand All @@ -1784,7 +1784,7 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any,
CollectionName: collectionGetter.GetCollectionName(),
Alias: aliasGetter.GetAliasName(),
}
resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) {
resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest))
})
if err == nil {
Expand Down
3 changes: 0 additions & 3 deletions internal/proxy/meta_cache_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ import (

"github.com/casbin/casbin/v2/model"
jsonadapter "github.com/casbin/json-adapter/v2"
"go.uber.org/zap"

"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)

Expand All @@ -51,7 +49,6 @@ func (a *MetaCacheCasbinAdapter) LoadPolicy(model model.Model) error {
policyInfo := strings.Join(cache.GetPrivilegeInfo(context.Background()), ",")

policy := fmt.Sprintf("[%s]", policyInfo)
log.Ctx(context.Background()).Info("LoddPolicy update policyinfo", zap.String("policyInfo", policy))
byteSource := []byte(policy)
jAdapter := jsonadapter.NewAdapter(&byteSource)
return jAdapter.LoadPolicy(model)
Expand Down
14 changes: 9 additions & 5 deletions internal/proxy/rate_limit_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/requestutil"
Expand Down Expand Up @@ -119,6 +120,9 @@ func getCollectionAndPartitionIDs(ctx context.Context, r reqPartNames) (int64, m

func getCollectionID(r reqCollName) (int64, map[int64][]int64) {
db, _ := globalMetaCache.GetDatabaseInfo(context.TODO(), r.GetDbName())
if db == nil {
return util.InvalidDBID, map[int64][]int64{}
}
collectionID, _ := globalMetaCache.GetCollectionID(context.TODO(), r.GetDbName(), r.GetCollectionName())
return db.dbID, map[int64][]int64{collectionID: {}}
}
Expand Down Expand Up @@ -177,14 +181,14 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
case *milvuspb.FlushRequest:
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}

collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
Expand All @@ -193,16 +197,16 @@ func getRequestInfo(ctx context.Context, req interface{}) (int64, map[int64][]in
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return 0, map[int64][]int64{}, 0, 0, err
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
default: // TODO: support more request
if req == nil {
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("null request")
}
return 0, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
return util.InvalidDBID, map[int64][]int64{}, 0, 0, fmt.Errorf("unsupported request type %s", reflect.TypeOf(req).Name())
}
}

Expand Down
8 changes: 7 additions & 1 deletion internal/proxy/rate_limit_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/merr"
)

Expand Down Expand Up @@ -367,7 +368,7 @@ func TestGetInfo(t *testing.T) {
}()

t.Run("fail to get database", func(t *testing.T) {
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(4)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: get database info")).Times(5)
{
_, _, err := getCollectionAndPartitionID(ctx, &milvuspb.InsertRequest{
DbName: "foo",
Expand All @@ -394,6 +395,11 @@ func TestGetInfo(t *testing.T) {
_, _, _, _, err := getRequestInfo(ctx, &milvuspb.ManualCompactionRequest{})
assert.Error(t, err)
}
{
dbID, collectionIDInfos := getCollectionID(&milvuspb.CreateCollectionRequest{})
assert.Equal(t, util.InvalidDBID, dbID)
assert.Equal(t, 0, len(collectionIDInfos))
}
})

t.Run("fail to get collection", func(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion internal/proxy/simple_rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
rlinternal "github.com/milvus-io/milvus/internal/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/util/typeutil"
Expand Down Expand Up @@ -79,7 +80,7 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
}

// 2. check database level rate limits
if ret == nil {
if ret == nil && dbID != util.InvalidDBID {
dbRateLimiters := m.rateLimiter.GetOrCreateDatabaseLimiters(dbID, newDatabaseLimiter)
ret = dbRateLimiters.Check(rt, n)
if ret != nil {
Expand All @@ -92,6 +93,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
// 3. check collection level rate limits
if ret == nil && len(collectionIDToPartIDs) > 0 && !isNotCollectionLevelLimitRequest(rt) {
for collectionID := range collectionIDToPartIDs {
if collectionID == 0 || dbID == util.InvalidDBID {
continue
}
// only dml and dql have collection level rate limits
collectionRateLimiters := m.rateLimiter.GetOrCreateCollectionLimiters(dbID, collectionID,
newDatabaseLimiter, newCollectionLimiters)
Expand All @@ -108,6 +112,9 @@ func (m *SimpleLimiter) Check(dbID int64, collectionIDToPartIDs map[int64][]int6
if ret == nil && len(collectionIDToPartIDs) > 0 {
for collectionID, partitionIDs := range collectionIDToPartIDs {
for _, partID := range partitionIDs {
if collectionID == 0 || partID == 0 || dbID == util.InvalidDBID {
continue
}
partitionRateLimiters := m.rateLimiter.GetOrCreatePartitionLimiters(dbID, collectionID, partID,
newDatabaseLimiter, newCollectionLimiters, newPartitionLimiters)
ret = partitionRateLimiters.Check(rt, n)
Expand Down
2 changes: 2 additions & 0 deletions internal/proxy/simple_rate_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ func TestSimpleRateLimiter(t *testing.T) {
clusterRateLimiters := simpleLimiter.rateLimiter.GetRootLimiters()

collectionIDToPartIDs := map[int64][]int64{
0: {},
1: {},
2: {},
3: {},
4: {0},
}

for i := 1; i <= 3; i++ {
Expand Down
25 changes: 23 additions & 2 deletions internal/rootcoord/quota_center.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,15 @@ func (q *QuotaCenter) collectMetrics() error {
}
}

datacoordQuotaCollections := make([]int64, 0)
q.diskMu.Lock()
if dataCoordTopology.Cluster.Self.QuotaMetrics != nil {
q.dataCoordMetrics = dataCoordTopology.Cluster.Self.QuotaMetrics
for _, metricCollections := range q.dataCoordMetrics.PartitionsBinlogSize {
for metricCollection := range metricCollections {
datacoordQuotaCollections = append(datacoordQuotaCollections, metricCollection)
}
}
}
q.diskMu.Unlock()

Expand All @@ -447,7 +453,6 @@ func (q *QuotaCenter) collectMetrics() error {
}
var rangeErr error
collections.Range(func(collectionID int64) bool {
var coll *model.Collection
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
rangeErr = getErr
Expand Down Expand Up @@ -482,7 +487,23 @@ func (q *QuotaCenter) collectMetrics() error {
}
return true
})
return rangeErr
if rangeErr != nil {
return rangeErr
}
for _, collectionID := range datacoordQuotaCollections {
_, ok := q.collectionIDToDBID.Get(collectionID)
if ok {
continue
}
coll, getErr := q.meta.GetCollectionByIDWithMaxTs(context.TODO(), collectionID)
if getErr != nil {
return getErr
}
q.collectionIDToDBID.Insert(collectionID, coll.DBID)
q.collections.Insert(FormatCollectionKey(coll.DBID, coll.Name), collectionID)
}

return nil
})
// get Proxies metrics
group.Go(func() error {
Expand Down
1 change: 1 addition & 0 deletions pkg/util/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const (
DefaultDBName = "default"
DefaultDBID = int64(1)
NonDBID = int64(0)
InvalidDBID = int64(-1)

PrivilegeWord = "Privilege"
AnyWord = "*"
Expand Down

0 comments on commit 37b2f90

Please sign in to comment.