Skip to content

Commit

Permalink
feat: support instance segmentation task (#183)
Browse files Browse the repository at this point in the history
Because

- support instance segmentation task in VDP

This commit

- add instance segmentation task
  • Loading branch information
Phelan164 committed Oct 17, 2022
1 parent 585996c commit d28cfdc
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 24 deletions.
11 changes: 6 additions & 5 deletions config/config.go
Expand Up @@ -82,11 +82,12 @@ type PipelineBackendConfig struct {
}

type MaxBatchSizeConfig struct {
Unspecified int `koanf:"unspecified"`
Classification int `koanf:"classification"`
Detection int `koanf:"detection"`
Keypoint int `koanf:"keypoint"`
Ocr int `koanf:"ocr"`
Unspecified int `koanf:"unspecified"`
Classification int `koanf:"classification"`
Detection int `koanf:"detection"`
Keypoint int `koanf:"keypoint"`
Ocr int `koanf:"ocr"`
InstanceSegmentation int `koanf:"instancesegmentation"`
}

// AppConfig defines
Expand Down
3 changes: 2 additions & 1 deletion config/config.yaml
Expand Up @@ -51,4 +51,5 @@ maxbatchsizelimitation:
classification: 16
detection: 8
keypoint: 8
ocr: 2
ocr: 2
instancesegmentation: 8
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -12,7 +12,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3
github.com/iancoleman/strcase v0.2.0
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20220923153409-661a55f8a69d
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221017142212-274800e9123d
github.com/instill-ai/usage-client v0.2.1-alpha
github.com/instill-ai/x v0.2.0-alpha
github.com/knadh/koanf v1.4.3
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Expand Up @@ -715,8 +715,8 @@ github.com/imdario/mergo v0.3.10/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20220923153409-661a55f8a69d h1:DbCTM5zP8tUekPR1q15qqZ7vv48iRny/DfnSF1rx1GM=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20220923153409-661a55f8a69d/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221017142212-274800e9123d h1:4hSv0RoOTCtPuEF80FdYfL94dEeDlf8brDUh1x4nCZc=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221017142212-274800e9123d/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/usage-client v0.2.1-alpha h1:XXMCTDT2BWOgGwerOpxghzt6hW9J7/yUR1tkNRuGjjM=
github.com/instill-ai/usage-client v0.2.1-alpha/go.mod h1:ThySPYe08Jy7OpfdtCZDckm19ET39K+KXGJ4lr+rOss=
github.com/instill-ai/x v0.2.0-alpha h1:8yszKP9DE8bvSRAtEpOwqhG2wwqU3olhTqhwoiLrHfc=
Expand Down
3 changes: 2 additions & 1 deletion internal/db/migration/000001_init.up.sql
Expand Up @@ -15,7 +15,8 @@ CREATE TYPE valid_task AS ENUM (
'TASK_CLASSIFICATION',
'TASK_DETECTION',
'TASK_KEYPOINT',
'TASK_OCR'
'TASK_OCR',
'TASK_INSTANCE_SEGMENTATION'
);

CREATE TYPE valid_release_stage AS ENUM (
Expand Down
7 changes: 7 additions & 0 deletions internal/triton/const.go
Expand Up @@ -34,3 +34,10 @@ type SingleOutputUnspecifiedTaskOutput struct {
type UnspecifiedTaskOutput struct {
RawOutput []SingleOutputUnspecifiedTaskOutput
}

type InstanceSegmentationOutput struct {
Rles [][]string
Boxes [][][]float32
Scores [][]float32
Labels [][]string
}
88 changes: 88 additions & 0 deletions internal/triton/triton.go
Expand Up @@ -519,6 +519,85 @@ func postProcessKeypoint(modelInferResponse *inferenceserver.ModelInferResponse,
}, nil
}

func postProcessInstanceSegmentation(modelInferResponse *inferenceserver.ModelInferResponse, outputNameRles string, outputNameBboxes string, outputNameLabels string, outputNameScores string) (interface{}, error) {
outputTensorRles, rawOutputContentRles, err := GetOutputFromInferResponse(outputNameRles, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for RLEs")
}
if rawOutputContentRles == nil {
return nil, fmt.Errorf("Unable to find output content for RLEs")
}

outputTensorBboxes, rawOutputContentBboxes, err := GetOutputFromInferResponse(outputNameBboxes, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for boxes")
}
if rawOutputContentBboxes == nil {
return nil, fmt.Errorf("Unable to find output content for boxes")
}
outputTensorLabels, rawOutputContentLabels, err := GetOutputFromInferResponse(outputNameLabels, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for labels")
}
if rawOutputContentLabels == nil {
return nil, fmt.Errorf("Unable to find output content for labels")
}

outputDataLabels := DeserializeBytesTensor(rawOutputContentLabels, outputTensorLabels.Shape[0]*outputTensorLabels.Shape[1])
batchedOutputDataLabels, err := Reshape1DArrayStringTo2D(outputDataLabels, outputTensorLabels.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for labels")
}

outputTensorScores, rawOutputContentScores, err := GetOutputFromInferResponse(outputNameScores, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for scores")
}
if rawOutputContentScores == nil {
return nil, fmt.Errorf("Unable to find output content for scores")
}
outputDataRles := DeserializeBytesTensor(rawOutputContentRles, outputTensorRles.Shape[0]*outputTensorBboxes.Shape[1])
batchedOutputDataRles, err := Reshape1DArrayStringTo2D(outputDataRles, outputTensorRles.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for RLEs")
}

outputDataBboxes := DeserializeFloat32Tensor(rawOutputContentBboxes)
batchedOutputDataBboxes, err := Reshape1DArrayFloat32To3D(outputDataBboxes, outputTensorBboxes.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for boxes")
}

outputDataScores := DeserializeFloat32Tensor(rawOutputContentScores)
batchedOutputDataScores, err := Reshape1DArrayFloat32To2D(outputDataScores, outputTensorScores.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for scores")
}

if len(batchedOutputDataBboxes) != len(batchedOutputDataLabels) ||
len(batchedOutputDataBboxes) != len(batchedOutputDataRles) ||
len(batchedOutputDataBboxes) != len(batchedOutputDataScores) {
log.Printf("Rles output has length %v Bboxes output has length %v but labels has length %v scores have length %v",
len(batchedOutputDataRles), len(batchedOutputDataBboxes), len(batchedOutputDataLabels), len(batchedOutputDataScores))
return nil, fmt.Errorf("Inconsistent batch size for rles, bboxes, labels and scores")
}

return InstanceSegmentationOutput{
Rles: batchedOutputDataRles,
Boxes: batchedOutputDataBboxes,
Labels: batchedOutputDataLabels,
Scores: batchedOutputDataScores,
}, nil
}

func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse, modelMetadata *inferenceserver.ModelMetadataResponse, task modelPB.ModelInstance_Task) (interface{}, error) {
var (
outputs interface{}
Expand Down Expand Up @@ -564,6 +643,15 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse,
}
}

case modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION:
if len(modelMetadata.Outputs) < 4 {
return nil, fmt.Errorf("Wrong output format of instance segmentation task")
}
outputs, err = postProcessInstanceSegmentation(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name, modelMetadata.Outputs[2].Name, modelMetadata.Outputs[3].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process instance segmentation output: %w", err)
}

default:
outputs, err = postProcessUnspecifiedTask(inferResponse, modelMetadata.Outputs)
if err != nil {
Expand Down
31 changes: 31 additions & 0 deletions internal/triton/util.go
Expand Up @@ -139,6 +139,37 @@ func Reshape1DArrayFloat32To3D(array []float32, shape []int64) ([][][]float32, e
return res, nil
}

func Reshape1DArrayInt32To3D(array []int32, shape []int64) ([][][]int32, error) {
if len(array) == 0 {
return [][][]int32{}, nil
}

if len(shape) != 3 {
return nil, fmt.Errorf("Expected a 3D shape, got %vD shape %v", len(shape), shape)
}

var prod int64 = 1
for _, s := range shape {
prod *= s
}
if prod != int64(len(array)) {
return nil, fmt.Errorf("Cannot reshape array of length %v into shape %v", len(array), shape)
}

res := make([][][]int32, shape[0])
for i := int64(0); i < shape[0]; i++ {
res[i] = make([][]int32, shape[1])
for j := int64(0); j < shape[1]; j++ {
start := i*shape[1]*shape[2] + j*shape[2]
end := start + shape[2]
res[i][j] = array[start:end]
}

}

return res, nil
}

func Reshape1DArrayFloat32To4D(array []float32, shape []int64) ([][][][]float32, error) {
if len(array) == 0 {
return [][][][]float32{}, nil
Expand Down
24 changes: 14 additions & 10 deletions internal/util/const.go
Expand Up @@ -7,19 +7,23 @@ import (
)

var Tasks = map[string]modelPB.ModelInstance_Task{
"TASK_CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"TASK_DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"TASK_KEYPOINT": modelPB.ModelInstance_TASK_KEYPOINT,
"TASK_OCR": modelPB.ModelInstance_TASK_OCR,
"TASK_CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"TASK_DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"TASK_KEYPOINT": modelPB.ModelInstance_TASK_KEYPOINT,
"TASK_OCR": modelPB.ModelInstance_TASK_OCR,
"TASK_INSTANCESEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"TASK_INSTANCE_SEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
}

var Tags = map[string]modelPB.ModelInstance_Task{
"CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"IMAGE-CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"IMAGE-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OBJECT-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OCR": modelPB.ModelInstance_TASK_OCR,
"CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"IMAGE-CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"IMAGE-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OBJECT-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OCR": modelPB.ModelInstance_TASK_OCR,
"INSTANCESEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"INSTANCE_SEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
}

var Visibility = map[string]modelPB.Model_Visibility{
Expand Down
10 changes: 10 additions & 0 deletions pkg/handler/handler.go
Expand Up @@ -707,6 +707,8 @@ func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
}

if maxBatchSize > allowedMaxBatchSize {
Expand Down Expand Up @@ -840,6 +842,8 @@ func (h *handler) CreateModelBinaryFileUpload(stream modelPB.ModelService_Create
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
}

if maxBatchSize > allowedMaxBatchSize {
Expand Down Expand Up @@ -1047,6 +1051,8 @@ func createGitHubModel(h *handler, ctx context.Context, req *modelPB.CreateModel
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down Expand Up @@ -1265,6 +1271,8 @@ func createArtiVCModel(h *handler, ctx context.Context, req *modelPB.CreateModel
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down Expand Up @@ -1503,6 +1511,8 @@ func createHuggingFaceModel(h *handler, ctx context.Context, req *modelPB.Create
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down
48 changes: 44 additions & 4 deletions pkg/service/service.go
Expand Up @@ -304,10 +304,10 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas

return keypointOutputs, nil
case modelPB.ModelInstance_TASK_OCR:
detResponses := postprocessResponse.(triton.OcrOutput)
batchedOutputDataBboxes := detResponses.Boxes
batchedOutputDataTexts := detResponses.Texts
batchedOutputDataScores := detResponses.Scores
ocrResponses := postprocessResponse.(triton.OcrOutput)
batchedOutputDataBboxes := ocrResponses.Boxes
batchedOutputDataTexts := ocrResponses.Texts
batchedOutputDataScores := ocrResponses.Scores
var ocrOutputs []*modelPB.TaskOutput
for i := range batchedOutputDataBboxes {
var ocrOutput = modelPB.TaskOutput{
Expand Down Expand Up @@ -338,6 +338,46 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
ocrOutputs = append(ocrOutputs, &ocrOutput)
}
return ocrOutputs, nil

case modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION:
instanceSegmentationResponses := postprocessResponse.(triton.InstanceSegmentationOutput)
batchedOutputDataRles := instanceSegmentationResponses.Rles
batchedOutputDataBboxes := instanceSegmentationResponses.Boxes
batchedOutputDataLabels := instanceSegmentationResponses.Labels
batchedOutputDataScores := instanceSegmentationResponses.Scores
var instanceSegmentationOutputs []*modelPB.TaskOutput
for i := range batchedOutputDataBboxes {
var instanceSegmentationOutput = modelPB.TaskOutput{
Output: &modelPB.TaskOutput_InstanceSegmentation{
InstanceSegmentation: &modelPB.InstanceSegmentationOutput{
Objects: []*modelPB.InstanceSegmentationObject{},
},
},
}
for j := range batchedOutputDataBboxes[i] {
rle := batchedOutputDataRles[i][j]
box := batchedOutputDataBboxes[i][j]
label := batchedOutputDataLabels[i][j]
score := batchedOutputDataScores[i][j]
// Non-meaningful bboxes were added with coords [-1, -1, -1, -1, -1] and text "" for Triton to be able to batch Tensors
if label != "" && rle != "" {
instanceSegmentationOutput.GetInstanceSegmentation().Objects = append(instanceSegmentationOutput.GetInstanceSegmentation().Objects, &modelPB.InstanceSegmentationObject{
Rle: rle,
BoundingBox: &modelPB.BoundingBox{
Left: box[0],
Top: box[1],
Width: box[2],
Height: box[3],
},
Score: score,
Label: label,
})
}
}
instanceSegmentationOutputs = append(instanceSegmentationOutputs, &instanceSegmentationOutput)
}
return instanceSegmentationOutputs, nil

default:
outputs := postprocessResponse.([]triton.BatchUnspecifiedTaskOutputs)
var rawOutputs []*modelPB.TaskOutput
Expand Down

0 comments on commit d28cfdc

Please sign in to comment.