From 73ba689103a37a07ea2eb529fe2bd33f13a7369f Mon Sep 17 00:00:00 2001 From: PowderLi Date: Tue, 2 Jul 2024 22:10:11 +0800 Subject: [PATCH] add hook for restful v2 Signed-off-by: PowderLi --- .../proxy/httpserver/handler_v2.go | 120 +++++++++--------- internal/proxy/hook_interceptor.go | 33 +++-- 2 files changed, 83 insertions(+), 70 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 9e5f6aa9cfac2..60e68db897277 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -246,7 +246,7 @@ func checkAuthorizationV2(ctx context.Context, c *gin.Context, ignoreErr bool, r return nil } -func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) { +func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, ignoreErr bool, fullMethod string, handler func(reqCtx context.Context, req any) (any, error)) (interface{}, error) { if baseGetter, ok := req.(BaseGetter); ok { span := trace.SpanFromContext(ctx) span.AddEvent(baseGetter.GetBase().GetMsgType().String()) @@ -258,7 +258,11 @@ func wrapperProxy(ctx context.Context, c *gin.Context, req any, checkAuth bool, } } log.Ctx(ctx).Debug("high level restful api, try to do a grpc call", zap.Any("grpcRequest", req)) - response, err := handler(ctx, req) + username, ok := c.Get(ContextUsername) + if !ok { + username = "" + } + response, err := proxy.RestfulHookInterceptor()(ctx, req, username.(string), fullMethod, handler) if err == nil { status, ok := requestutil.GetStatusFromResponse(response) if ok { @@ -279,7 +283,7 @@ func (h *HandlersV2) wrapperCheckDatabase(v2 handlerFuncV2) handlerFuncV2 { if dbName == DefaultDbName || proxy.CheckDatabase(ctx, dbName) { return v2(ctx, c, req, dbName) } - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListDatabases", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListDatabases(reqCtx, &milvuspb.ListDatabasesRequest{}) }) if err != nil { @@ -309,7 +313,7 @@ func (h *HandlersV2) hasCollection(ctx context.Context, c *gin.Context, anyReq a DbName: dbName, CollectionName: collectionName, } - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/HasCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.HasCollection(reqCtx, req.(*milvuspb.HasCollectionRequest)) }) if err != nil { @@ -326,7 +330,7 @@ func (h *HandlersV2) listCollections(ctx context.Context, c *gin.Context, anyReq DbName: dbName, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ShowCollections", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ShowCollections(reqCtx, req.(*milvuspb.ShowCollectionsRequest)) }) if err == nil { @@ -343,7 +347,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a CollectionName: collectionName, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeCollection", func(reqCtx context.Context, req any) (any, error) { return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest)) }) if err != nil { @@ -362,7 +366,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a DbName: dbName, CollectionName: collectionName, } - stateResp, err := wrapperProxy(ctx, c, loadStateReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) { + stateResp, err := wrapperProxy(ctx, c, loadStateReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/GetLoadState", func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetLoadState(reqCtx, req.(*milvuspb.GetLoadStateRequest)) }) collLoadState := "" @@ -384,7 +388,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a CollectionName: collectionName, FieldName: vectorField, } - indexResp, err := wrapperProxy(ctx, c, descIndexReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) { + indexResp, err := wrapperProxy(ctx, c, descIndexReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (any, error) { return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest)) }) if err == nil { @@ -397,7 +401,7 @@ func (h *HandlersV2) getCollectionDetails(ctx context.Context, c *gin.Context, a DbName: dbName, CollectionName: collectionName, } - aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, true, func(reqCtx context.Context, req any) (interface{}, error) { + aliasResp, err := wrapperProxy(ctx, c, aliasReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/ListAliases", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest)) }) if err == nil { @@ -436,7 +440,7 @@ func (h *HandlersV2) getCollectionStats(ctx context.Context, c *gin.Context, any CollectionName: collectionGetter.GetCollectionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetCollectionStatistics", func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetCollectionStatistics(reqCtx, req.(*milvuspb.GetCollectionStatisticsRequest)) }) if err == nil { @@ -452,7 +456,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context, CollectionName: collectionGetter.GetCollectionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetLoadState", func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetLoadState(reqCtx, req.(*milvuspb.GetLoadStateRequest)) }) if err != nil { @@ -474,7 +478,7 @@ func (h *HandlersV2) getCollectionLoadState(ctx context.Context, c *gin.Context, PartitionNames: partitionsGetter.GetPartitionNames(), DbName: dbName, } - progressResp, err := wrapperProxy(ctx, c, progressReq, h.checkAuth, true, func(reqCtx context.Context, req any) (any, error) { + progressResp, err := wrapperProxy(ctx, c, progressReq, h.checkAuth, true, "/milvus.proto.milvus.MilvusService/GetLoadingProgress", func(reqCtx context.Context, req any) (any, error) { return h.proxy.GetLoadingProgress(reqCtx, req.(*milvuspb.GetLoadingProgressRequest)) }) progress := int64(-1) @@ -502,7 +506,7 @@ func (h *HandlersV2) dropCollection(ctx context.Context, c *gin.Context, anyReq CollectionName: getter.GetCollectionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropCollection(reqCtx, req.(*milvuspb.DropCollectionRequest)) }) if err == nil { @@ -523,7 +527,7 @@ func (h *HandlersV2) renameCollection(ctx context.Context, c *gin.Context, anyRe if req.NewDBName == "" { req.NewDBName = dbName } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/RenameCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.RenameCollection(reqCtx, req.(*milvuspb.RenameCollectionRequest)) }) if err == nil { @@ -539,7 +543,7 @@ func (h *HandlersV2) loadCollection(ctx context.Context, c *gin.Context, anyReq CollectionName: getter.GetCollectionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.LoadCollection(reqCtx, req.(*milvuspb.LoadCollectionRequest)) }) if err == nil { @@ -555,7 +559,7 @@ func (h *HandlersV2) releaseCollection(ctx context.Context, c *gin.Context, anyR CollectionName: getter.GetCollectionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleaseCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ReleaseCollection(reqCtx, req.(*milvuspb.ReleaseCollectionRequest)) }) if err == nil { @@ -587,7 +591,7 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa if httpReq.Limit > 0 && !matchCountRule(httpReq.OutputFields) { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest)) }) if err == nil { @@ -635,7 +639,7 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName UseDefaultConsistency: true, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Query", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Query(reqCtx, req.(*milvuspb.QueryRequest)) }) if err == nil { @@ -684,7 +688,7 @@ func (h *HandlersV2) delete(ctx context.Context, c *gin.Context, anyReq any, dbN } req.Expr = filter } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Delete", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Delete(reqCtx, req.(*milvuspb.DeleteRequest)) }) if err == nil { @@ -730,7 +734,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN }) return nil, err } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Insert", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Insert(reqCtx, req.(*milvuspb.InsertRequest)) }) if err == nil { @@ -808,7 +812,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN }) return nil, err } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Upsert", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Upsert(reqCtx, req.(*milvuspb.UpsertRequest)) }) if err == nil { @@ -953,7 +957,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN } req.SearchParams = searchParams req.PlaceholderGroup = placeholderGroup - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/Search", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.Search(reqCtx, req.(*milvuspb.SearchRequest)) }) if err == nil { @@ -1033,7 +1037,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq {Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}, {Key: ParamRoundDecimal, Value: "-1"}, } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest)) }) if err == nil { @@ -1242,7 +1246,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe Value: fmt.Sprintf("%v", httpReq.Params["ttlSeconds"]), }) } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateCollection(reqCtx, req.(*milvuspb.CreateCollectionRequest)) }) if err != nil { @@ -1259,7 +1263,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe IndexName: httpReq.VectorFieldName, ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}}, } - statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest)) }) if err != nil { @@ -1288,7 +1292,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe for key, value := range indexParam.Params { createIndexReq.ExtraParams = append(createIndexReq.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) } - statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + statusResponse, err := wrapperProxy(ctx, c, createIndexReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateIndex(ctx, req.(*milvuspb.CreateIndexRequest)) }) if err != nil { @@ -1300,7 +1304,7 @@ func (h *HandlersV2) createCollection(ctx context.Context, c *gin.Context, anyRe DbName: dbName, CollectionName: httpReq.CollectionName, } - statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + statusResponse, err := wrapperProxy(ctx, c, loadReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.LoadCollection(ctx, req.(*milvuspb.LoadCollectionRequest)) }) if err == nil { @@ -1317,7 +1321,7 @@ func (h *HandlersV2) listPartitions(ctx context.Context, c *gin.Context, anyReq } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ShowPartitions", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ShowPartitions(reqCtx, req.(*milvuspb.ShowPartitionsRequest)) }) if err == nil { @@ -1335,7 +1339,7 @@ func (h *HandlersV2) hasPartitions(ctx context.Context, c *gin.Context, anyReq a PartitionName: partitionGetter.GetPartitionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HasPartition", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.HasPartition(reqCtx, req.(*milvuspb.HasPartitionRequest)) }) if err == nil { @@ -1355,7 +1359,7 @@ func (h *HandlersV2) statsPartition(ctx context.Context, c *gin.Context, anyReq PartitionName: partitionGetter.GetPartitionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/GetPartitionStatistics", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.GetPartitionStatistics(reqCtx, req.(*milvuspb.GetPartitionStatisticsRequest)) }) if err == nil { @@ -1373,7 +1377,7 @@ func (h *HandlersV2) createPartition(ctx context.Context, c *gin.Context, anyReq PartitionName: partitionGetter.GetPartitionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreatePartition", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreatePartition(reqCtx, req.(*milvuspb.CreatePartitionRequest)) }) if err == nil { @@ -1391,7 +1395,7 @@ func (h *HandlersV2) dropPartition(ctx context.Context, c *gin.Context, anyReq a PartitionName: partitionGetter.GetPartitionName(), } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropPartition", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropPartition(reqCtx, req.(*milvuspb.DropPartitionRequest)) }) if err == nil { @@ -1408,7 +1412,7 @@ func (h *HandlersV2) loadPartitions(ctx context.Context, c *gin.Context, anyReq PartitionNames: httpReq.PartitionNames, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/LoadPartitions", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.LoadPartitions(reqCtx, req.(*milvuspb.LoadPartitionsRequest)) }) if err == nil { @@ -1425,7 +1429,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR PartitionNames: httpReq.PartitionNames, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ReleasePartitions", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ReleasePartitions(reqCtx, req.(*milvuspb.ReleasePartitionsRequest)) }) if err == nil { @@ -1437,7 +1441,7 @@ func (h *HandlersV2) releasePartitions(ctx context.Context, c *gin.Context, anyR func (h *HandlersV2) listUsers(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { req := &milvuspb.ListCredUsersRequest{} c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ListCredUsers", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListCredUsers(reqCtx, req.(*milvuspb.ListCredUsersRequest)) }) if err == nil { @@ -1457,7 +1461,7 @@ func (h *HandlersV2) describeUser(ctx context.Context, c *gin.Context, anyReq an } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectUser", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.SelectUser(reqCtx, req.(*milvuspb.SelectUserRequest)) }) if err == nil { @@ -1480,7 +1484,7 @@ func (h *HandlersV2) createUser(ctx context.Context, c *gin.Context, anyReq any, Username: httpReq.UserName, Password: crypto.Base64Encode(httpReq.Password), } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateCredential", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateCredential(reqCtx, req.(*milvuspb.CreateCredentialRequest)) }) if err == nil { @@ -1496,7 +1500,7 @@ func (h *HandlersV2) updateUser(ctx context.Context, c *gin.Context, anyReq any, OldPassword: crypto.Base64Encode(httpReq.Password), NewPassword: crypto.Base64Encode(httpReq.NewPassword), } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/UpdateCredential", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.UpdateCredential(reqCtx, req.(*milvuspb.UpdateCredentialRequest)) }) if err == nil { @@ -1510,7 +1514,7 @@ func (h *HandlersV2) dropUser(ctx context.Context, c *gin.Context, anyReq any, d req := &milvuspb.DeleteCredentialRequest{ Username: getter.GetUserName(), } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DeleteCredential", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DeleteCredential(reqCtx, req.(*milvuspb.DeleteCredentialRequest)) }) if err == nil { @@ -1525,7 +1529,7 @@ func (h *HandlersV2) operateRoleToUser(ctx context.Context, c *gin.Context, user RoleName: roleName, Type: operateType, } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperateUserRole", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.OperateUserRole(reqCtx, req.(*milvuspb.OperateUserRoleRequest)) }) if err == nil { @@ -1544,7 +1548,7 @@ func (h *HandlersV2) removeRoleFromUser(ctx context.Context, c *gin.Context, any func (h *HandlersV2) listRoles(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { req := &milvuspb.SelectRoleRequest{} - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectRole", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.SelectRole(reqCtx, req.(*milvuspb.SelectRoleRequest)) }) if err == nil { @@ -1562,7 +1566,7 @@ func (h *HandlersV2) describeRole(ctx context.Context, c *gin.Context, anyReq an req := &milvuspb.SelectGrantRequest{ Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: getter.GetRoleName()}, DbName: dbName}, } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/SelectGrant", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.SelectGrant(reqCtx, req.(*milvuspb.SelectGrantRequest)) }) if err == nil { @@ -1587,7 +1591,7 @@ func (h *HandlersV2) createRole(ctx context.Context, c *gin.Context, anyReq any, req := &milvuspb.CreateRoleRequest{ Entity: &milvuspb.RoleEntity{Name: getter.GetRoleName()}, } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateRole", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateRole(reqCtx, req.(*milvuspb.CreateRoleRequest)) }) if err == nil { @@ -1601,7 +1605,7 @@ func (h *HandlersV2) dropRole(ctx context.Context, c *gin.Context, anyReq any, d req := &milvuspb.DropRoleRequest{ RoleName: getter.GetRoleName(), } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropRole", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropRole(reqCtx, req.(*milvuspb.DropRoleRequest)) }) if err == nil { @@ -1623,7 +1627,7 @@ func (h *HandlersV2) operatePrivilegeToRole(ctx context.Context, c *gin.Context, }, Type: operateType, } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperatePrivilege", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.OperatePrivilege(reqCtx, req.(*milvuspb.OperatePrivilegeRequest)) }) if err == nil { @@ -1649,7 +1653,7 @@ func (h *HandlersV2) listIndexes(ctx context.Context, c *gin.Context, anyReq any } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (any, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (any, error) { resp, err := h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest)) if errors.Is(err, merr.ErrIndexNotFound) { return &milvuspb.DescribeIndexResponse{ @@ -1683,7 +1687,7 @@ func (h *HandlersV2) describeIndex(ctx context.Context, c *gin.Context, anyReq a } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeIndex", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DescribeIndex(reqCtx, req.(*milvuspb.DescribeIndexRequest)) }) if err == nil { @@ -1733,7 +1737,7 @@ func (h *HandlersV2) createIndex(ctx context.Context, c *gin.Context, anyReq any for key, value := range indexParam.Params { req.ExtraParams = append(req.ExtraParams, &commonpb.KeyValuePair{Key: key, Value: fmt.Sprintf("%v", value)}) } - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateIndex", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateIndex(reqCtx, req.(*milvuspb.CreateIndexRequest)) }) if err != nil { @@ -1754,7 +1758,7 @@ func (h *HandlersV2) dropIndex(ctx context.Context, c *gin.Context, anyReq any, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropIndex", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropIndex(reqCtx, req.(*milvuspb.DropIndexRequest)) }) if err == nil { @@ -1771,7 +1775,7 @@ func (h *HandlersV2) listAlias(ctx context.Context, c *gin.Context, anyReq any, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ListAliases", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListAliases(reqCtx, req.(*milvuspb.ListAliasesRequest)) }) if err == nil { @@ -1788,7 +1792,7 @@ func (h *HandlersV2) describeAlias(ctx context.Context, c *gin.Context, anyReq a } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeAlias", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DescribeAlias(reqCtx, req.(*milvuspb.DescribeAliasRequest)) }) if err == nil { @@ -1812,7 +1816,7 @@ func (h *HandlersV2) createAlias(ctx context.Context, c *gin.Context, anyReq any } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateAlias", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.CreateAlias(reqCtx, req.(*milvuspb.CreateAliasRequest)) }) if err == nil { @@ -1829,7 +1833,7 @@ func (h *HandlersV2) dropAlias(ctx context.Context, c *gin.Context, anyReq any, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropAlias", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DropAlias(reqCtx, req.(*milvuspb.DropAliasRequest)) }) if err == nil { @@ -1848,7 +1852,7 @@ func (h *HandlersV2) alterAlias(ctx context.Context, c *gin.Context, anyReq any, } c.Set(ContextRequest, req) - resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/AlterAlias", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.AlterAlias(reqCtx, req.(*milvuspb.AlterAliasRequest)) }) if err == nil { @@ -1877,7 +1881,7 @@ func (h *HandlersV2) listImportJob(ctx context.Context, c *gin.Context, anyReq a return nil, err } } - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ListImports", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ListImports(reqCtx, req.(*internalpb.ListImportsRequest)) }) if err == nil { @@ -1930,7 +1934,7 @@ func (h *HandlersV2) createImportJob(ctx context.Context, c *gin.Context, anyReq return nil, err } } - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/ImportV2", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.ImportV2(reqCtx, req.(*internalpb.ImportRequest)) }) if err == nil { @@ -1957,7 +1961,7 @@ func (h *HandlersV2) getImportJobProcess(ctx context.Context, c *gin.Context, an return nil, err } } - resp, err := wrapperProxy(ctx, c, req, false, false, func(reqCtx context.Context, req any) (interface{}, error) { + resp, err := wrapperProxy(ctx, c, req, false, false, "/milvus.proto.milvus.MilvusService/GetImportProgress", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.GetImportProgress(reqCtx, req.(*internalpb.GetImportProgressRequest)) }) if err == nil { @@ -2008,7 +2012,7 @@ func (h *HandlersV2) GetCollectionSchema(ctx context.Context, c *gin.Context, db DbName: dbName, CollectionName: collectionName, } - descResp, err := wrapperProxy(ctx, c, descReq, h.checkAuth, false, func(reqCtx context.Context, req any) (interface{}, error) { + descResp, err := wrapperProxy(ctx, c, descReq, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeCollection", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.DescribeCollection(reqCtx, req.(*milvuspb.DescribeCollectionRequest)) }) if err != nil { diff --git a/internal/proxy/hook_interceptor.go b/internal/proxy/hook_interceptor.go index 448e6d217c98e..447b512618b24 100644 --- a/internal/proxy/hook_interceptor.go +++ b/internal/proxy/hook_interceptor.go @@ -17,22 +17,31 @@ import ( var hoo hook.Hook +type restfulHandler func(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) + func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { - hookutil.InitOnceHook() - hoo = hookutil.Hoo return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return RestfulHookInterceptor()(ctx, req, getCurrentUser(ctx), info.FullMethod, handler) + } +} + +func RestfulHookInterceptor() restfulHandler { + return func(ctx context.Context, req any, userName, fullMethod string, handler grpc.UnaryHandler) (interface{}, error) { + if hoo == nil { + hookutil.InitOnceHook() + hoo = hookutil.Hoo + } var ( - fullMethod = info.FullMethod - newCtx context.Context - isMock bool - mockResp interface{} - realResp interface{} - realErr error - err error + newCtx context.Context + isMock bool + mockResp interface{} + realResp interface{} + realErr error + err error ) if isMock, mockResp, err = hoo.Mock(ctx, req, fullMethod); isMock { - log.Info("hook mock", zap.String("user", getCurrentUser(ctx)), + log.Info("hook mock", zap.String("user", userName), zap.String("full method", fullMethod), zap.Error(err)) metrics.ProxyHookFunc.WithLabelValues(metrics.HookMock, fullMethod).Inc() updateProxyFunctionCallMetric(fullMethod) @@ -40,7 +49,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { } if newCtx, err = hoo.Before(ctx, req, fullMethod); err != nil { - log.Warn("hook before error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod), + log.Warn("hook before error", zap.String("user", userName), zap.String("full method", fullMethod), zap.Any("request", req), zap.Error(err)) metrics.ProxyHookFunc.WithLabelValues(metrics.HookBefore, fullMethod).Inc() updateProxyFunctionCallMetric(fullMethod) @@ -48,7 +57,7 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { } realResp, realErr = handler(newCtx, req) if err = hoo.After(newCtx, realResp, realErr, fullMethod); err != nil { - log.Warn("hook after error", zap.String("user", getCurrentUser(ctx)), zap.String("full method", fullMethod), + log.Warn("hook after error", zap.String("user", userName), zap.String("full method", fullMethod), zap.Any("request", req), zap.Error(err)) metrics.ProxyHookFunc.WithLabelValues(metrics.HookAfter, fullMethod).Inc() updateProxyFunctionCallMetric(fullMethod)