Skip to content

Commit

Permalink
fix: support multipart trigger for new tasks (#109)
Browse files Browse the repository at this point in the history
Because

- VDP supported text to image and text generation tasks

This commit

- support trigger those 2 tasks from the pipeline-backend
  • Loading branch information
Phelan164 authored and pinglin committed Mar 8, 2023
1 parent f36239a commit 0e7e9fa
Show file tree
Hide file tree
Showing 7 changed files with 666 additions and 90 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3
github.com/iancoleman/strcase v0.2.0
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230217111731-b78c700241b2
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230308122400-51986736325a
github.com/instill-ai/usage-client v0.2.2-alpha
github.com/instill-ai/x v0.2.0-alpha
github.com/knadh/koanf v1.4.3
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230217111731-b78c700241b2 h1:TLK82ewEE54IgE71Er+rY5wq7kXVSS0pQd17C6hW+34=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230217111731-b78c700241b2/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230307101000-a8348ab7c390 h1:x7wlV2IGhW8XzG/K8s9gsu1cGDEIug+eJmFvElmNxvk=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230307101000-a8348ab7c390/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230308122400-51986736325a h1:pXAll60F53JR0Tyny5Glq8DQmWuU9BOqmoSMGl2oFf4=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20230308122400-51986736325a/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/usage-client v0.2.2-alpha h1:EQyHpgzZ26TEIL9UoaqchTf+LnKaidUGhKlUEFR68I8=
github.com/instill-ai/usage-client v0.2.2-alpha/go.mod h1:RpVnioKQBoJZsE1qTiZlPQUQXUALTGzhBl8ju9rm5+U=
github.com/instill-ai/x v0.2.0-alpha h1:8yszKP9DE8bvSRAtEpOwqhG2wwqU3olhTqhwoiLrHfc=
Expand Down
10 changes: 10 additions & 0 deletions internal/constant/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,13 @@ const (
)

const MaxBatchSize int = 32

// Constants for text to image task.
const DefaultStep int = 10
const DefaultCfgScale float64 = 7.0
const DefaultSeed int = 1024
const DefaultSamples int = 1

// Constants for text generation task.
const DefaultOutputLen int = 100
const DefaultTopK int = 40
129 changes: 116 additions & 13 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handler

import (
"bytes"
"context"
"fmt"
"io"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/instill-ai/x/checkfield"

healthcheckPB "github.com/instill-ai/protogen-go/vdp/healthcheck/v1alpha"
modelPB "github.com/instill-ai/protogen-go/vdp/model/v1alpha"
pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1alpha"
)

Expand Down Expand Up @@ -480,10 +480,16 @@ func (h *handler) TriggerPipelineBinaryFileUpload(stream pipelinePB.PipelineServ
return err
}

// Read chuck
var fileNames []string
var textToImageInput service.TextToImageInput
var textGenerationInput service.TextGenerationInput

var allContentFiles []byte
var fileLengths []uint64
content := bytes.Buffer{}

var modelInstance *modelPB.ModelInstance

var firstChunk = true

for {
data, err := stream.Recv()
if err != nil {
Expand All @@ -492,21 +498,118 @@ func (h *handler) TriggerPipelineBinaryFileUpload(stream pipelinePB.PipelineServ
}
return status.Errorf(codes.Internal, "failed unexpectedly while reading chunks from stream: %s", err.Error())
}
if len(fileNames) == 0 {
fileNames = data.GetFileNames()
if firstChunk { // Get one time for first chunk.
firstChunk = false
pipelineName := data.GetName()
pipeline, err := h.service.GetPipelineByID(strings.TrimSuffix(pipelineName, "pipelines/"), owner, false)
if err != nil {
return status.Errorf(codes.Internal, "do not find the pipeline: %s", err.Error())
}
if pipeline.Recipe == nil || len(pipeline.Recipe.ModelInstances) == 0 {
return status.Errorf(codes.Internal, "there is no model instance in pipeline's recipe")
}
modelInstance, err = h.service.GetModelInstanceByName(dbPipeline.Recipe.ModelInstances[0])
if err != nil {
return status.Errorf(codes.Internal, "could not find model instance: %s", err.Error())
}

switch modelInstance.Task {
case modelPB.ModelInstance_TASK_CLASSIFICATION:
fileLengths = data.TaskInput.GetClassification().FileLengths
if data.TaskInput.GetClassification().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetClassification().GetContent()...)
}
case modelPB.ModelInstance_TASK_DETECTION:
fileLengths = data.TaskInput.GetDetection().FileLengths
if data.TaskInput.GetDetection().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetDetection().GetContent()...)
}
case modelPB.ModelInstance_TASK_KEYPOINT:
fileLengths = data.TaskInput.GetKeypoint().FileLengths
if data.TaskInput.GetKeypoint().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetKeypoint().GetContent()...)
}
case modelPB.ModelInstance_TASK_OCR:
fileLengths = data.TaskInput.GetOcr().FileLengths
if data.TaskInput.GetOcr().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetOcr().GetContent()...)
}
case modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION:
fileLengths = data.TaskInput.GetInstanceSegmentation().FileLengths
if data.TaskInput.GetInstanceSegmentation().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetInstanceSegmentation().GetContent()...)
}
case modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
fileLengths = data.TaskInput.GetSemanticSegmentation().FileLengths
if data.TaskInput.GetSemanticSegmentation().GetContent() != nil {
allContentFiles = append(allContentFiles, data.TaskInput.GetSemanticSegmentation().GetContent()...)
}
case modelPB.ModelInstance_TASK_TEXT_TO_IMAGE:
textToImageInput = service.TextToImageInput{
Prompt: data.TaskInput.GetTextToImage().GetPrompt(),
Steps: data.TaskInput.GetTextToImage().GetSteps(),
CfgScale: data.TaskInput.GetTextToImage().GetCfgScale(),
Seed: data.TaskInput.GetTextToImage().GetSeed(),
Samples: data.TaskInput.GetTextToImage().GetSamples(),
}
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGenerationInput = service.TextGenerationInput{
Prompt: data.TaskInput.GetTextGeneration().GetPrompt(),
OutputLen: data.TaskInput.GetTextGeneration().GetOutputLen(),
BadWordsList: data.TaskInput.GetTextGeneration().GetBadWordsList(),
StopWordsList: data.TaskInput.GetTextGeneration().GetStopWordsList(),
TopK: data.TaskInput.GetTextGeneration().GetTopk(),
Seed: data.TaskInput.GetTextGeneration().GetSeed(),
}
default:
return fmt.Errorf("unsupported task input type")
}
continue
}

