Skip to content

Commit

Permalink
feat: add support for text generation tasks (#252)
Browse files Browse the repository at this point in the history
Because

- support text generation tasks in VDP

This commit

- add support for text generation tasks
  • Loading branch information
heiruwu committed Feb 17, 2023
1 parent 88adc61 commit 767ec45
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 2 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type MaxBatchSizeConfig struct {
Ocr int `koanf:"ocr"`
InstanceSegmentation int `koanf:"instancesegmentation"`
SemanticSegmentation int `koanf:"semanticsegmentation"`
TextGeneration int `koanf:"textgeneration"`
}

// TemporalConfig related to Temporal
Expand Down
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ maxbatchsizelimitation:
ocr: 2
instancesegmentation: 8
semanticsegmentation: 8
textgeneration: 1
temporal:
clientoptions:
hostport: temporal:7233
3 changes: 2 additions & 1 deletion internal/db/migration/000001_init.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ CREATE TYPE valid_task AS ENUM (
'TASK_OCR',
'TASK_INSTANCE_SEGMENTATION',
'TASK_SEMANTIC_SEGMENTATION',
'TASK_TEXT_TO_IMAGE'
'TASK_TEXT_TO_IMAGE',
'TASK_TEXT_GENERATION'
);

CREATE TYPE valid_release_stage AS ENUM (
Expand Down
4 changes: 4 additions & 0 deletions internal/triton/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ type SemanticSegmentationOutput struct {
type TextToImageOutput struct {
Images [][]string
}

type TextGenerationOutput struct {
Text []string
}
50 changes: 50 additions & 0 deletions internal/triton/triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ type TextToImageInput struct {
Samples int64
}

type TextGenerationInput struct {
Prompt string
OutputLen int64
BadWordsList string
StopWordsList string
TopK int64
Seed int64
}

type VisionInput struct {
ImgUrl string
ImgBase64 string
Expand Down Expand Up @@ -163,6 +172,12 @@ func (ts *triton) ModelInferRequest(task modelPB.ModelInstance_Task, inferInput
Datatype: modelMetadata.Inputs[i].Datatype,
Shape: []int64{1},
})
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
inferInputs = append(inferInputs, &inferenceserver.ModelInferRequest_InferInputTensor{
Name: modelMetadata.Inputs[i].Name,
Datatype: modelMetadata.Inputs[i].Datatype,
Shape: []int64{1, 1},
})
case modelPB.ModelInstance_TASK_CLASSIFICATION,
modelPB.ModelInstance_TASK_DETECTION,
modelPB.ModelInstance_TASK_KEYPOINT,
Expand Down Expand Up @@ -252,6 +267,20 @@ func (ts *triton) ModelInferRequest(task modelPB.ModelInstance_Task, inferInput
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, steps)
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, guidanceScale)
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, seed)
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGenerationInputs := inferInput.([]TextGenerationInput)
outputLen := make([]byte, 4)
binary.LittleEndian.PutUint32(outputLen, uint32(textGenerationInputs[0].OutputLen))
topK := make([]byte, 4)
binary.LittleEndian.PutUint32(topK, uint32(textGenerationInputs[0].TopK))
seed := make([]byte, 8)
binary.LittleEndian.PutUint64(seed, uint64(textGenerationInputs[0].Seed))
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, SerializeBytesTensor([][]byte{[]byte(textGenerationInputs[0].Prompt)}))
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, outputLen)
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, SerializeBytesTensor([][]byte{[]byte(textGenerationInputs[0].BadWordsList)}))
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, SerializeBytesTensor([][]byte{[]byte(textGenerationInputs[0].StopWordsList)}))
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, topK)
modelInferRequest.RawInputContents = append(modelInferRequest.RawInputContents, seed)
case modelPB.ModelInstance_TASK_CLASSIFICATION,
modelPB.ModelInstance_TASK_DETECTION,
modelPB.ModelInstance_TASK_KEYPOINT,
Expand Down Expand Up @@ -753,6 +782,21 @@ func postProcessTextToImage(modelInferResponse *inferenceserver.ModelInferRespon
}, nil
}

