Skip to content

Commit

Permalink
Add support for more models
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 28, 2023
1 parent c8796d8 commit 969b1f1
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 160 deletions.
33 changes: 33 additions & 0 deletions _examples/conversational/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.Conversational(context.Background(), &huggingface.ConversationalRequest{
Inputs: huggingface.ConverstationalInputs{
PastUserInputs: []string{
"Which movie is the best ?",
"Can you explain why ?",
},
GeneratedResponses: []string{
"It's Die Hard for sure.",
"It's the best movie ever.",
},
Text: "Can you explain why ?",
},
})
if err != nil {
log.Fatal(err)
}

fmt.Println(res.GeneratedText)
}
23 changes: 23 additions & 0 deletions _examples/feature_extraction/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.FeatureExtraction(context.Background(), &huggingface.FeatureExtractionRequest{
Inputs: []string{"Hello World"},
})
if err != nil {
log.Fatal(err)
}

fmt.Println(res[0][0])
}
99 changes: 99 additions & 0 deletions conversational.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package gohuggingface

import (
"context"
"encoding/json"
"errors"
)

// Used with ConversationalRequest
type ConversationalParameters struct {
// (Default: None). Integer to define the minimum length in tokens of the output summary.
MinLength *int `json:"min_length,omitempty"`

// (Default: None). Integer to define the maximum length in tokens of the output summary.
MaxLength *int `json:"max_length,omitempty"`

// (Default: None). Integer to define the top tokens considered within the sample operation to create
// new text.
TopK *int `json:"top_k,omitempty"`

// (Default: None). Float to define the tokens that are within the sample` operation of text generation.
// Add tokens in the sample for more probable to least probable until the sum of the probabilities is
// greater than top_p.
TopP *float64 `json:"top_p,omitempty"`

// (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling,
// 0 mens top_k=1, 100.0 is getting closer to uniform probability.
Temperature *float64 `json:"temperature,omitempty"`

// (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized
// to not be picked in successive generation passes.
RepetitionPenalty *float64 `json:"repetitionpenalty,omitempty"`

// (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum.
// Network can cause some overhead so it will be a soft limit.
MaxTime *float64 `json:"maxtime,omitempty"`
}

// Used with ConversationalRequest
type ConverstationalInputs struct {
// (Required) The last input from the user in the conversation.
Text string `json:"text"`

// A list of strings corresponding to the earlier replies from the model.
GeneratedResponses []string `json:"generated_responses,omitempty"`

// A list of strings corresponding to the earlier replies from the user.
// Should be of the same length of GeneratedResponses.
PastUserInputs []string `json:"past_user_inputs,omitempty"`
}

// Request structure for the conversational endpoint
type ConversationalRequest struct {
// (Required)
Inputs ConverstationalInputs `json:"inputs,omitempty"`

Parameters ConversationalParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// Used with ConversationalResponse
type Conversation struct {
// The last outputs from the model in the conversation, after the model has run.
GeneratedResponses []string `json:"generated_responses,omitempty"`

// The last inputs from the user in the conversation, after the model has run.
PastUserInputs []string `json:"past_user_inputs,omitempty"`
}

// Response structure for the conversational endpoint
type ConversationalResponse struct {
// The answer of the model
GeneratedText string `json:"generated_text,omitempty"`

// A facility dictionary to send back for the next input (with the new user input addition).
Conversation Conversation `json:"conversation,omitempty"`
}

// Conversational performs conversational AI using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided conversational inputs.
// The response contains the generated conversational response or an error if the request fails.
func (ic *InferenceClient) Conversational(ctx context.Context, req *ConversationalRequest) (*ConversationalResponse, error) {
if len(req.Inputs.Text) == 0 {
return nil, errors.New("text is required")
}

body, err := ic.post(ctx, req.Model, "conversational", req)
if err != nil {
return nil, err
}

conversationalResponse := ConversationalResponse{}
if err := json.Unmarshal(body, &conversationalResponse); err != nil {
return nil, err
}

return &conversationalResponse, nil
}
42 changes: 42 additions & 0 deletions feature_extraction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package gohuggingface

import (
"context"
"encoding/json"
"errors"
"fmt"
)

// Request structure for the feature extraction endpoint
type FeatureExtractionRequest struct {
// String to get the features from
Inputs []string `json:"inputs"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// Response structure for the feature extraction endpoint
type FeatureExtractionResponse [][][][]float64

// FeatureExtraction performs feature extraction using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided input data.
// The response contains the extracted features or an error if the request fails.
func (ic *InferenceClient) FeatureExtraction(ctx context.Context, req *FeatureExtractionRequest) (FeatureExtractionResponse, error) {
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

body, err := ic.post(ctx, req.Model, "feature-extraction", req)
if err != nil {
return nil, err
}

fmt.Println(string(body))

featureExtractionResponse := FeatureExtractionResponse{}
if err := json.Unmarshal(body, &featureExtractionResponse); err != nil {
return nil, err
}

return featureExtractionResponse, nil
}
27 changes: 27 additions & 0 deletions fill_mask.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package gohuggingface

import (
"context"
"encoding/json"
"errors"
)

// Request structure for the Fill Mask endpoint
type FillMaskRequest struct {
// (Required) a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask)
Expand All @@ -22,3 +28,24 @@ type FillMaskResponse []struct {
// The string representation of the token
TokenStr string `json:"token_str,omitempty"`
}

// FillMask performs masked language modeling using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
// The response contains the generated text with the masked tokens filled or an error if the request fails.
func (ic *InferenceClient) FillMask(ctx context.Context, req *FillMaskRequest) (FillMaskResponse, error) {
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

body, err := ic.post(ctx, req.Model, "fill-mask", req)
if err != nil {
return nil, err
}

fillMaskResponse := FillMaskResponse{}
if err := json.Unmarshal(body, &fillMaskResponse); err != nil {
return nil, err
}

return fillMaskResponse, nil
}
Loading

0 comments on commit 969b1f1

Please sign in to comment.