Skip to content

Commit

Permalink
feat(handler): implement get latest operation (#589)
Browse files Browse the repository at this point in the history
Because

- Model overview page needs to retrieve input and output data from
latest model async trigger

This commit

- implement methods to get latest model operation with embedded request
and response
  • Loading branch information
heiruwu committed Jun 3, 2024
1 parent e23ba99 commit 33d2395
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 10 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
github.com/iancoleman/strcase v0.3.0
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240528180658-a8ebced10a42
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240603193226-becf22655052
github.com/instill-ai/usage-client v0.2.4-alpha.0.20240123081026-6c78d9a5197a
github.com/instill-ai/x v0.4.0-alpha
github.com/knadh/koanf v1.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1314,8 +1314,8 @@ github.com/imdario/mergo v0.3.10/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240528180658-a8ebced10a42 h1:17S3vNvi+TzvYqZFpwHhGhiK03MHaiVoG+TblOWZjLQ=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240528180658-a8ebced10a42/go.mod h1:2blmpUwiTwxIDnrjIqT6FhR5ewshZZF554wzjXFvKpQ=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240603193226-becf22655052 h1:/LUY1yR6oJpr2WCiguqE5P7D7klA2tUcRkPnJTRXAGQ=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20240603193226-becf22655052/go.mod h1:2blmpUwiTwxIDnrjIqT6FhR5ewshZZF554wzjXFvKpQ=
github.com/instill-ai/usage-client v0.2.4-alpha.0.20240123081026-6c78d9a5197a h1:gmy8BcCFDZQan40c/D3f62DwTYtlCwi0VrSax+pKffw=
github.com/instill-ai/usage-client v0.2.4-alpha.0.20240123081026-6c78d9a5197a/go.mod h1:EpX3Yr661uWULtZf5UnJHfr5rw2PDyX8ku4Kx0UtYFw=
github.com/instill-ai/x v0.4.0-alpha h1:zQV2VLbSHjMv6gyBN/2mwwrvWk0/mJM6ZKS12AzjfQg=
Expand Down
66 changes: 66 additions & 0 deletions pkg/handler/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ package handler
import (
"context"

"cloud.google.com/go/longrunning/autogen/longrunningpb"
"go.opentelemetry.io/otel/trace"

"github.com/gofrs/uuid"
"github.com/instill-ai/model-backend/internal/resource"

custom_logger "github.com/instill-ai/model-backend/pkg/logger"
custom_otel "github.com/instill-ai/model-backend/pkg/logger/otel"
modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
)

Expand All @@ -29,3 +33,65 @@ func (h *PublicHandler) GetModelOperation(ctx context.Context, req *modelPB.GetM
Operation: operation,
}, nil
}

type GetNamespaceLatestModelOperationRequestInterface interface {
GetName() string
GetView() modelPB.View
}

func (h *PublicHandler) GetUserLatestModelOperation(ctx context.Context, req *modelPB.GetUserLatestModelOperationRequest) (resp *modelPB.GetUserLatestModelOperationResponse, err error) {

resp = &modelPB.GetUserLatestModelOperationResponse{}

resp.Operation, err = h.getNamespaceLatestModelOperation(ctx, req)

return resp, err
}

func (h *PublicHandler) GetOrganizationLatestModelOperation(ctx context.Context, req *modelPB.GetOrganizationLatestModelOperationRequest) (resp *modelPB.GetOrganizationLatestModelOperationResponse, err error) {

resp = &modelPB.GetOrganizationLatestModelOperationResponse{}

resp.Operation, err = h.getNamespaceLatestModelOperation(ctx, req)

return resp, err
}

func (h *PublicHandler) getNamespaceLatestModelOperation(ctx context.Context, req GetNamespaceLatestModelOperationRequestInterface) (*longrunningpb.Operation, error) {
eventName := "GetNamespaceLatestModelOperation"

ctx, span := tracer.Start(ctx, eventName,
trace.WithSpanKind(trace.SpanKindServer))
defer span.End()

logUUID, _ := uuid.NewV4()

logger, _ := custom_logger.GetZapLogger(ctx)

ns, modelID, err := h.service.GetRscNamespaceAndNameID(req.GetName())
if err != nil {
span.SetStatus(1, err.Error())
return nil, err
}

if err := authenticateUser(ctx, false); err != nil {
span.SetStatus(1, err.Error())
logger.Info(string(custom_otel.NewLogMessage(
ctx,
span,
logUUID.String(),
eventName,
custom_otel.SetEventResource(req.GetName()),
custom_otel.SetErrorMessage(err.Error()),
)))
return nil, err
}

operation, err := h.service.GetNamespaceLatestModelOperation(ctx, ns, modelID, req.GetView())
if err != nil {
span.SetStatus(1, err.Error())
return nil, err
}

return operation, nil
}
23 changes: 23 additions & 0 deletions pkg/handler/trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoreflect"

