Skip to content

Commit

Permalink
Add FeatureExtractionWithAutomaticReduction
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 15, 2023
1 parent ff32446 commit 54e89c6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
18 changes: 14 additions & 4 deletions _examples/feature_extraction/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@ import (
func main() {
ic := huggingface.NewInferenceClient(os.Getenv("HUGGINGFACEHUB_API_TOKEN"))

res, err := ic.FeatureExtraction(context.Background(), &huggingface.FeatureExtractionRequest{
res1, err1 := ic.FeatureExtraction(context.Background(), &huggingface.FeatureExtractionRequest{
Inputs: []string{"Hello World"},
})
if err != nil {
log.Fatal(err)
if err1 != nil {
log.Fatal(err1)
}

fmt.Println(res[0][0])
fmt.Println(res1[0])

res2, err2 := ic.FeatureExtractionWithAutomaticReduction(context.Background(), &huggingface.FeatureExtractionRequest{
Inputs: []string{"Hello World"},
Model: "sentence-transformers/all-mpnet-base-v2",
})
if err2 != nil {
log.Fatal(err2)
}

fmt.Println(res2[0])
}
29 changes: 25 additions & 4 deletions feature_extraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
)

// Request structure for the feature extraction endpoint
Expand All @@ -16,7 +15,10 @@ type FeatureExtractionRequest struct {
}

// Response structure for the feature extraction endpoint
type FeatureExtractionResponse [][][][]float64
type FeatureExtractionResponse [][][][]float32

// Response structure for the feature extraction endpoint
type FeatureExtractionWithAutomaticReductionResponse [][]float32

// FeatureExtraction performs feature extraction using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided input data.
Expand All @@ -31,12 +33,31 @@ func (ic *InferenceClient) FeatureExtraction(ctx context.Context, req *FeatureEx
return nil, err
}

fmt.Println(string(body))

featureExtractionResponse := FeatureExtractionResponse{}
if err := json.Unmarshal(body, &featureExtractionResponse); err != nil {
return nil, err
}

return featureExtractionResponse, nil
}

// FeatureExtractionWithAutomaticReduction performs feature extraction using the specified model.
// It sends a POST request to the Hugging Face inference endpoint with the provided input data.
// The response contains the extracted features or an error if the request fails.
func (ic *InferenceClient) FeatureExtractionWithAutomaticReduction(ctx context.Context, req *FeatureExtractionRequest) (FeatureExtractionWithAutomaticReductionResponse, error) {
if len(req.Inputs) == 0 {
return nil, errors.New("inputs are required")
}

body, err := ic.post(ctx, req.Model, "feature-extraction", req)
if err != nil {
return nil, err
}

featureExtractionResponse := FeatureExtractionWithAutomaticReductionResponse{}
if err := json.Unmarshal(body, &featureExtractionResponse); err != nil {
return nil, err
}

return featureExtractionResponse, nil
}

0 comments on commit 54e89c6

Please sign in to comment.