diff --git a/README.md b/README.md index 42d0aab..c410366 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# go-huggingface +# 🤗 go-huggingface ![Build Status](https://github.com/hupe1980/go-huggingface/workflows/build/badge.svg) [![Go Reference](https://pkg.go.dev/badge/github.com/hupe1980/go-huggingface.svg)](https://pkg.go.dev/github.com/hupe1980/go-huggingface) > The Hugging Face Inference Client in Golang is a modul designed to interact with the Hugging Face model repository and perform inference tasks using state-of-the-art natural language processing models. Developed in Golang, it provides a seamless and efficient way to integrate Hugging Face models into your Golang applications. @@ -24,9 +24,9 @@ import ( func main() { ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) - res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{ + res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotClassificationRequest{ Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"}, - Parameters: huggingface.ZeroShotParameters{ + Parameters: huggingface.ZeroShotClassificationParameters{ CandidateLabels: []string{"refund", "faq", "legal"}, }, }) diff --git a/_examples/question_answering/main.go b/_examples/question_answering/main.go new file mode 100644 index 0000000..615012d --- /dev/null +++ b/_examples/question_answering/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.QuestionAnswering(context.Background(), &huggingface.QuestionAnsweringRequest{ + Inputs: huggingface.QuestionAnsweringInputs{ + Question: "What's my name?", + Context: "My name is Clara and I live in Berkeley.", + }, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].Answer) +} diff --git a/_examples/zero_shot_classification/main.go b/_examples/zero_shot_classification/main.go index c6f6baa..474a30f 100644 --- a/_examples/zero_shot_classification/main.go +++ b/_examples/zero_shot_classification/main.go @@ -12,9 +12,9 @@ import ( func main() { ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) - res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{ + res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotClassificationRequest{ Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"}, - Parameters: huggingface.ZeroShotParameters{ + Parameters: huggingface.ZeroShotClassificationParameters{ CandidateLabels: []string{"refund", "faq", "legal"}, }, }) diff --git a/error.go b/error.go new file mode 100644 index 0000000..ff931bc --- /dev/null +++ b/error.go @@ -0,0 +1,5 @@ +package gohuggingface + +type ErrorResponse struct { + Error string `json:"error"` +} diff --git a/huggingface.go b/huggingface.go index 613512b..ad06b85 100644 --- a/huggingface.go +++ b/huggingface.go @@ -118,22 +118,51 @@ func (ic *InferenceClient) Text2TextGeneration(ctx context.Context, req *Text2Te // ZeroShotClassification performs zero-shot classification using the specified model. // It sends a POST request to the Hugging Face inference endpoint with the provided inputs. // The response contains the classification results or an error if the request fails. -func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotRequest) (ZeroShotResponse, error) { +func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotClassificationRequest) (ZeroShotClassificationResponse, error) { if len(req.Inputs) == 0 { return nil, errors.New("inputs are required") } + if len(req.Parameters.CandidateLabels) == 0 { + return nil, errors.New("canidateLabels are required") + } + body, err := ic.post(ctx, req.Model, "zero-shot-classification", req) if err != nil { return nil, err } - zeroShotResponse := ZeroShotResponse{} - if err := json.Unmarshal(body, &zeroShotResponse); err != nil { + zeroShotClassificationResponse := ZeroShotClassificationResponse{} + if err := json.Unmarshal(body, &zeroShotClassificationResponse); err != nil { return nil, err } - return zeroShotResponse, nil + return zeroShotClassificationResponse, nil +} + +// QuestionAnswering performs question answering using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided question and context inputs. +// The response contains the answer or an error if the request fails. +func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionAnsweringRequest) (QuestionAnsweringResponse, error) { + if req.Inputs.Question == "" { + return nil, errors.New("question is required") + } + + if req.Inputs.Context == "" { + return nil, errors.New("context is required") + } + + body, err := ic.post(ctx, req.Model, "question-answering", req) + if err != nil { + return nil, err + } + + questionAnsweringResponse := QuestionAnsweringResponse{} + if err := json.Unmarshal(body, &questionAnsweringResponse); err != nil { + return nil, err + } + + return questionAnsweringResponse, nil } // post sends a POST request to the specified model and task with the provided payload. @@ -174,7 +203,12 @@ func (ic *InferenceClient) post(ctx context.Context, model, task string, payload } if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("huggingfaces error: %s", resBody) + errResp := ErrorResponse{} + if err := json.Unmarshal(resBody, &errResp); err != nil { + return nil, fmt.Errorf("huggingfaces error: %s", resBody) + } + + return nil, fmt.Errorf("huggingfaces error: %s", errResp.Error) } return resBody, nil diff --git a/huggingface_test.go b/huggingface_test.go index 01ce100..1c9ec49 100644 --- a/huggingface_test.go +++ b/huggingface_test.go @@ -76,3 +76,29 @@ func TestSummarization(t *testing.T) { assert.Equal(t, "request error", err.Error()) }) } + +func TestQuestionAnswering(t *testing.T) { + client := NewInferenceClient("your-token") + + t.Run("Missing question input", func(t *testing.T) { + req := &QuestionAnsweringRequest{ + Model: "distilbert-base-uncased-distilled-squad", + Inputs: QuestionAnsweringInputs{ + Context: "Paris is the capital of France.", + }, + } + _, err := client.QuestionAnswering(context.Background(), req) + assert.EqualError(t, err, "question is required") + }) + + t.Run("Missing context input", func(t *testing.T) { + req := &QuestionAnsweringRequest{ + Model: "distilbert-base-uncased-distilled-squad", + Inputs: QuestionAnsweringInputs{ + Question: "What is the capital of France?", + }, + } + _, err := client.QuestionAnswering(context.Background(), req) + assert.EqualError(t, err, "context is required") + }) +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..acbc382 --- /dev/null +++ b/options.go @@ -0,0 +1,16 @@ +package gohuggingface + +type Options struct { + // (Default: true). There is a cache layer on the inference API to speedup + // requests we have already seen. Most models can use those results as is + // as models are deterministic (meaning the results will be the same anyway). + // However if you use a non deterministic model, you can set this parameter + // to prevent the caching mechanism from being used resulting in a real new query. + UseCache *bool `json:"use_cache,omitempty"` + + // (Default: false) If the model is not ready, wait for it instead of receiving 503. + // It limits the number of requests required to get your inference done. It is advised + // to only set this flag to true after receiving a 503 error as it will limit hanging + // in your application to known places. + WaitForModel *bool `json:"wait_for_model,omitempty"` +} diff --git a/question_answering.go b/question_answering.go new file mode 100644 index 0000000..8af327e --- /dev/null +++ b/question_answering.go @@ -0,0 +1,32 @@ +package gohuggingface + +type QuestionAnsweringInputs struct { + // (Required) The question as a string that has an answer within Context. + Question string `json:"question"` + + // (Required) A string that contains the answer to the question + Context string `json:"context"` +} + +// Request structure for question answering model +type QuestionAnsweringRequest struct { + // (Required) + Inputs QuestionAnsweringInputs `json:"inputs,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +// Response structure for question answering model +type QuestionAnsweringResponse []struct { + // A string that’s the answer within the Context text. + Answer string `json:"answer,omitempty"` + + // A float that represents how likely that the answer is correct. + Score float64 `json:"score,omitempty"` + + // The string index of the start of the answer within Context. + Start int `json:"start,omitempty"` + + // The string index of the stop of the answer within Context. + End int `json:"end,omitempty"` +} diff --git a/summarization.go b/summarization.go new file mode 100644 index 0000000..fcc7356 --- /dev/null +++ b/summarization.go @@ -0,0 +1,43 @@ +package gohuggingface + +type SummarizationParameters struct { + // (Default: None). Integer to define the minimum length in tokens of the output summary. + MinLength *int `json:"min_length,omitempty"` + + // (Default: None). Integer to define the maximum length in tokens of the output summary. + MaxLength *int `json:"max_length,omitempty"` + + // (Default: None). Integer to define the top tokens considered within the sample operation to create + // new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. + // Add tokens in the sample for more probable to least probable until the sum of the probabilities is + // greater than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 mens top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. + MaxTime *float64 `json:"maxtime,omitempty"` +} + +type SummarizationRequest struct { + // String to be summarized + Inputs []string `json:"inputs"` + Parameters SummarizationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type SummarizationResponse []struct { + // The summarized input string + SummaryText string `json:"summary_text,omitempty"` +} diff --git a/text2text_generation.go b/text2text_generation.go new file mode 100644 index 0000000..9fbceae --- /dev/null +++ b/text2text_generation.go @@ -0,0 +1,48 @@ +package gohuggingface + +type Text2TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type Text2TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters Text2TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type Text2TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} diff --git a/text_generation.go b/text_generation.go new file mode 100644 index 0000000..dede2ba --- /dev/null +++ b/text_generation.go @@ -0,0 +1,50 @@ +package gohuggingface + +type TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +// A list of generated texts. The length of this list is the value of +// NumReturnSequences in the request. +type TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} diff --git a/types.go b/types.go deleted file mode 100644 index 79ca272..0000000 --- a/types.go +++ /dev/null @@ -1,186 +0,0 @@ -package gohuggingface - -type Options struct { - // (Default: true). There is a cache layer on the inference API to speedup - // requests we have already seen. Most models can use those results as is - // as models are deterministic (meaning the results will be the same anyway). - // However if you use a non deterministic model, you can set this parameter - // to prevent the caching mechanism from being used resulting in a real new query. - UseCache *bool `json:"use_cache,omitempty"` - - // (Default: false) If the model is not ready, wait for it instead of receiving 503. - // It limits the number of requests required to get your inference done. It is advised - // to only set this flag to true after receiving a 503 error as it will limit hanging - // in your application to known places. - WaitForModel *bool `json:"wait_for_model,omitempty"` -} - -type SummarizationParameters struct { - // (Default: None). Integer to define the minimum length in tokens of the output summary. - MinLength *int `json:"min_length,omitempty"` - - // (Default: None). Integer to define the maximum length in tokens of the output summary. - MaxLength *int `json:"max_length,omitempty"` - - // (Default: None). Integer to define the top tokens considered within the sample operation to create - // new text. - TopK *int `json:"top_k,omitempty"` - - // (Default: None). Float to define the tokens that are within the sample` operation of text generation. - // Add tokens in the sample for more probable to least probable until the sum of the probabilities is - // greater than top_p. - TopP *float64 `json:"top_p,omitempty"` - - // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, - // 0 mens top_k=1, 100.0 is getting closer to uniform probability. - Temperature *float64 `json:"temperature,omitempty"` - - // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized - // to not be picked in successive generation passes. - RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` - - // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. - // Network can cause some overhead so it will be a soft limit. - MaxTime *float64 `json:"maxtime,omitempty"` -} - -type SummarizationRequest struct { - // String to be summarized - Inputs []string `json:"inputs"` - Parameters SummarizationParameters `json:"parameters,omitempty"` - Options Options `json:"options,omitempty"` - Model string `json:"-"` -} - -type SummarizationResponse []struct { - // The summarized input string - SummaryText string `json:"summary_text,omitempty"` -} - -type TextGenerationParameters struct { - // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. - TopK *int `json:"top_k,omitempty"` - - // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add - // tokens in the sample for more probable to least probable until the sum of the probabilities is greater - // than top_p. - TopP *float64 `json:"top_p,omitempty"` - - // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, - // 0 means top_k=1, 100.0 is getting closer to uniform probability. - Temperature *float64 `json:"temperature,omitempty"` - - // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized - // to not be picked in successive generation passes. - RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` - - // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input - // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, - // so look for balance between response times and length of text generated. - MaxNewTokens *int `json:"max_new_tokens,omitempty"` - - // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. - // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens - // for best results. - MaxTime *float64 `json:"max_time,omitempty"` - - // (Default: True). Bool. If set to False, the return results will not contain the original query making it - // easier for prompting. - ReturnFullText *bool `json:"return_full_text,omitempty"` - - // (Default: 1). Integer. The number of proposition you want to be returned. - NumReturnSequences *int `json:"num_return_sequences,omitempty"` -} - -type TextGenerationRequest struct { - // String to generated from - Inputs string `json:"inputs"` - Parameters TextGenerationParameters `json:"parameters,omitempty"` - Options Options `json:"options,omitempty"` - Model string `json:"-"` -} - -// A list of generated texts. The length of this list is the value of -// NumReturnSequences in the request. -type TextGenerationResponse []struct { - GeneratedText string `json:"generated_text,omitempty"` -} - -type Text2TextGenerationParameters struct { - // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. - TopK *int `json:"top_k,omitempty"` - - // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add - // tokens in the sample for more probable to least probable until the sum of the probabilities is greater - // than top_p. - TopP *float64 `json:"top_p,omitempty"` - - // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, - // 0 means top_k=1, 100.0 is getting closer to uniform probability. - Temperature *float64 `json:"temperature,omitempty"` - - // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized - // to not be picked in successive generation passes. - RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` - - // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input - // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, - // so look for balance between response times and length of text generated. - MaxNewTokens *int `json:"max_new_tokens,omitempty"` - - // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. - // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens - // for best results. - MaxTime *float64 `json:"max_time,omitempty"` - - // (Default: True). Bool. If set to False, the return results will not contain the original query making it - // easier for prompting. - ReturnFullText *bool `json:"return_full_text,omitempty"` - - // (Default: 1). Integer. The number of proposition you want to be returned. - NumReturnSequences *int `json:"num_return_sequences,omitempty"` -} - -type Text2TextGenerationRequest struct { - // String to generated from - Inputs string `json:"inputs"` - Parameters Text2TextGenerationParameters `json:"parameters,omitempty"` - Options Options `json:"options,omitempty"` - Model string `json:"-"` -} - -type Text2TextGenerationResponse []struct { - GeneratedText string `json:"generated_text,omitempty"` -} - -type ZeroShotParameters struct { - // (Required) A list of strings that are potential classes for inputs. Max 10 candidate_labels, - // for more, simply run multiple requests, results are going to be misleading if using - // too many candidate_labels anyway. If you want to keep the exact same, you can - // simply run multi_label=True and do the scaling on your end. - CandidateLabels []string `json:"candidate_labels"` - - // (Default: false) Boolean that is set to True if classes can overlap - MultiLabel *bool `json:"multi_label,omitempty"` -} - -type ZeroShotRequest struct { - // (Required) Input or Inputs are required request fields - Inputs []string `json:"inputs"` - // (Required) - Parameters ZeroShotParameters `json:"parameters,omitempty"` - Options Options `json:"options,omitempty"` - Model string `json:"-"` -} - -type ZeroShotResponse []struct { - // The string sent as an input - Sequence string `json:"sequence,omitempty"` - - // The list of labels sent in the request, sorted in descending order - // by probability that the input corresponds to the to the label. - Labels []string `json:"labels,omitempty"` - - // a list of floats that correspond the the probability of label, in the same order as labels. - Scores []float64 `json:"scores,omitempty"` -} diff --git a/zero_shot_classification.go b/zero_shot_classification.go new file mode 100644 index 0000000..ee67f2c --- /dev/null +++ b/zero_shot_classification.go @@ -0,0 +1,33 @@ +package gohuggingface + +type ZeroShotClassificationParameters struct { + // (Required) A list of strings that are potential classes for inputs. Max 10 candidate_labels, + // for more, simply run multiple requests, results are going to be misleading if using + // too many candidate_labels anyway. If you want to keep the exact same, you can + // simply run multi_label=True and do the scaling on your end. + CandidateLabels []string `json:"candidate_labels"` + + // (Default: false) Boolean that is set to True if classes can overlap + MultiLabel *bool `json:"multi_label,omitempty"` +} + +type ZeroShotClassificationRequest struct { + // (Required) Input or Inputs are required request fields + Inputs []string `json:"inputs"` + // (Required) + Parameters ZeroShotClassificationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type ZeroShotClassificationResponse []struct { + // The string sent as an input + Sequence string `json:"sequence,omitempty"` + + // The list of labels sent in the request, sorted in descending order + // by probability that the input corresponds to the to the label. + Labels []string `json:"labels,omitempty"` + + // a list of floats that correspond the the probability of label, in the same order as labels. + Scores []float64 `json:"scores,omitempty"` +}