Skip to content

Commit

Permalink
Add fill mask
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 25, 2023
1 parent 9c5541f commit 6a94b33
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
29 changes: 29 additions & 0 deletions _examples/fill_mask/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.FillMask(context.Background(), &huggingface.FillMaskRequest{
Inputs: []string{"The answer to the universe is <mask>."},
})
if err != nil {
log.Fatal(err)
}

for _, r := range res {
fmt.Println("Sequence:", r.Sequence)
fmt.Println("Score:", r.Score)
fmt.Println("TokenID:", r.TokenID)
fmt.Println("TokenStr", r.TokenStr)
fmt.Println("---")
}
}
24 changes: 24 additions & 0 deletions fill_mask.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package gohuggingface

// Request structure for the Fill Mask endpoint
type FillMaskRequest struct {
// (Required) a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask)
Inputs []string `json:"inputs"`
Options Options `json:"options,omitempty"`
Model string `json:"-"`
}

// Response structure for the Fill Mask endpoint
type FillMaskResponse []struct {
// The actual sequence of tokens that ran against the model (may contain special tokens)
Sequence string `json:"sequence,omitempty"`

// The probability for this token.
Score float64 `json:"score,omitempty"`

// The id of the token
TokenID int `json:"token,omitempty"`

// The string representation of the token
TokenStr string `json:"token_str,omitempty"`
}
21 changes: 21 additions & 0 deletions huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,27 @@ func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionA
return questionAnsweringResponse, nil
}

// FillMask performs masked language modeling using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided inputs.
// The response contains the generated text with the masked tokens filled or an error if the request fails.
func (ic *InferenceClient) FillMask(ctx context.Context, req *FillMaskRequest) (FillMaskResponse, error) {
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

body, err := ic.post(ctx, req.Model, "fill-mask", req)
if err != nil {
return nil, err
}

fillMaskResponse := FillMaskResponse{}
if err := json.Unmarshal(body, &fillMaskResponse); err != nil {
return nil, err
}

return fillMaskResponse, nil
}

// post sends a POST request to the specified model and task with the provided payload.
// It returns the response body or an error if the request fails.
func (ic *InferenceClient) post(ctx context.Context, model, task string, payload any) ([]byte, error) {
Expand Down

0 comments on commit 6a94b33

Please sign in to comment.