Skip to content

Commit

Permalink
Add question answering
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 25, 2023
1 parent 57d46c6 commit 54ac4ef
Show file tree
Hide file tree
Showing 13 changed files with 323 additions and 196 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# go-huggingface
# 🤗 go-huggingface
![Build Status](https://github.com/hupe1980/go-huggingface/workflows/build/badge.svg)
[![Go Reference](https://pkg.go.dev/badge/github.com/hupe1980/go-huggingface.svg)](https://pkg.go.dev/github.com/hupe1980/go-huggingface)
> The Hugging Face Inference Client in Golang is a modul designed to interact with the Hugging Face model repository and perform inference tasks using state-of-the-art natural language processing models. Developed in Golang, it provides a seamless and efficient way to integrate Hugging Face models into your Golang applications.
Expand All @@ -24,9 +24,9 @@ import (
func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{
res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotClassificationRequest{
Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"},
Parameters: huggingface.ZeroShotParameters{
Parameters: huggingface.ZeroShotClassificationParameters{
CandidateLabels: []string{"refund", "faq", "legal"},
},
})
Expand Down
26 changes: 26 additions & 0 deletions _examples/question_answering/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.QuestionAnswering(context.Background(), &huggingface.QuestionAnsweringRequest{
Inputs: huggingface.QuestionAnsweringInputs{
Question: "What's my name?",
Context: "My name is Clara and I live in Berkeley.",
},
})
if err != nil {
log.Fatal(err)
}

fmt.Println(res[0].Answer)
}
4 changes: 2 additions & 2 deletions _examples/zero_shot_classification/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{
res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotClassificationRequest{
Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"},
Parameters: huggingface.ZeroShotParameters{
Parameters: huggingface.ZeroShotClassificationParameters{
CandidateLabels: []string{"refund", "faq", "legal"},
},
})
Expand Down
5 changes: 5 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package gohuggingface

type ErrorResponse struct {
Error string `json:"error"`
}
44 changes: 39 additions & 5 deletions huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,51 @@ func (ic *InferenceClient) Text2TextGeneration(ctx context.Context, req *Text2Te
// ZeroShotClassification performs zero-shot classification using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
// The response contains the classification results or an error if the request fails.
func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotRequest) (ZeroShotResponse, error) {
func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotClassificationRequest) (ZeroShotClassificationResponse, error) {
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

if len(req.Parameters.CandidateLabels) == 0 {
return nil, errors.New("canidateLabels are required")
}

body, err := ic.post(ctx, req.Model, "zero-shot-classification", req)
if err != nil {
return nil, err
}

zeroShotResponse := ZeroShotResponse{}
if err := json.Unmarshal(body, &zeroShotResponse); err != nil {
zeroShotClassificationResponse := ZeroShotClassificationResponse{}
if err := json.Unmarshal(body, &zeroShotClassificationResponse); err != nil {
return nil, err
}

return zeroShotResponse, nil
return zeroShotClassificationResponse, nil
}

// QuestionAnswering performs question answering using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided question and context inputs.
// The response contains the answer or an error if the request fails.
func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionAnsweringRequest) (QuestionAnsweringResponse, error) {
if req.Inputs.Question == "" {
return nil, errors.New("question is required")
}

if req.Inputs.Context == "" {
return nil, errors.New("context is required")
}

body, err := ic.post(ctx, req.Model, "question-answering", req)
if err != nil {
return nil, err
}

questionAnsweringResponse := QuestionAnsweringResponse{}
if err := json.Unmarshal(body, &questionAnsweringResponse); err != nil {
return nil, err
}

return questionAnsweringResponse, nil
}

// post sends a POST request to the specified model and task with the provided payload.
Expand Down Expand Up @@ -174,7 +203,12 @@ func (ic *InferenceClient) post(ctx context.Context, model, task string, payload
}

if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("huggingfaces error: %s", resBody)
errResp := ErrorResponse{}
if err := json.Unmarshal(resBody, &errResp); err != nil {
return nil, fmt.Errorf("huggingfaces error: %s", resBody)
}

return nil, fmt.Errorf("huggingfaces error: %s", errResp.Error)
}

return resBody, nil
Expand Down
26 changes: 26 additions & 0 deletions huggingface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,29 @@ func TestSummarization(t *testing.T) {
assert.Equal(t, "request error", err.Error())
})
}

func TestQuestionAnswering(t *testing.T) {
client := NewInferenceClient("your-token")

t.Run("Missing question input", func(t *testing.T) {
req := &QuestionAnsweringRequest{
Model: "distilbert-base-uncased-distilled-squad",
Inputs: QuestionAnsweringInputs{
Context: "Paris is the capital of France.",
},
}
_, err := client.QuestionAnswering(context.Background(), req)
assert.EqualError(t, err, "question is required")
})

t.Run("Missing context input", func(t *testing.T) {
req := &QuestionAnsweringRequest{
Model: "distilbert-base-uncased-distilled-squad",
Inputs: QuestionAnsweringInputs{
Question: "What is the capital of France?",
},
}
_, err := client.QuestionAnswering(context.Background(), req)
assert.EqualError(t, err, "context is required")
})
}
16 changes: 16 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package gohuggingface

type Options struct {
// (Default: true). There is a cache layer on the inference API to speedup
// requests we have already seen. Most models can use those results as is
// as models are deterministic (meaning the results will be the same anyway).
// However if you use a non deterministic model, you can set this parameter
// to prevent the caching mechanism from being used resulting in a real new query.
UseCache *bool `json:"use_cache,omitempty"`

// (Default: false) If the model is not ready, wait for it instead of receiving 503.
// It limits the number of requests required to get your inference done. It is advised
// to only set this flag to true after receiving a 503 error as it will limit hanging
// in your application to known places.
WaitForModel *bool `json:"wait_for_model,omitempty"`
}
32 changes: 32 additions & 0 deletions question_answering.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package gohuggingface

type QuestionAnsweringInputs struct {
// (Required) The question as a string that has an answer within Context.
Question string `json:"question"`

// (Required) A string that contains the answer to the question
Context string `json:"context"`
}

// Request structure for question answering model
type QuestionAnsweringRequest struct {
// (Required)
Inputs QuestionAnsweringInputs `json:"inputs,omitempty"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// Response structure for question answering model
type QuestionAnsweringResponse []struct {
// A string that’s the answer within the Context text.
Answer string `json:"answer,omitempty"`

// A float that represents how likely that the answer is correct.
Score float64 `json:"score,omitempty"`

// The string index of the start of the answer within Context.
Start int `json:"start,omitempty"`

// The string index of the stop of the answer within Context.
End int `json:"end,omitempty"`
}
43 changes: 43 additions & 0 deletions summarization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package gohuggingface

type SummarizationParameters struct {
// (Default: None). Integer to define the minimum length in tokens of the output summary.
MinLength *int `json:"min_length,omitempty"`

// (Default: None). Integer to define the maximum length in tokens of the output summary.
MaxLength *int `json:"max_length,omitempty"`

// (Default: None). Integer to define the top tokens considered within the sample operation to create
// new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
// greater than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit.
MaxTime *float64 `json:"maxtime,omitempty"`
}

type SummarizationRequest struct {
// String to be summarized
Inputs []string `json:"inputs"`
Parameters SummarizationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

type SummarizationResponse []struct {
// The summarized input string
SummaryText string `json:"summary_text,omitempty"`
}
48 changes: 48 additions & 0 deletions text2text_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package gohuggingface

type Text2TextGenerationParameters struct {
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater
// than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 means top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`

// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
// so look for balance between response times and length of text generated.
MaxNewTokens *int `json:"max_new_tokens,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens
// for best results.
MaxTime *float64 `json:"max_time,omitempty"`

// (Default: True). Bool. If set to False, the return results will not contain the original query making it
// easier for prompting.
ReturnFullText *bool `json:"return_full_text,omitempty"`

// (Default: 1). Integer. The number of proposition you want to be returned.
NumReturnSequences *int `json:"num_return_sequences,omitempty"`
}

type Text2TextGenerationRequest struct {
// String to generated from
Inputs string `json:"inputs"`
Parameters Text2TextGenerationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

type Text2TextGenerationResponse []struct {
GeneratedText string `json:"generated_text,omitempty"`
}
50 changes: 50 additions & 0 deletions text_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package gohuggingface

type TextGenerationParameters struct {
// (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add
// tokens in the sample for more probable to least probable until the sum of the probabilities is greater
// than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 means top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"`

// (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input
// length it is a estimate of the size of generated text you want. Each new tokens slows down the request,
// so look for balance between response times and length of text generated.
MaxNewTokens *int `json:"max_new_tokens,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens
// for best results.
MaxTime *float64 `json:"max_time,omitempty"`

// (Default: True). Bool. If set to False, the return results will not contain the original query making it
// easier for prompting.
ReturnFullText *bool `json:"return_full_text,omitempty"`

// (Default: 1). Integer. The number of proposition you want to be returned.
NumReturnSequences *int `json:"num_return_sequences,omitempty"`
}

type TextGenerationRequest struct {
// String to generated from
Inputs string `json:"inputs"`
Parameters TextGenerationParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// A list of generated texts. The length of this list is the value of
// NumReturnSequences in the request.
type TextGenerationResponse []struct {
GeneratedText string `json:"generated_text,omitempty"`
}
Loading

0 comments on commit 54ac4ef

Please sign in to comment.