Skip to content

Commit

Permalink
Add text classification
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 13, 2023
1 parent 900fe46 commit ff32446
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
24 changes: 24 additions & 0 deletions _examples/text_classification/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package main

import (
"context"
"fmt"
"log"
"os"

huggingface "github.com/hupe1980/go-huggingface"
)

func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.TextClassification(context.Background(), &huggingface.TextClassificationRequest{
Inputs: "The answer to the universe is 42",
//Model: "deepset/deberta-v3-base-injection", // overwrite recommended model
})
if err != nil {
log.Fatal(err)
}

fmt.Println(res[0])
}
45 changes: 45 additions & 0 deletions text_classification.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package gohuggingface

import (
"context"
"encoding/json"
"errors"
)

// TextClassificationRequest represents a request for text classification.
type TextClassificationRequest struct {
// Inputs is the string to be generated from.
Inputs string `json:"inputs"`
// Options represents optional settings for the classification.
Options Options `json:"options,omitempty"`
// Model is the name of the model to use for classification.
Model string `json:"-"`
}

// TextClassificationResponse represents a response for text classification.
type TextClassificationResponse [][]struct {
// Label is the label for the class (model-specific).
Label string `json:"label,omitempty"`
// Score is a float that represents how likely it is that the text belongs to this class.
Score float32 `json:"score,omitempty"`
}

// TextClassification performs text classification using the provided request.
func (ic *InferenceClient) TextClassification(ctx context.Context, req *TextClassificationRequest) (TextClassificationResponse, error) {
// Check if inputs are provided.
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

body, err := ic.post(ctx, req.Model, "text-classification", req)
if err != nil {
return nil, err
}

textClassificationResponse := TextClassificationResponse{}
if err := json.Unmarshal(body, &textClassificationResponse); err != nil {
return nil, err
}

return textClassificationResponse, nil
}

0 comments on commit ff32446

Please sign in to comment.