diff --git a/_examples/fill_mask/main.go b/_examples/fill_mask/main.go new file mode 100644 index 0000000..82ee092 --- /dev/null +++ b/_examples/fill_mask/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + huggingface "github.com/hupe1980/go-huggingface" +) + +func main() { + ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN")) + + res, err := ic.FillMask(context.Background(), &huggingface.FillMaskRequest{ + Inputs: []string{"The answer to the universe is ."}, + }) + if err != nil { + log.Fatal(err) + } + + for _, r := range res { + fmt.Println("Sequence:", r.Sequence) + fmt.Println("Score:", r.Score) + fmt.Println("TokenID:", r.TokenID) + fmt.Println("TokenStr", r.TokenStr) + fmt.Println("---") + } +} diff --git a/fill_mask.go b/fill_mask.go new file mode 100644 index 0000000..e7ac36c --- /dev/null +++ b/fill_mask.go @@ -0,0 +1,24 @@ +package gohuggingface + +// Request structure for the Fill Mask endpoint +type FillMaskRequest struct { + // (Required) a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask) + Inputs []string `json:"inputs"` + Options Options `json:"options,omitempty"` + Model string `json:"-"` +} + +// Response structure for the Fill Mask endpoint +type FillMaskResponse []struct { + // The actual sequence of tokens that ran against the model (may contain special tokens) + Sequence string `json:"sequence,omitempty"` + + // The probability for this token. + Score float64 `json:"score,omitempty"` + + // The id of the token + TokenID int `json:"token,omitempty"` + + // The string representation of the token + TokenStr string `json:"token_str,omitempty"` +} diff --git a/huggingface.go b/huggingface.go index 969de84..43206b3 100644 --- a/huggingface.go +++ b/huggingface.go @@ -170,6 +170,27 @@ func (ic *InferenceClient) QuestionAnswering(ctx context.Context, req *QuestionA return questionAnsweringResponse, nil } +// FillMask performs masked language modeling using the specified model. +// It sends a POST request to the Hugging Face inference endpoint with the provided inputs. +// The response contains the generated text with the masked tokens filled or an error if the request fails. +func (ic *InferenceClient) FillMask(ctx context.Context, req *FillMaskRequest) (FillMaskResponse, error) { + if len(req.Inputs) == 0 { + return nil, errors.New("inputs are required") + } + + body, err := ic.post(ctx, req.Model, "fill-mask", req) + if err != nil { + return nil, err + } + + fillMaskResponse := FillMaskResponse{} + if err := json.Unmarshal(body, &fillMaskResponse); err != nil { + return nil, err + } + + return fillMaskResponse, nil +} + // post sends a POST request to the specified model and task with the provided payload. // It returns the response body or an error if the request fails. func (ic *InferenceClient) post(ctx context.Context, model, task string, payload any) ([]byte, error) {