Skip to content

Commit

Permalink
feat(instill): adopt latest Model endpoints (#146)
Browse files Browse the repository at this point in the history
Because

- We are going to introduce the new Instill Model services and which
endpoints are not compatible with old ones.

This commit

- Adopts latest Model endpoints.
  • Loading branch information
donch1989 committed May 31, 2024
1 parent ad35e10 commit 7f2537b
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 156 deletions.
40 changes: 40 additions & 0 deletions ai/instill/v0/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package instill

import (
"context"
"crypto/tls"
"strings"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"

mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1beta"
modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
Expand Down Expand Up @@ -58,3 +61,40 @@ func stripProtocolFromURL(url string) string {
}
return url
}

func trigger(gRPCClient modelPB.ModelPublicServiceClient, vars map[string]any, modelName string, taskInputs []*modelPB.TaskInput) ([]*modelPB.TaskOutput, error) {

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()

ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(vars))

nameSplits := strings.Split(modelName, "/")

if strings.HasPrefix(modelName, "user") {
req := modelPB.TriggerUserModelRequest{
Name: strings.Join(nameSplits[0:4], "/"),
TaskInputs: taskInputs,
Version: nameSplits[5],
}

res, err := gRPCClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
return nil, err
}
return res.TaskOutputs, nil
} else {
req := modelPB.TriggerOrganizationModelRequest{
Name: strings.Join(nameSplits[0:4], "/"),
TaskInputs: taskInputs,
Version: nameSplits[5],
}

res, err := gRPCClient.TriggerOrganizationModel(ctx, &req)
if err != nil || res == nil {
return nil, err
}
return res.TaskOutputs, nil
}

}
12 changes: 2 additions & 10 deletions ai/instill/v0/image_classification.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -43,16 +41,10 @@ func (e *execution) executeImageClassification(grpcClient modelPB.ModelPublicSer
taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput})
}

req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: taskInputs,
}
ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, taskInputs)
if err != nil {
return nil, err
}
taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
12 changes: 2 additions & 10 deletions ai/instill/v0/image_to_image.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -61,16 +59,10 @@ func (e *execution) executeImageToImage(grpcClient modelPB.ModelPublicServiceCli
}

// only support batch 1
req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: []*modelPB.TaskInput{{Input: taskInput}},
}
ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, []*modelPB.TaskInput{{Input: taskInput}})
if err != nil {
return nil, err
}
taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
12 changes: 2 additions & 10 deletions ai/instill/v0/instance_segmentation.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -41,16 +39,10 @@ func (e *execution) executeInstanceSegmentation(grpcClient modelPB.ModelPublicSe
}
taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput})
}
req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: taskInputs,
}
ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, taskInputs)
if err != nil {
return nil, err
}
taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
12 changes: 2 additions & 10 deletions ai/instill/v0/keypoint_detection.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -37,16 +35,10 @@ func (e *execution) executeKeyPointDetection(grpcClient modelPB.ModelPublicServi
taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput})
}

