diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..fca3293 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,42 @@ +name: build + +on: + push: + branches: [ main ] + tags: [ v* ] + pull_request: {} + +permissions: + contents: write + pull-requests: write + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Check out code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Go ${{ matrix.node-version }} + uses: actions/setup-go@v4 + with: + go-version: 1.20.x + + - name: Run Linter + uses: golangci/golangci-lint-action@v3 + + - name: Run Tests + run: make test + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v4 + if: startsWith(github.ref, 'refs/tags/') + with: + distribution: goreleaser + version: latest + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3b735ec..2dd5137 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ # Go workspace file go.work + +.env diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..0ade7ff --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,57 @@ +linters-settings: + errcheck: + check-type-assertions: true + goconst: + min-len: 2 + min-occurrences: 3 + gocritic: + enabled-tags: + - diagnostic + - experimental + - opinionated + - performance + - style + govet: + check-shadowing: true + nolintlint: + require-explanation: true + require-specific: true + +linters: + disable-all: true + enable: + - bodyclose + #- depguard + - dogsled + #- dupl + - errcheck + - exportloopref + - exhaustive + #- goconst TODO + - gofmt + - goimports + #- gomnd + - gocyclo + - gosec + - gosimple + - govet + - ineffassign + - misspell + #- nolintlint + - nakedret + - prealloc + - predeclared + #- revive TODO + - staticcheck + - stylecheck + - thelper + - tparallel + - typecheck + - unconvert + - unparam + - unused + - whitespace + - wsl + +run: + issues-exit-code: 1 \ No newline at end of file diff --git a/.goreleaser.yml b/.goreleaser.yml new file mode 100644 index 0000000..10b9449 --- /dev/null +++ b/.goreleaser.yml @@ -0,0 +1,2 @@ +builds: +- skip: true \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d3f71bf --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +PROJECTNAME=$(shell basename "$(PWD)") + +# Go related variables. +# Make is verbose in Linux. Make it silent. +MAKEFLAGS += --silent + +.PHONY: setup +## setup: Setup installes dependencies +setup: + @go mod tidy + +.PHONY: lint +## test: Runs the linter +lint: + golangci-lint run --color=always --sort-results ./... + +.PHONY: test +## test: Runs go test with default values +test: + @go test -v -race -count=1 ./... + +.PHONY: help +## help: Prints this help message +help: Makefile + @echo + @echo " Choose a command run in "$(PROJECTNAME)":" + @echo + @sed -n 's/^##//p' $< | column -t -s ':' | sed -e 's/^/ /' + @echo \ No newline at end of file diff --git a/README.md b/README.md index a1f7b2a..feb5f89 100644 --- a/README.md +++ b/README.md @@ -1 +1,50 @@ -# go-huggingface \ No newline at end of file +# go-huggingface +> The Hugging Face Inference Client in Golang is a modul designed to interact with the Hugging Face model repository and perform inference tasks using state-of-the-art natural language processing models. Developed in Golang, it provides a seamless and efficient way to integrate Hugging Face models into your Golang applications. + +## Installation +``` +go get github.com/hupe1980/golc +``` + +## How to use +```golang +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{ + Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"}, + Parameters: huggingface.ZeroShotParameters{ + CandidateLabels: []string{"refund", "faq", "legal"}, + }, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].Sequence) + fmt.Println("Labels:", res[0].Labels) + fmt.Println("Scores:", res[0].Scores) +} +``` +Output: +```text +Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed! +Labels: [refund faq legal] +Scores: [0.8777876496315002 0.10522633790969849 0.016985949128866196] +``` + +For more example usage, see [_examples](./_examples). + +## License +[MIT](LICENCE) \ No newline at end of file diff --git a/_examples/summarization/main.go b/_examples/summarization/main.go new file mode 100644 index 0000000..3aca867 --- /dev/null +++ b/_examples/summarization/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.Summarization(context.Background(), &huggingface.SummarizationRequest{ + Inputs: []string{"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."}, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].SummaryText) +} diff --git a/_examples/text2text_generation/main.go b/_examples/text2text_generation/main.go new file mode 100644 index 0000000..8d3b2c5 --- /dev/null +++ b/_examples/text2text_generation/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.Text2TextGeneration(context.Background(), &huggingface.Text2TextGenerationRequest{ + Inputs: "The answer to the universe is", + Model: "gpt2", // overwrite recommended model + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].GeneratedText) +} diff --git a/_examples/text_generation/main.go b/_examples/text_generation/main.go new file mode 100644 index 0000000..a543100 --- /dev/null +++ b/_examples/text_generation/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.TextGeneration(context.Background(), &huggingface.TextGenerationRequest{ + Inputs: "The answer to the universe is", + Model: "gpt2", // overwrite recommended model + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].GeneratedText) +} diff --git a/_examples/zero_shot_classification/main.go b/_examples/zero_shot_classification/main.go new file mode 100644 index 0000000..c6f6baa --- /dev/null +++ b/_examples/zero_shot_classification/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.ZeroShotClassification(context.Background(), &huggingface.ZeroShotRequest{ + Inputs: []string{"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"}, + Parameters: huggingface.ZeroShotParameters{ + CandidateLabels: []string{"refund", "faq", "legal"}, + }, + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(res[0].Sequence) + fmt.Println("Labels:", res[0].Labels) + fmt.Println("Scores:", res[0].Scores) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..52485ff --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/hupe1980/go-huggingface + +go 1.20 + +require github.com/stretchr/testify v1.8.4 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8cf6655 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/huggingface.go b/huggingface.go new file mode 100644 index 0000000..0178ee8 --- /dev/null +++ b/huggingface.go @@ -0,0 +1,271 @@ +package gohuggingface + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" +) + +var ( + recommendedModels map[string]string +) + +// HTTPClient is an interface representing an HTTP client. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// InferenceClientOptions represents options for the InferenceClient. +type InferenceClientOptions struct { + Model string + Endpoint string + InferenceEndpoint string +} + +// InferenceClient is a client for performing inference using Hugging Face models. +type InferenceClient struct { + httpClient HTTPClient + token string + opts InferenceClientOptions +} + +// NewInferenceClient creates a new InferenceClient instance with the specified token. +func NewInferenceClient(token string) *InferenceClient { + opts := InferenceClientOptions{ + Endpoint: "https://huggingface.co", + InferenceEndpoint: "https://api-inference.huggingface.co", + } + + return &InferenceClient{ + httpClient: http.DefaultClient, + token: token, + opts: opts, + } +} + +// Summarization performs text summarization using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the generated summary or an error if the request fails. +func (ic *InferenceClient) Summarization(ctx context.Context, req *SummarizationRequest) (SummarizationResponse, error) { + if len(req.Inputs) == 0 { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "summarization", req) + if err != nil { + return nil, err + } + + summarizationResponse := SummarizationResponse{} + if err := json.Unmarshal(body, &summarizationResponse); err != nil { + return nil, err + } + + return summarizationResponse, nil +} + +// TextGeneration performs text generation using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the generated text or an error if the request fails. +func (ic *InferenceClient) TextGeneration(ctx context.Context, req *TextGenerationRequest) (TextGenerationResponse, error) { + if req.Inputs == "" { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "text-generation", req) + if err != nil { + return nil, err + } + + textGenerationResponse := TextGenerationResponse{} + if err := json.Unmarshal(body, &textGenerationResponse); err != nil { + return nil, err + } + + return textGenerationResponse, nil +} + +// Text2TextGeneration performs text-to-text generation using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the generated text or an error if the request fails. +func (ic *InferenceClient) Text2TextGeneration(ctx context.Context, req *Text2TextGenerationRequest) (Text2TextGenerationResponse, error) { + if req.Inputs == "" { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "text2text-generation", req) + if err != nil { + return nil, err + } + + text2TextGenerationResponse := Text2TextGenerationResponse{} + if err := json.Unmarshal(body, &text2TextGenerationResponse); err != nil { + return nil, err + } + + return text2TextGenerationResponse, nil +} + +// ZeroShotClassification performs zero-shot classification using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the classification results or an error if the request fails. +func (ic *InferenceClient) ZeroShotClassification(ctx context.Context, req *ZeroShotRequest) (ZeroShotResponse, error) { + if len(req.Inputs) == 0 { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "zero-shot-classification", req) + if err != nil { + return nil, err + } + + zeroShotResponse := ZeroShotResponse{} + if err := json.Unmarshal(body, &zeroShotResponse); err != nil { + return nil, err + } + + return zeroShotResponse, nil +} + +// post sends a POST request to the specified model and task with the provided payload. +// It returns the response body or an error if the request fails. +func (ic *InferenceClient) post(ctx context.Context, model, task string, payload any) ([]byte, error) { + url, err := ic.resolveURL(ctx, model, task) + if err != nil { + return nil, err + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + + if ic.token != "" { + httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ic.token)) + } + + res, err := ic.httpClient.Do(httpReq) + if err != nil { + return nil, err + } + + defer res.Body.Close() + + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("huggingfaces error: %s", resBody) + } + + return resBody, nil +} + +// resolveURL resolves the URL for the specified model and task. +// It returns the resolved URL or an error if resolution fails. +func (ic *InferenceClient) resolveURL(ctx context.Context, model, task string) (string, error) { + if model == "" { + model = ic.opts.Model + } + + // If model is already a URL, ignore `task` and return directly + if model != "" && (strings.HasPrefix(model, "http://") || strings.HasPrefix(model, "https://")) { + return model, nil + } + + if model == "" { + var err error + + model, err = ic.getRecommendedModel(ctx, task) + if err != nil { + return "", err + } + } + + // Feature-extraction and sentence-similarity are the only cases where models support multiple tasks + if contains([]string{"feature-extraction", "sentence-similarity"}, task) { + return fmt.Sprintf("%s/pipeline/%s/%s", ic.opts.InferenceEndpoint, task, model), nil + } + + return fmt.Sprintf("%s/models/%s", ic.opts.InferenceEndpoint, model), nil +} + +// getRecommendedModel retrieves the recommended model for the specified task. +// It returns the recommended model or an error if retrieval fails. +func (ic *InferenceClient) getRecommendedModel(ctx context.Context, task string) (string, error) { + rModels, err := ic.fetchRecommendedModels(ctx) + if err != nil { + return "", err + } + + model, ok := rModels[task] + if !ok { + return "", fmt.Errorf("task %s has no recommended model", task) + } + + return model, nil +} + +// fetchRecommendedModels retrieves the recommended models for all available tasks. +// It returns a map of task names to recommended models or an error if retrieval fails. +func (ic *InferenceClient) fetchRecommendedModels(ctx context.Context) (map[string]string, error) { + if recommendedModels == nil { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/api/tasks", ic.opts.Endpoint), nil) + if err != nil { + return nil, err + } + + res, err := ic.httpClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var jsonResponse map[string]interface{} + + err = json.NewDecoder(res.Body).Decode(&jsonResponse) + if err != nil { + return nil, err + } + + recommendedModels = make(map[string]string) + + for task, details := range jsonResponse { + widgetModels, ok := details.(map[string]interface{})["widgetModels"].([]interface{}) + if !ok || len(widgetModels) == 0 { + recommendedModels[task] = "" + } else { + firstModel, _ := widgetModels[0].(string) + recommendedModels[task] = firstModel + } + } + } + + return recommendedModels, nil +} + +// Contains checks if the given element is present in the collection. +func contains[T comparable](collection []T, element T) bool { + for _, item := range collection { + if item == element { + return true + } + } + + return false +} diff --git a/huggingface_test.go b/huggingface_test.go new file mode 100644 index 0000000..01ce100 --- /dev/null +++ b/huggingface_test.go @@ -0,0 +1,78 @@ +package gohuggingface + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Mock HTTP Client for testing purposes +type mockHTTPClient struct { + Response []byte + Err error +} + +func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + if c.Err != nil { + return nil, c.Err + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBuffer(c.Response)), + }, nil +} + +func TestSummarization(t *testing.T) { + client := NewInferenceClient("your-token") + mockResponse := []byte(`[{"summary_text": "This is a summary"}]`) + + t.Run("Successful Request", func(t *testing.T) { + // Mock HTTP Client with successful response + mockHTTP := &mockHTTPClient{Response: mockResponse} + client.httpClient = mockHTTP + + req := &SummarizationRequest{ + Inputs: []string{"This is a test input"}, + Model: "t5-base", + } + + response, err := client.Summarization(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, response) + assert.Equal(t, "This is a summary", response[0].SummaryText) + }) + + t.Run("Empty Inputs", func(t *testing.T) { + req := &SummarizationRequest{ + Inputs: nil, // Empty inputs + Model: "t5-base", + } + + response, err := client.Summarization(context.Background(), req) + assert.Error(t, err) + assert.Nil(t, response) + assert.Equal(t, "inputs are required", err.Error()) + }) + + t.Run("HTTP Request Error", func(t *testing.T) { + // Mock HTTP Client with error response + mockHTTP := &mockHTTPClient{Err: errors.New("request error")} + client.httpClient = mockHTTP + + req := &SummarizationRequest{ + Inputs: []string{"This is a test input"}, + Model: "t5-base", + } + + response, err := client.Summarization(context.Background(), req) + assert.Error(t, err) + assert.Nil(t, response) + assert.Equal(t, "request error", err.Error()) + }) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..79ca272 --- /dev/null +++ b/types.go @@ -0,0 +1,186 @@ +package gohuggingface + +type Options struct { + // (Default: true). There is a cache layer on the inference API to speedup + // requests we have already seen. Most models can use those results as is + // as models are deterministic (meaning the results will be the same anyway). + // However if you use a non deterministic model, you can set this parameter + // to prevent the caching mechanism from being used resulting in a real new query. + UseCache *bool `json:"use_cache,omitempty"` + + // (Default: false) If the model is not ready, wait for it instead of receiving 503. + // It limits the number of requests required to get your inference done. It is advised + // to only set this flag to true after receiving a 503 error as it will limit hanging + // in your application to known places. + WaitForModel *bool `json:"wait_for_model,omitempty"` +} + +type SummarizationParameters struct { + // (Default: None). Integer to define the minimum length in tokens of the output summary. + MinLength *int `json:"min_length,omitempty"` + + // (Default: None). Integer to define the maximum length in tokens of the output summary. + MaxLength *int `json:"max_length,omitempty"` + + // (Default: None). Integer to define the top tokens considered within the sample operation to create + // new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. + // Add tokens in the sample for more probable to least probable until the sum of the probabilities is + // greater than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 mens top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. + MaxTime *float64 `json:"maxtime,omitempty"` +} + +type SummarizationRequest struct { + // String to be summarized + Inputs []string `json:"inputs"` + Parameters SummarizationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type SummarizationResponse []struct { + // The summarized input string + SummaryText string `json:"summary_text,omitempty"` +} + +type TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +// A list of generated texts. The length of this list is the value of +// NumReturnSequences in the request. +type TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} + +type Text2TextGenerationParameters struct { + // (Default: None). Integer to define the top tokens considered within the sample operation to create new text. + TopK *int `json:"top_k,omitempty"` + + // (Default: None). Float to define the tokens that are within the sample` operation of text generation. Add + // tokens in the sample for more probable to least probable until the sum of the probabilities is greater + // than top_p. + TopP *float64 `json:"top_p,omitempty"` + + // (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, + // 0 means top_k=1, 100.0 is getting closer to uniform probability. + Temperature *float64 `json:"temperature,omitempty"` + + // (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized + // to not be picked in successive generation passes. + RepetitionPenalty *float64 `json:"repetition_penalty,omitempty"` + + // (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input + // length it is a estimate of the size of generated text you want. Each new tokens slows down the request, + // so look for balance between response times and length of text generated. + MaxNewTokens *int `json:"max_new_tokens,omitempty"` + + // (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. + // Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens + // for best results. + MaxTime *float64 `json:"max_time,omitempty"` + + // (Default: True). Bool. If set to False, the return results will not contain the original query making it + // easier for prompting. + ReturnFullText *bool `json:"return_full_text,omitempty"` + + // (Default: 1). Integer. The number of proposition you want to be returned. + NumReturnSequences *int `json:"num_return_sequences,omitempty"` +} + +type Text2TextGenerationRequest struct { + // String to generated from + Inputs string `json:"inputs"` + Parameters Text2TextGenerationParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type Text2TextGenerationResponse []struct { + GeneratedText string `json:"generated_text,omitempty"` +} + +type ZeroShotParameters struct { + // (Required) A list of strings that are potential classes for inputs. Max 10 candidate_labels, + // for more, simply run multiple requests, results are going to be misleading if using + // too many candidate_labels anyway. If you want to keep the exact same, you can + // simply run multi_label=True and do the scaling on your end. + CandidateLabels []string `json:"candidate_labels"` + + // (Default: false) Boolean that is set to True if classes can overlap + MultiLabel *bool `json:"multi_label,omitempty"` +} + +type ZeroShotRequest struct { + // (Required) Input or Inputs are required request fields + Inputs []string `json:"inputs"` + // (Required) + Parameters ZeroShotParameters `json:"parameters,omitempty"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +type ZeroShotResponse []struct { + // The string sent as an input + Sequence string `json:"sequence,omitempty"` + + // The list of labels sent in the request, sorted in descending order + // by probability that the input corresponds to the to the label. + Labels []string `json:"labels,omitempty"` + + // a list of floats that correspond the the probability of label, in the same order as labels. + Scores []float64 `json:"scores,omitempty"` +}