func postProcessTextGeneration(modelInferResponse *inferenceserver.ModelInferResponse, outputNameTexts string) (interface{}, error) {
outputTensorTexts, rawOutputContentTexts, err := GetOutputFromInferResponse(outputNameTexts, modelInferResponse)
if err != nil {
return nil, fmt.Errorf("unable to find inference output for generated texts")
}
if outputTensorTexts == nil {
return nil, fmt.Errorf("unable to find output content for generated texts")
}
outputTexts := DeserializeBytesTensor(rawOutputContentTexts, outputTensorTexts.Shape[0])

return TextGenerationOutput{
Text: outputTexts,
}, nil
}

func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse, modelMetadata *inferenceserver.ModelMetadataResponse, task modelPB.ModelInstance_Task) (interface{}, error) {
var (
outputs interface{}
Expand Down Expand Up @@ -822,6 +866,12 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse,
return nil, fmt.Errorf("unable to post-process text to image output: %w", err)
}

case modelPB.ModelInstance_TASK_TEXT_GENERATION:
outputs, err = postProcessTextGeneration(inferResponse, modelMetadata.Outputs[0].Name)
if err != nil {
return nil, fmt.Errorf("unable to post-process text to image output: %w", err)
}

default:
outputs, err = postProcessUnspecifiedTask(inferResponse, modelMetadata.Outputs)
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions internal/util/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ var Tasks = map[string]modelPB.ModelInstance_Task{
"TASK_SEMANTICSEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
"TASK_TEXT_TO_IMAGE": modelPB.ModelInstance_TASK_TEXT_TO_IMAGE,
"TASK_TEXTTOIMAGE": modelPB.ModelInstance_TASK_TEXT_TO_IMAGE,
"TASK_TEXT_GENERATION": modelPB.ModelInstance_TASK_TEXT_GENERATION,
"TASK_TEXTGENERATION": modelPB.ModelInstance_TASK_TEXT_GENERATION,
}

var Tags = map[string]modelPB.ModelInstance_Task{
Expand All @@ -32,6 +34,8 @@ var Tags = map[string]modelPB.ModelInstance_Task{
"SEMANTICSEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
"TEXT_TO_IMAGE": modelPB.ModelInstance_TASK_TEXT_TO_IMAGE,
"TEXTTOIMAGE": modelPB.ModelInstance_TASK_TEXT_TO_IMAGE,
"TEXT_GENERATION": modelPB.ModelInstance_TASK_TEXT_GENERATION,
"TEXTGENERATION": modelPB.ModelInstance_TASK_TEXT_GENERATION,
}

var Visibility = map[string]modelPB.Model_Visibility{
Expand Down Expand Up @@ -78,3 +82,9 @@ const (
IMAGE_TO_TEXT_SEED = int64(1024)
IMAGE_TO_TEXT_SAMPLES = int64(1)
)

const (
TEXT_GENERATION_OUTPUT_LEN = int64(100)
TEXT_GENERATION_TOP_K = int64(1)
TEXT_GENERATION_SEED = int64(0)
)
15 changes: 14 additions & 1 deletion internal/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ func findDVCPaths(dir string) []string {
func findModelFiles(dir string) []string {
var modelPaths []string = []string{}
_ = filepath.Walk(dir, func(path string, f os.FileInfo, err error) error {
if strings.HasSuffix(f.Name(), ".onnx") || strings.HasSuffix(f.Name(), ".pt") || strings.HasSuffix(f.Name(), ".bias") || strings.HasSuffix(f.Name(), ".weight") || strings.HasPrefix(f.Name(), "onnx__") {
if strings.HasSuffix(f.Name(), ".onnx") || strings.HasSuffix(f.Name(), ".pt") || strings.HasSuffix(f.Name(), ".bias") ||
strings.HasSuffix(f.Name(), ".weight") || strings.HasSuffix(f.Name(), ".ini") || strings.HasSuffix(f.Name(), ".bin") {
modelPaths = append(modelPaths, path)
}
return nil
Expand Down Expand Up @@ -178,6 +179,16 @@ func CopyModelFileToModelRepository(modelRepository string, dir string, tritonMo
if err := cmd.Run(); err != nil {
return err
}
// TODO: add general function to check if backend use fastertransformer, which has different model file structure
} else if modelSubNames[len(modelSubNames)-3] == "fastertransformer" && tritonSubNames[len(tritonSubNames)-2] == modelSubNames[len(modelSubNames)-3] {
targetPath := fmt.Sprintf("%s/%s/%s/%s/", modelRepository, tritonModelName, modelSubNames[len(modelSubNames)-2], modelSubNames[len(modelSubNames)-1])
if err := os.MkdirAll(targetPath, os.ModePerm); err != nil {
return err
}
cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("cp %s %s/", modelPath, targetPath))
if err := cmd.Run(); err != nil {
return err
}
}
}
}
Expand Down Expand Up @@ -728,6 +739,8 @@ func GetSupportedBatchSize(task datamodel.ModelInstanceTask) int {
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_TEXT_GENERATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.TextGeneration
}
return allowedMaxBatchSize
}
26 changes: 26 additions & 0 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,13 @@ func (h *handler) TriggerModelInstance(ctx context.Context, req *modelPB.Trigger
}
lenInputs = len(textToImage)
inputInfer = textToImage
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGeneration, err := parseTexGenerationRequestInputs(req)
if err != nil {
return &modelPB.TriggerModelInstanceResponse{}, status.Error(codes.InvalidArgument, err.Error())
}
lenInputs = len(textGeneration)
inputInfer = textGeneration
}
// check whether model support batching or not. If not, raise an error
if lenInputs > 1 {
Expand Down Expand Up @@ -2051,6 +2058,17 @@ func (h *handler) TestModelInstance(ctx context.Context, req *modelPB.TestModelI
}
lenInputs = len(textToImage)
inputInfer = textToImage
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGeneration, err := parseTexGenerationRequestInputs(
&modelPB.TriggerModelInstanceRequest{
Name: req.Name,
TaskInputs: req.TaskInputs,
})
if err != nil {
return &modelPB.TestModelInstanceResponse{}, status.Error(codes.InvalidArgument, err.Error())
}
lenInputs = len(textGeneration)
inputInfer = textGeneration
}

// check whether model support batching or not. If not, raise an error
Expand Down Expand Up @@ -2190,6 +2208,14 @@ func inferModelInstanceByUpload(w http.ResponseWriter, r *http.Request, pathPara
}
lenInputs = len(textToImage)
inputInfer = textToImage
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGeneration, err := parseTextFormDataTextGenerationInputs(r)
if err != nil {
makeJSONResponse(w, 400, "File Input Error", err.Error())
return
}
lenInputs = len(textGeneration)
inputInfer = textGeneration
}

// check whether model support batching or not. If not, raise an error
Expand Down
91 changes: 91 additions & 0 deletions pkg/handler/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,41 @@ func parseTexToImageRequestInputs(req *modelPB.TriggerModelInstanceRequest) (tex
return textToImageInputs, nil
}

func parseTexGenerationRequestInputs(req *modelPB.TriggerModelInstanceRequest) (textGenerationInput []triton.TextGenerationInput, err error) {
var textGenerationInputs []triton.TextGenerationInput
for _, taskInput := range req.TaskInputs {
outputLen := int64(util.TEXT_GENERATION_OUTPUT_LEN)
if taskInput.GetTextGeneration().OutputLen != nil {
outputLen = int64(*taskInput.GetTextGeneration().OutputLen)
}
badWordsList := string("")
if taskInput.GetTextGeneration().BadWordsList != nil {
badWordsList = *taskInput.GetTextGeneration().BadWordsList
}
stopWordsList := string("")
if taskInput.GetTextGeneration().StopWordsList != nil {
stopWordsList = *taskInput.GetTextGeneration().BadWordsList
}
topK := int64(util.TEXT_GENERATION_TOP_K)
if taskInput.GetTextGeneration().Topk != nil {
topK = int64(*taskInput.GetTextGeneration().Topk)
}
seed := int64(util.TEXT_GENERATION_SEED)
if taskInput.GetTextGeneration().Seed != nil {
seed = int64(*taskInput.GetTextGeneration().Seed)
}
textGenerationInputs = append(textGenerationInputs, triton.TextGenerationInput{
Prompt: taskInput.GetTextGeneration().Prompt,
OutputLen: outputLen,
BadWordsList: badWordsList,
StopWordsList: stopWordsList,
TopK: topK,
Seed: seed,
})
}
return textGenerationInputs, nil
}

func parseImageFormDataInputsToBytes(req *http.Request) (imgsBytes [][]byte, err error) {

logger, _ := logger.GetZapLogger()
Expand Down Expand Up @@ -292,3 +327,59 @@ func parseImageFormDataTextToImageInputs(req *http.Request) (textToImageInput []
Samples: int64(samples),
}}, nil
}

func parseTextFormDataTextGenerationInputs(req *http.Request) (textGeneration []triton.TextGenerationInput, err error) {
prompts := req.MultipartForm.Value["prompt"]
if len(prompts) != 1 {
return nil, fmt.Errorf("only support batchsize 1")
}
badWordsListInput := req.MultipartForm.Value["stop_words_list"]
stopWordsListInput := req.MultipartForm.Value["stop_words_list"]
outputLenInput := req.MultipartForm.Value["output_len"]
topKInput := req.MultipartForm.Value["topk"]
seedInput := req.MultipartForm.Value["seed"]

badWordsList := string("")
if len(badWordsListInput) > 0 {
badWordsList = badWordsListInput[0]
}

stopWordsList := string("")
if len(stopWordsListInput) > 0 {
stopWordsList = stopWordsListInput[0]
}

outputLen := 100
if len(outputLenInput) > 0 {
outputLen, err = strconv.Atoi(outputLenInput[0])
if err != nil {
return nil, fmt.Errorf("invalid input %w", err)
}
}

topK := 1
if len(topKInput) > 0 {
topK, err = strconv.Atoi(topKInput[0])
if err != nil {
return nil, fmt.Errorf("invalid input %w", err)
}
}

seed := 0
if len(seedInput) > 0 {
seed, err = strconv.Atoi(seedInput[0])
if err != nil {
return nil, fmt.Errorf("invalid input %w", err)
}
}

// TODO: add support for bad/stop words
return []triton.TextGenerationInput{{
Prompt: prompts[0],
OutputLen: int64(outputLen),
BadWordsList: badWordsList,
StopWordsList: stopWordsList,
TopK: int64(topK),
Seed: int64(seed),
}}, nil
}
22 changes: 22 additions & 0 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ func (s *service) ModelInferTestMode(owner string, modelInstanceUID uuid.UUID, i
} else if strings.HasPrefix(owner, "orgs/") {
s.redisClient.IncrBy(ctx, fmt.Sprintf("org:%s:test.num", uid), int64(len(inferInput.([]triton.TextToImageInput))))
}
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
if strings.HasPrefix(owner, "users/") {
s.redisClient.IncrBy(ctx, fmt.Sprintf("user:%s:test.num", uid), int64(len(inferInput.([]triton.TextGenerationInput))))
} else if strings.HasPrefix(owner, "orgs/") {
s.redisClient.IncrBy(ctx, fmt.Sprintf("org:%s:test.num", uid), int64(len(inferInput.([]triton.TextGenerationOutput))))
}
default:
return nil, fmt.Errorf("unknown task input type")
}
Expand Down Expand Up @@ -450,6 +456,22 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, inferInput InferInput,
textToImageOutputs = append(textToImageOutputs, &textToImageOutput)
}
return textToImageOutputs, nil
case modelPB.ModelInstance_TASK_TEXT_GENERATION:
textGenerationResponses := postprocessResponse.(triton.TextGenerationOutput)
batchedOutputDataTexts := textGenerationResponses.Text
var textGenerationOutputs []*modelPB.TaskOutput
for i := range batchedOutputDataTexts {
var textGenerationOutput = modelPB.TaskOutput{
Output: &modelPB.TaskOutput_TextGeneration{
TextGeneration: &modelPB.TextGenerationOutput{
Text: batchedOutputDataTexts[i],
},
},
}

textGenerationOutputs = append(textGenerationOutputs, &textGenerationOutput)
}
return textGenerationOutputs, nil
default:
outputs := postprocessResponse.([]triton.BatchUnspecifiedTaskOutputs)
var rawOutputs []*modelPB.TaskOutput
Expand Down

0 comments on commit 767ec45

Please sign in to comment.