Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dolly/redpajama/bloomz models support #214

Merged
merged 5 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ BINARY_NAME=local-ai

GOLLAMA_VERSION?=c03e8adbc45c866e0f6d876af1887d6b01d57eb4
GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
GOGPT2_VERSION?=d498232
mudler marked this conversation as resolved.
Show resolved Hide resolved
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
Expand Down
24 changes: 24 additions & 0 deletions api/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions...,
)
}
case *gpt2.Dolly:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(c.Temperature),
gpt2.SetTopP(c.TopP),
gpt2.SetTopK(c.TopK),
gpt2.SetTokens(c.Maxtokens),
gpt2.SetThreads(c.Threads),
}

if c.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch))
}

if c.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed))
}

return model.Predict(
s,
predictOptions...,
)
}
case *gpt2.GPT2:
fn = func() (string, error) {
// Generate the prediction using the language model
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ require (
github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be
github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35
github.com/go-audio/wav v1.1.0
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6
github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c
github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3
github.com/gofiber/fiber/v2 v2.45.0
Expand Down
10 changes: 10 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM=
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 h1:cfOi4TWvQ6JsAm9Q1A8I8j9YfNy10bmIfwOiyGyU5wQ=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230508165257-c03e8adbc45c h1:JoW2+LKrSemoV32QRwrEC5f53erym96NCsUSM3wSVbM=
github.com/go-skynet/go-llama.cpp v0.0.0-20230508165257-c03e8adbc45c/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I=
github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3 h1:YNi1oetK5kGJoUgT3/r/Wj3XPOICWf3nwHsz5v89iSs=
github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
Expand Down
47 changes: 46 additions & 1 deletion pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ type ModelLoader struct {
gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM
dollymodels map[string]*gpt2.Dolly
rwkv map[string]*rwkv.RwkvState
bert map[string]*bert.Bert

promptsTemplates map[string]*template.Template
}

Expand All @@ -40,6 +40,9 @@ func NewModelLoader(modelPath string) *ModelLoader {
gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
dollymodels: make(map[string]*gpt2.Dolly),


models: make(map[string]*llama.LLama),
rwkv: make(map[string]*rwkv.RwkvState),
bert: make(map[string]*bert.Bert),
Expand Down Expand Up @@ -128,6 +131,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil
}

func (ml *ModelLoader) LoadDollyModel(modelName string) (*gpt2.Dolly, error) {
ml.mu.Lock()
defer ml.mu.Unlock()

// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}

if m, ok := ml.dollymodels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}

// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)

model, err := gpt2.NewDolly(modelFile)
if err != nil {
return nil, err
}

// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}

ml.dollymodels[modelName] = model
return model, err
}

func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
Expand Down Expand Up @@ -331,6 +366,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadLLaMAModel(modelFile, llamaOpts...)
case "stablelm":
return ml.LoadStableLMModel(modelFile)
case "dolly":
return ml.LoadDollyModel(modelFile)
case "gpt2":
return ml.LoadGPT2Model(modelFile)
case "gptj":
Expand Down Expand Up @@ -391,6 +428,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
err = multierror.Append(err, modelerr)
}

model, modelerr = ml.LoadDollyModel(modelFile)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}

model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {
updateModels(model)
Expand Down