"github.com/instill-ai/model-backend/config"
"github.com/instill-ai/model-backend/internal/resource"
Expand Down Expand Up @@ -328,6 +330,7 @@ func (h *PublicHandler) triggerNamespaceModel(ctx context.Context, req TriggerNa
}

type TriggerAsyncNamespaceModelRequestInterface interface {
protoreflect.ProtoMessage
GetName() string
GetVersion() string
GetTaskInputs() []*modelPB.TaskInput
Expand Down Expand Up @@ -435,6 +438,18 @@ func (h *PublicHandler) triggerAsyncNamespaceModel(ctx context.Context, req Trig

userUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)

// TODO: temporary solution to store input json
inputRequestJSON, err := protojson.Marshal(req)
if err != nil {
return nil, err
}
h.service.GetRedisClient().Set(
ctx,
fmt.Sprintf("model_trigger_input:%s:%s", userUID, pbModel.Uid),
inputRequestJSON,
0,
)

usageData := &utils.UsageMetricData{
OwnerUID: ns.NsUID.String(),
OwnerType: mgmtPB.OwnerType_OWNER_TYPE_USER,
Expand Down Expand Up @@ -589,6 +604,14 @@ func (h *PublicHandler) triggerAsyncNamespaceModel(ctx context.Context, req Trig
}
}

// TODO: temporary solution to store output json
h.service.GetRedisClient().Set(
ctx,
fmt.Sprintf("model_trigger_output_key:%s:%s", userUID, pbModel.Uid),
operation.GetName(),
0,
)

logger.Info(string(custom_otel.NewLogMessage(
ctx,
span,
Expand Down
5 changes: 5 additions & 0 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type Service interface {
CreateModelPrediction(ctx context.Context, prediction *datamodel.ModelPrediction) error

GetOperation(ctx context.Context, workflowID string) (*longrunningpb.Operation, error)
GetNamespaceLatestModelOperation(ctx context.Context, ns resource.Namespace, modelID string, view modelPB.View) (*longrunningpb.Operation, error)

// Private
GetModelByIDAdmin(ctx context.Context, ns resource.Namespace, modelID string, view modelPB.View) (*modelPB.Model, error)
Expand Down Expand Up @@ -705,6 +706,8 @@ func (s *service) DeleteNamespaceModelByID(ctx context.Context, ns resource.Name

ownerPermalink := ns.Permalink()

userUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)

dbModel, err := s.repository.GetNamespaceModelByID(ctx, ownerPermalink, modelID, false, false)
if err != nil {
return ErrNotFound
Expand Down Expand Up @@ -741,6 +744,8 @@ func (s *service) DeleteNamespaceModelByID(ctx context.Context, ns resource.Name
return err
}

s.redisClient.Del(ctx, fmt.Sprintf("model_trigger_input:%s:%s", userUID, dbModel.UID.String()))

return s.repository.DeleteNamespaceModelByID(ctx, ownerPermalink, dbModel.ID)
}

Expand Down
80 changes: 76 additions & 4 deletions pkg/service/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package service

import (
"context"
"errors"
"fmt"

workflowpb "go.temporal.io/api/workflow/v1"
Expand All @@ -12,6 +13,10 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb"

"github.com/go-redis/redis/v9"
"github.com/instill-ai/model-backend/internal/resource"
"github.com/instill-ai/model-backend/pkg/constant"

modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
)

Expand All @@ -21,15 +26,80 @@ func (s *service) GetOperation(ctx context.Context, workflowID string) (*longrun
return nil, err
}

return s.getOperationFromWorkflowInfo(ctx, workflowExecutionRes.WorkflowExecutionInfo)
return s.getOperationFromWorkflowInfo(ctx, workflowExecutionRes.WorkflowExecutionInfo, nil)
}

func (s *service) getOperationFromWorkflowInfo(ctx context.Context, workflowExecutionInfo *workflowpb.WorkflowExecutionInfo) (*longrunningpb.Operation, error) {
func (s *service) GetNamespaceLatestModelOperation(ctx context.Context, ns resource.Namespace, modelID string, view modelPB.View) (*longrunningpb.Operation, error) {
ownerPermalink := ns.Permalink()

userUID := resource.GetRequestSingleHeader(ctx, constant.HeaderUserUIDKey)

dbModel, err := s.repository.GetNamespaceModelByID(ctx, ownerPermalink, modelID, true, false)
if err != nil {
return nil, ErrNotFound
}

if granted, err := s.aclClient.CheckPermission(ctx, "model_", dbModel.UID, "reader"); err != nil {
return nil, err
} else if !granted {
return nil, ErrNotFound
}

if granted, err := s.aclClient.CheckPermission(ctx, "model_", dbModel.UID, "executor"); err != nil {
return nil, err
} else if !granted {
return nil, ErrNoPermission
}

triggerModelReq := &modelPB.TriggerUserModelRequest{}

inputJSON, err := s.redisClient.Get(ctx, fmt.Sprintf("model_trigger_input:%s:%s", userUID, dbModel.UID.String())).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, err
}
err = protojson.Unmarshal(inputJSON, triggerModelReq)
if err != nil {
return nil, err
}

outputWorkflowID := s.redisClient.Get(ctx, fmt.Sprintf("model_trigger_output_key:%s:%s", userUID, dbModel.UID.String())).Val()
operationID, err := resource.GetOperationID(outputWorkflowID)
if err != nil {
return nil, err
}

workflowExecutionRes, err := s.temporalClient.DescribeWorkflowExecution(ctx, operationID, "")
if err != nil {
fmt.Println(err)
return nil, err
}

operation, err := s.getOperationFromWorkflowInfo(ctx, workflowExecutionRes.WorkflowExecutionInfo, triggerModelReq)
if err != nil {
return nil, err
}

if view != modelPB.View_VIEW_FULL {
operation.Result = nil
}

return operation, nil

}

func (s *service) getOperationFromWorkflowInfo(ctx context.Context, workflowExecutionInfo *workflowpb.WorkflowExecutionInfo, triggerModelReq *modelPB.TriggerUserModelRequest) (*longrunningpb.Operation, error) {
operation := longrunningpb.Operation{}

switch workflowExecutionInfo.Status {
case enums.WORKFLOW_EXECUTION_STATUS_COMPLETED:

latestOperation := &modelPB.LatestOperation{
Request: triggerModelReq,
}

triggerModelResp := &modelPB.TriggerUserModelResponse{}

blobRedisKey := fmt.Sprintf("async_model_response:%s", workflowExecutionInfo.Execution.WorkflowId)
Expand All @@ -43,11 +113,13 @@ func (s *service) getOperationFromWorkflowInfo(ctx context.Context, workflowExec
return nil, err
}

resp, err := anypb.New(triggerModelResp)
latestOperation.Response = triggerModelResp

resp, err := anypb.New(latestOperation)
if err != nil {
return nil, err
}
resp.TypeUrl = "buf.build/instill-ai/protobufs/model.model.v1alpha.TriggerUserModelResponse"
resp.TypeUrl = "buf.build/instill-ai/protobufs/model.model.v1alpha.LatestOperation"
operation = longrunningpb.Operation{
Done: true,
Result: &longrunningpb.Operation_Response{
Expand Down
17 changes: 14 additions & 3 deletions pkg/worker/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ func (w *worker) TriggerModelActivity(ctx context.Context, param *TriggerModelAc
if err != nil {
return nil, w.toApplicationError(err, param.ModelID, ModelActivityError)
}
defer func() {
w.redisClient.Del(ctx, param.ParsedInputKey)
w.redisClient.Del(ctx, param.InputKey)
w.redisClient.Expire(
ctx,
fmt.Sprintf("model_trigger_input:%s:%s", param.UserUID, param.ModelUID.String()),
time.Duration(config.Config.Server.Workflow.MaxWorkflowTimeout)*time.Second,
)
w.redisClient.Expire(
ctx,
fmt.Sprintf("model_trigger_output_key:%s:%s", param.UserUID, param.ModelUID.String()),
time.Duration(config.Config.Server.Workflow.MaxWorkflowTimeout)*time.Second,
)
}()

var inferInput InferInput
switch param.Task {
Expand Down Expand Up @@ -279,9 +293,6 @@ func (w *worker) TriggerModelActivity(ctx context.Context, param *TriggerModelAc

logger.Info("TriggerModelActivity completed")

w.redisClient.Del(ctx, param.ParsedInputKey)
w.redisClient.Del(ctx, param.InputKey)

return &TriggerModelActivityResponse{
TaskOutputBytes: jsonOutput,
OutputKey: outputKey,
Expand Down

0 comments on commit 33d2395

Please sign in to comment.