From 5b4a36328887e2de06bf5a5b5cc1e2873fc2db67 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Sat, 16 Sep 2023 08:41:22 +0200 Subject: [PATCH] Change float32 to float64 --- _examples/feature_extraction/main.go | 10 ++++++++++ feature_extraction.go | 4 ++-- huggingface.go | 4 ++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/_examples/feature_extraction/main.go b/_examples/feature_extraction/main.go index 7c1492d..3b85c78 100644 --- a/_examples/feature_extraction/main.go +++ b/_examples/feature_extraction/main.go @@ -14,20 +14,30 @@ func main() { res1, err1 := ic.FeatureExtraction(context.Background(), &huggingface.FeatureExtractionRequest{ Inputs: []string{"Hello World"}, + Options: huggingface.Options{ + WaitForModel: huggingface.PTR(true), + }, }) if err1 != nil { log.Fatal(err1) } + fmt.Println("FeatureExtraction:") fmt.Println(res1[0]) + fmt.Println() res2, err2 := ic.FeatureExtractionWithAutomaticReduction(context.Background(), &huggingface.FeatureExtractionRequest{ Inputs: []string{"Hello World"}, Model: "sentence-transformers/all-mpnet-base-v2", + Options: huggingface.Options{ + WaitForModel: huggingface.PTR(true), + }, }) if err2 != nil { log.Fatal(err2) } + fmt.Println("FeatureExtractionWithAutomaticReduction:") fmt.Println(res2[0]) + fmt.Println() } diff --git a/feature_extraction.go b/feature_extraction.go index f7ebfe0..6ef1435 100644 --- a/feature_extraction.go +++ b/feature_extraction.go @@ -15,10 +15,10 @@ type FeatureExtractionRequest struct { } // Response structure for the feature extraction endpoint -type FeatureExtractionResponse [][][][]float32 +type FeatureExtractionResponse [][][][]float64 // Response structure for the feature extraction endpoint -type FeatureExtractionWithAutomaticReductionResponse [][]float32 +type FeatureExtractionWithAutomaticReductionResponse [][]float64 // FeatureExtraction performs feature extraction using the specified model. // It sends a POST request to the Hugging Face inference endpoint with the provided input data. diff --git a/huggingface.go b/huggingface.go index 0582d5c..6a2bd1b 100644 --- a/huggingface.go +++ b/huggingface.go @@ -203,3 +203,7 @@ func contains[T comparable](collection []T, element T) bool { return false } + +func PTR[T any](input T) *T { + return &input +}