Skip to content

Commit

Permalink
feat: support conversational hugging face (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed May 31, 2023
1 parent 7d8223c commit 28d8bf3
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 0 deletions.
18 changes: 18 additions & 0 deletions examples/llm/huggingface/conversational/main.go
@@ -0,0 +1,18 @@
package main

import (
"context"

"github.com/henomis/lingoose/llm/huggingface"
)

func main() {

llm := huggingface.New("microsoft/DialoGPT-medium", 1.0, true)

_, err := llm.Completion(context.Background(), "What is the NATO purpose?")
if err != nil {
panic(err)
}

}
78 changes: 78 additions & 0 deletions llm/huggingface/conversational.go
@@ -0,0 +1,78 @@
package huggingface

import (
"context"
"encoding/json"
)

type conversationalRequest struct {
Inputs conversationalInputs `json:"inputs,omitempty"`
Parameters conversationalParameters `json:"parameters,omitempty"`
Options Options `json:"options,omitempty"`
}

type conversationalParameters struct {
MinLength *int `json:"min_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
TopK *int `json:"top_k,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
RepetitionPenalty *float32 `json:"repetitionpenalty,omitempty"`
MaxTime *float32 `json:"maxtime,omitempty"`
}

type conversationalInputs struct {
Text string `json:"text,omitempty"`
GeneratedResponses []string `json:"generated_responses,omitempty"`
PastUserInputs []string `json:"past_user_inputs,omitempty"`
}

type conversationalResponse struct {
GeneratedText string `json:"generated_text,omitempty"`
Conversation conversation `json:"conversation,omitempty"`
}

type conversation struct {
GeneratedResponses []string `json:"generated_responses,omitempty"`
PastUserInputs []string `json:"past_user_inputs,omitempty"`
}

func (h *huggingFace) conversationalCompletion(ctx context.Context, prompt string) (string, error) {

isTrue := true
request := conversationalRequest{
Inputs: conversationalInputs{
Text: prompt,
},
Parameters: conversationalParameters{
Temperature: &h.temperature,
MinLength: h.minLength,
MaxLength: h.maxLength,
TopK: h.topK,
TopP: h.topP,
},
Options: Options{
WaitForModel: &isTrue,
},
}

jsonBuf, err := json.Marshal(request)
if err != nil {
return "", err
}

respBody, err := h.doRequest(ctx, jsonBuf, h.model)
if err != nil {
return "", err
}

cresp := conversationalResponse{}
err = json.Unmarshal(respBody, &cresp)
if err != nil {
return "", err
}

debugCompletion(prompt, cresp.GeneratedText)

return cresp.GeneratedText, nil
}
79 changes: 79 additions & 0 deletions llm/huggingface/http.go
@@ -0,0 +1,79 @@
package huggingface

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
)

func (h *huggingFace) doRequest(ctx context.Context, jsonBody []byte, model string) ([]byte, error) {

req, err := http.NewRequestWithContext(ctx, http.MethodPost, APIBaseURL+model, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, err
}
if req == nil {
return nil, errors.New("nil request created")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+h.token)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

err = checkRespForError(respBody)
if err != nil {
return nil, err
}

return respBody, nil
}

type apiError struct {
Error string `json:"error,omitempty"`
}

type apiErrors struct {
Errors []string `json:"error,omitempty"`
}

func checkRespForError(respJSON []byte) error {
{
buf := make([]byte, len(respJSON))
copy(buf, respJSON)
apiErr := apiError{}
err := json.Unmarshal(buf, &apiErr)
if err != nil {
return err
}
if apiErr.Error != "" {
return errors.New(string(respJSON))
}
}

{
buf := make([]byte, len(respJSON))
copy(buf, respJSON)
apiErrs := apiErrors{}
err := json.Unmarshal(buf, &apiErrs)
if err != nil {
return err
}
if apiErrs.Errors != nil {
return errors.New(string(respJSON))
}
}

return nil
}
104 changes: 104 additions & 0 deletions llm/huggingface/huggingface.go
@@ -0,0 +1,104 @@
package huggingface

import (
"context"
"fmt"
"os"
)

const APIBaseURL = "https://api-inference.huggingface.co/models/"

const (
ErrHuggingFaceCompletion = "huggingface completion error"
)

type HuggingFaceMode int

const (
HuggingFaceModeCoversational HuggingFaceMode = iota
)

type huggingFace struct {
mode HuggingFaceMode
token string
model string
temperature float32
maxLength *int
minLength *int
topK *int
topP *float32
verbose bool
}

func New(model string, temperature float32, verbose bool) *huggingFace {
return &huggingFace{
mode: HuggingFaceModeCoversational,
token: os.Getenv("HUGGING_FACE_HUB_TOKEN"),
model: model,
temperature: temperature,
verbose: verbose,
}
}

func (h *huggingFace) WithModel(model string) *huggingFace {
h.model = model
return h
}

func (h *huggingFace) WithTemperature(temperature float32) *huggingFace {
h.temperature = temperature
return h
}

func (h *huggingFace) WithMaxLength(maxLength int) *huggingFace {
h.maxLength = &maxLength
return h
}

func (h *huggingFace) WithMinLength(minLength int) *huggingFace {
h.minLength = &minLength
return h
}

func (h *huggingFace) WithToken(token string) *huggingFace {
h.token = token
return h
}

func (h *huggingFace) WithVerbose(verbose bool) *huggingFace {
h.verbose = verbose
return h
}

func (h *huggingFace) WithTopK(topK int) *huggingFace {
h.topK = &topK
return h
}

func (h *huggingFace) WithTopP(topP float32) *huggingFace {
h.topP = &topP
return h
}

func (h *huggingFace) WithMode(mode HuggingFaceMode) *huggingFace {
h.mode = mode
return h
}

func (h *huggingFace) Completion(ctx context.Context, prompt string) (string, error) {

var output string
var err error
switch h.mode {
case HuggingFaceModeCoversational:
fallthrough
default:
output, err = h.conversationalCompletion(ctx, prompt)
}

if err != nil {
return "", fmt.Errorf("%s: %w", ErrHuggingFaceCompletion, err)
}

return output, nil
}
14 changes: 14 additions & 0 deletions llm/huggingface/shared.go
@@ -0,0 +1,14 @@
package huggingface

import "fmt"

type Options struct {
UseGPU *bool `json:"use_gpu,omitempty"`
UseCache *bool `json:"use_cache,omitempty"`
WaitForModel *bool `json:"wait_for_model,omitempty"`
}

func debugCompletion(prompt string, content string) {
fmt.Printf("---USER---\n%s\n", prompt)
fmt.Printf("---AI---\n%s\n", content)
}

0 comments on commit 28d8bf3

Please sign in to comment.