switch modelInstance.Task {
case modelPB.ModelInstance_TASK_CLASSIFICATION:
allContentFiles = append(allContentFiles, data.TaskInput.GetClassification().Content...)
case modelPB.ModelInstance_TASK_DETECTION:
allContentFiles = append(allContentFiles, data.TaskInput.GetDetection().Content...)
case modelPB.ModelInstance_TASK_KEYPOINT:
allContentFiles = append(allContentFiles, data.TaskInput.GetKeypoint().Content...)
case modelPB.ModelInstance_TASK_OCR:
allContentFiles = append(allContentFiles, data.TaskInput.GetOcr().Content...)
case modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION:
allContentFiles = append(allContentFiles, data.TaskInput.GetInstanceSegmentation().Content...)
case modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
allContentFiles = append(allContentFiles, data.TaskInput.GetSemanticSegmentation().Content...)
default:
return fmt.Errorf("unsupported task input type")
}

}

var obj *pipelinePB.TriggerPipelineBinaryFileUploadResponse
switch modelInstance.Task {
case modelPB.ModelInstance_TASK_CLASSIFICATION,
modelPB.ModelInstance_TASK_DETECTION,
modelPB.ModelInstance_TASK_KEYPOINT,
modelPB.ModelInstance_TASK_OCR,
modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
if len(fileLengths) == 0 {
fileLengths = data.GetFileLengths()
return status.Errorf(codes.InvalidArgument, "no file lengths")
}
if data.Content == nil {
continue
if len(allContentFiles) == 0 {
return status.Errorf(codes.InvalidArgument, "no content files")
}
if _, err := content.Write(data.Content); err != nil {
return status.Errorf(codes.Internal, "failed unexpectedly while reading chunks from stream: %s", err.Error())
imageInput := service.ImageInput{
Content: allContentFiles,
FileLengths: fileLengths,
}
obj, err = h.service.TriggerPipelineBinaryFileUpload(dbPipeline, modelInstance.Task, &imageInput)
case modelPB.ModelInstance_TASK_TEXT_TO_IMAGE:
obj, err = h.service.TriggerPipelineBinaryFileUpload(dbPipeline, modelInstance.Task, &textToImageInput)
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
obj, err = h.service.TriggerPipelineBinaryFileUpload(dbPipeline, modelInstance.Task, &textGenerationInput)
}

obj, err := h.service.TriggerPipelineBinaryFileUpload(content, fileNames, fileLengths, dbPipeline)
if err != nil {
return err
}
Expand Down
73 changes: 33 additions & 40 deletions pkg/handler/handlercustom.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handler

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
Expand All @@ -15,13 +14,14 @@ import (
"google.golang.org/protobuf/encoding/protojson"

"github.com/instill-ai/pipeline-backend/config"
"github.com/instill-ai/pipeline-backend/internal/constant"
"github.com/instill-ai/pipeline-backend/internal/db"
"github.com/instill-ai/pipeline-backend/internal/external"
"github.com/instill-ai/pipeline-backend/internal/logger"
"github.com/instill-ai/pipeline-backend/internal/sterr"
"github.com/instill-ai/pipeline-backend/pkg/repository"
"github.com/instill-ai/pipeline-backend/pkg/service"

modelPB "github.com/instill-ai/protogen-go/vdp/model/v1alpha"
)

// HandleTriggerPipelineBinaryFileUpload is for POST multipart form data
Expand Down Expand Up @@ -91,9 +91,24 @@ func HandleTriggerPipelineBinaryFileUpload(w http.ResponseWriter, req *http.Requ
return
}

modelInstance, err := service.GetModelInstanceByName(dbPipeline.Recipe.ModelInstances[0])
if err != nil {
st := sterr.CreateErrorResourceInfo(
codes.NotFound,
"[handler] cannot get pipeline by id",
"pipelines",
fmt.Sprintf("id %s", id),
owner,
err.Error(),
)
errorResponse(w, st)
logger.Error(st.String())
return
}

if err := req.ParseMultipartForm(4 << 20); err != nil {
st := sterr.CreateErrorPreconditionFailure(
"[handler] error while reading file from request",
"[handler] error while get model instance information",
"TriggerPipelineBinaryFileUpload",
fmt.Sprintf("id %s", id),
err.Error(),
Expand All @@ -103,7 +118,20 @@ func HandleTriggerPipelineBinaryFileUpload(w http.ResponseWriter, req *http.Requ
return
}

content, fileNames, fileLengths, err := parseImageFormDataInputsToBytes(req)
var inp interface{}
switch modelInstance.Task {
case modelPB.ModelInstance_TASK_CLASSIFICATION,
modelPB.ModelInstance_TASK_DETECTION,
modelPB.ModelInstance_TASK_KEYPOINT,
modelPB.ModelInstance_TASK_OCR,
modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
inp, err = parseImageFormDataInputsToBytes(req)
case modelPB.ModelInstance_TASK_TEXT_TO_IMAGE:
inp, err = parseImageFormDataTextToImageInputs(req)
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
inp, err = parseTextFormDataTextGenerationInputs(req)
}
if err != nil {
st := sterr.CreateErrorPreconditionFailure(
"[handler] error while reading file from request",
Expand All @@ -116,7 +144,7 @@ func HandleTriggerPipelineBinaryFileUpload(w http.ResponseWriter, req *http.Requ
return
}

obj, err := service.TriggerPipelineBinaryFileUpload(*bytes.NewBuffer(content), fileNames, fileLengths, dbPipeline)
obj, err := service.TriggerPipelineBinaryFileUpload(dbPipeline, modelInstance.Task, inp)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
logger.Error(err.Error())
Expand Down Expand Up @@ -144,41 +172,6 @@ func HandleTriggerPipelineBinaryFileUpload(w http.ResponseWriter, req *http.Requ
}
}

func parseImageFormDataInputsToBytes(req *http.Request) (content []byte, fileNames []string, fileLengths []uint64, err error) {

inputs := req.MultipartForm.File["file"]

for _, input := range inputs {
file, err := input.Open()
defer func() {
err = file.Close()
}()

if err != nil {
return nil, nil, nil, fmt.Errorf("Unable to open file for image")
}

buff := new(bytes.Buffer)
numBytes, err := buff.ReadFrom(file)
if err != nil {
return nil, nil, nil, fmt.Errorf("Unable to read content body from image")
}
if numBytes > int64(config.Config.Server.MaxDataSize*constant.MB) {
return nil, nil, nil, fmt.Errorf(
"Image size must be smaller than %vMB. Got %vMB",
config.Config.Server.MaxDataSize,
float32(numBytes)/float32(constant.MB),
)
}

content = append(content, buff.Bytes()...)
fileNames = append(fileNames, input.Filename)
fileLengths = append(fileLengths, uint64(buff.Len()))
}

return content, fileNames, fileLengths, nil
}

func errorResponse(w http.ResponseWriter, s *status.Status) {
w.Header().Add("Content-Type", "application/problem+json")
switch {
Expand Down

0 comments on commit 0e7e9fa

Please sign in to comment.