Skip to content

Commit

Permalink
fix: keep format for empty inference output (#258)
Browse files Browse the repository at this point in the history
Because

- inference output format needs to be consistency even when having an empty output

This commit

- add empty output in case of having no result
  • Loading branch information
Phelan164 committed Feb 17, 2023
1 parent 767ec45 commit e2a2e48
Showing 1 changed file with 79 additions and 2 deletions.
81 changes: 79 additions & 2 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,13 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
return nil, fmt.Errorf("unable to decode inference output")
}
}

if len(clsOutputs) == 0 {
clsOutputs = append(clsOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Classification{
Classification: &modelPB.ClassificationOutput{},
},
})
}
return clsOutputs, nil
case modelPB.ModelInstance_TASK_DETECTION:
detResponses := postprocessResponse.(triton.DetectionOutput)
Expand Down Expand Up @@ -295,6 +301,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
}
detOutputs = append(detOutputs, &detOutput)
}
if len(detOutputs) == 0 {
detOutputs = append(detOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Detection{
Detection: &modelPB.DetectionOutput{
Objects: []*modelPB.DetectionObject{},
},
},
})
}
return detOutputs, nil
case modelPB.ModelInstance_TASK_KEYPOINT:
keypointResponse := postprocessResponse.(triton.KeypointOutput)
Expand Down Expand Up @@ -336,7 +351,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
},
})
}

if len(keypointOutputs) == 0 {
keypointOutputs = append(keypointOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Keypoint{
Keypoint: &modelPB.KeypointOutput{
Objects: []*modelPB.KeypointObject{},
},
},
})
}
return keypointOutputs, nil
case modelPB.ModelInstance_TASK_OCR:
ocrResponses := postprocessResponse.(triton.OcrOutput)
Expand Down Expand Up @@ -372,6 +395,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
}
ocrOutputs = append(ocrOutputs, &ocrOutput)
}
if len(ocrOutputs) == 0 {
ocrOutputs = append(ocrOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Ocr{
Ocr: &modelPB.OcrOutput{
Objects: []*modelPB.OcrObject{},
},
},
})
}
return ocrOutputs, nil

case modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION:
Expand Down Expand Up @@ -411,6 +443,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
}
instanceSegmentationOutputs = append(instanceSegmentationOutputs, &instanceSegmentationOutput)
}
if len(instanceSegmentationOutputs) == 0 {
instanceSegmentationOutputs = append(instanceSegmentationOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_InstanceSegmentation{
InstanceSegmentation: &modelPB.InstanceSegmentationOutput{
Objects: []*modelPB.InstanceSegmentationObject{},
},
},
})
}
return instanceSegmentationOutputs, nil

case modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
Expand Down Expand Up @@ -439,6 +480,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
}
semanticSegmentationOutputs = append(semanticSegmentationOutputs, &semanticSegmentationOutput)
}
if len(semanticSegmentationOutputs) == 0 {
semanticSegmentationOutputs = append(semanticSegmentationOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_SemanticSegmentation{
SemanticSegmentation: &modelPB.SemanticSegmentationOutput{
Stuffs: []*modelPB.SemanticSegmentationStuff{},
},
},
})
}
return semanticSegmentationOutputs, nil
case modelPB.ModelInstance_TASK_TEXT_TO_IMAGE:
textToImageResponses := postprocessResponse.(triton.TextToImageOutput)
Expand All @@ -455,6 +505,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,

textToImageOutputs = append(textToImageOutputs, &textToImageOutput)
}
if len(textToImageOutputs) == 0 {
textToImageOutputs = append(textToImageOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_TextToImage{
TextToImage: &modelPB.TextToImageOutput{
Images: []string{},
},
},
})
}
return textToImageOutputs, nil
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGenerationResponses := postprocessResponse.(triton.TextGenerationOutput)
Expand All @@ -471,6 +530,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,

textGenerationOutputs = append(textGenerationOutputs, &textGenerationOutput)
}
if len(textGenerationOutputs) == 0 {
textGenerationOutputs = append(textGenerationOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_TextGeneration{
TextGeneration: &modelPB.TextGenerationOutput{
Text: "",
},
},
})
}
return textGenerationOutputs, nil
default:
outputs := postprocessResponse.([]triton.BatchUnspecifiedTaskOutputs)
Expand Down Expand Up @@ -521,6 +589,15 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
},
})
}
if len(rawOutputs) == 0 {
rawOutputs = append(rawOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Unspecified{
Unspecified: &modelPB.UnspecifiedOutput{
RawOutputs: []*structpb.Struct{},
},
},
})
}
return rawOutputs, nil
}
}
Expand Down

0 comments on commit e2a2e48

Please sign in to comment.