req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: taskInputs,
}
ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, taskInputs)
if err != nil {
return nil, err
}
taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
87 changes: 44 additions & 43 deletions ai/instill/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ var (

type component struct {
base.Component

// Workaround solution
cacheDefinition *pb.ComponentDefinition
}

type execution struct {
Expand Down Expand Up @@ -82,14 +79,6 @@ func getMgmtServerURL(vars map[string]any) string {
return ""
}

// This is a workaround solution for caching the definition in memory if the model list is static.
func useStaticModelList(vars map[string]any) bool {
if v, ok := vars["__STATIC_MODEL_LIST"]; ok {
return v.(bool)
}
return false
}

func getRequestMetadata(vars map[string]any) metadata.MD {
return metadata.Pairs(
"Authorization", getHeaderAuthorization(vars),
Expand All @@ -105,11 +94,11 @@ func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*
return inputs, fmt.Errorf("invalid input")
}

gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(e.SystemVariables))
if gRPCCLientConn != nil {
defer gRPCCLientConn.Close()
// TODO, we should move this to CreateExecution
gRPCClient, gRPCCientConn := initModelPublicServiceClient(getModelServerURL(e.SystemVariables))
if gRPCCientConn != nil {
defer gRPCCientConn.Close()
}

mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getMgmtServerURL(e.SystemVariables))
if mgmtGRPCCLientConn != nil {
defer mgmtGRPCCLientConn.Close()
Expand All @@ -133,34 +122,34 @@ func (e *execution) Execute(ctx context.Context, inputs []*structpb.Struct) ([]*
nsType = "users"
}

modelName := fmt.Sprintf("%s/%s/models/%s", nsType, modelNameSplits[0], modelNameSplits[1])
modelName := fmt.Sprintf("%s/%s/models/%s/versions/%s", nsType, modelNameSplits[0], modelNameSplits[1], modelNameSplits[2])

var result []*structpb.Struct
switch e.Task {
case commonPB.Task_TASK_UNSPECIFIED.String():
result, err = e.executeUnspecified(gRPCCLient, modelName, inputs)
result, err = e.executeUnspecified(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_CLASSIFICATION.String():
result, err = e.executeImageClassification(gRPCCLient, modelName, inputs)
result, err = e.executeImageClassification(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_DETECTION.String():
result, err = e.executeObjectDetection(gRPCCLient, modelName, inputs)
result, err = e.executeObjectDetection(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_KEYPOINT.String():
result, err = e.executeKeyPointDetection(gRPCCLient, modelName, inputs)
result, err = e.executeKeyPointDetection(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_OCR.String():
result, err = e.executeOCR(gRPCCLient, modelName, inputs)
result, err = e.executeOCR(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_INSTANCE_SEGMENTATION.String():
result, err = e.executeInstanceSegmentation(gRPCCLient, modelName, inputs)
result, err = e.executeInstanceSegmentation(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_SEMANTIC_SEGMENTATION.String():
result, err = e.executeSemanticSegmentation(gRPCCLient, modelName, inputs)
result, err = e.executeSemanticSegmentation(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_TEXT_TO_IMAGE.String():
result, err = e.executeTextToImage(gRPCCLient, modelName, inputs)
result, err = e.executeTextToImage(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_TEXT_GENERATION.String():
result, err = e.executeTextGeneration(gRPCCLient, modelName, inputs)
result, err = e.executeTextGeneration(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_TEXT_GENERATION_CHAT.String():
result, err = e.executeTextGenerationChat(gRPCCLient, modelName, inputs)
result, err = e.executeTextGenerationChat(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_VISUAL_QUESTION_ANSWERING.String():
result, err = e.executeVisualQuestionAnswering(gRPCCLient, modelName, inputs)
result, err = e.executeVisualQuestionAnswering(gRPCClient, modelName, inputs)
case commonPB.Task_TASK_IMAGE_TO_IMAGE.String():
result, err = e.executeImageToImage(gRPCCLient, modelName, inputs)
result, err = e.executeImageToImage(gRPCClient, modelName, inputs)
default:
return inputs, fmt.Errorf("unsupported task: %s", e.Task)
}
Expand Down Expand Up @@ -193,12 +182,7 @@ type ModelsResp struct {
}

// Generate the `model_name` enum based on the task.
// This implementation is a temporary solution due to the incomplete feature set of Instill Model.
// We'll re-implement this after Instill Model is stable.
func (c *component) Definition(sysVars map[string]any, compConfig *base.ComponentConfig) (*pb.ComponentDefinition, error) {
if useStaticModelList(sysVars) && c.cacheDefinition != nil {
return c.cacheDefinition, nil
}
func (c *component) GetDefinition(sysVars map[string]any, compConfig *base.ComponentConfig) (*pb.ComponentDefinition, error) {

oriDef, err := c.Component.GetDefinition(nil, nil)
if err != nil {
Expand Down Expand Up @@ -240,14 +224,33 @@ func (c *component) Definition(sysVars map[string]any, compConfig *base.Componen

modelNameMap := map[string]*structpb.ListValue{}

modelName := &structpb.ListValue{}
for _, model := range models {
if _, ok := modelNameMap[model.Task.String()]; !ok {
modelNameMap[model.Task.String()] = &structpb.ListValue{}

versions := []*modelPB.ModelVersion{}
switch model.Owner.Owner.(type) {
case *mgmtPB.Owner_Organization:
resp, err := gRPCCLient.ListOrganizationModelVersions(ctx, &modelPB.ListOrganizationModelVersionsRequest{Name: model.Name})
if err != nil {
return nil, err
}
versions = resp.Versions

case *mgmtPB.Owner_User:
resp, err := gRPCCLient.ListUserModelVersions(ctx, &modelPB.ListUserModelVersionsRequest{Name: model.Name})
if err != nil {
return nil, err
}
versions = resp.Versions
}
namePaths := strings.Split(model.Name, "/")
modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))

for _, version := range versions {
if _, ok := modelNameMap[model.Task.String()]; !ok {
modelNameMap[model.Task.String()] = &structpb.ListValue{}
}
namePaths := strings.Split(version.Name, "/")
modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s/%s", namePaths[1], namePaths[3], namePaths[5])))
}

}
for _, sch := range def.Spec.ComponentSpecification.Fields["oneOf"].GetListValue().Values {
task := sch.GetStructValue().Fields["properties"].GetStructValue().Fields["task"].GetStructValue().Fields["const"].GetStringValue()
Expand All @@ -256,9 +259,7 @@ func (c *component) Definition(sysVars map[string]any, compConfig *base.Componen
}

}
if useStaticModelList(sysVars) {
c.cacheDefinition = def
}

return def, nil
}

Expand Down
14 changes: 2 additions & 12 deletions ai/instill/v0/object_detection.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -43,18 +41,10 @@ func (e *execution) executeObjectDetection(grpcClient modelPB.ModelPublicService
taskInputs = append(taskInputs, &modelPB.TaskInput{Input: modelInput})
}

req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: taskInputs,
}

ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, taskInputs)
if err != nil {
return nil, err
}

taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
12 changes: 2 additions & 10 deletions ai/instill/v0/ocr.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package instill

import (
"context"
"fmt"

"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"

Expand Down Expand Up @@ -37,16 +35,10 @@ func (e *execution) executeOCR(grpcClient modelPB.ModelPublicServiceClient, mode
}

// only support batch 1
req := modelPB.TriggerUserModelRequest{
Name: modelName,
TaskInputs: []*modelPB.TaskInput{{Input: taskInput}},
}
ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
res, err := grpcClient.TriggerUserModel(ctx, &req)
if err != nil || res == nil {
taskOutputs, err := trigger(grpcClient, e.SystemVariables, modelName, []*modelPB.TaskInput{{Input: taskInput}})
if err != nil {
return nil, err
}
taskOutputs := res.GetTaskOutputs()
if len(taskOutputs) <= 0 {
return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
}
Expand Down
Loading

0 comments on commit 7f2537b

Please sign in to comment.