Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Minor improvements to BackendConfigLoader #2353

Merged
merged 13 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 85 additions & 92 deletions core/config/backend_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ type BackendConfigLoader struct {
sync.Mutex
}

func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}

type LoadOptions struct {
debug bool
threads, ctxSize int
Expand Down Expand Up @@ -61,46 +67,8 @@ func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) {
}
}

// Load a config file for a model
func (cl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {

// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}

cfgExisting, exists := cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := cl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = cl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}

cfg.SetDefaults(opts...)

return cfg, nil
}

func NewBackendConfigLoader() *BackendConfigLoader {
return &BackendConfigLoader{
configs: make(map[string]BackendConfig),
}
}
func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
// TODO: either in the next PR or the next commit, I want to merge these down into a single function that looks at the first few characters of the file to determine if we need to deserialize to []BackendConfig or BackendConfig
func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*BackendConfig, error) {
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
Expand All @@ -117,7 +85,7 @@ func ReadBackendConfigFile(file string, opts ...ConfigLoaderOption) ([]*BackendC
return *c, nil
}

func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*BackendConfig, error) {
lo := &LoadOptions{}
lo.Apply(opts...)

Expand All @@ -134,44 +102,79 @@ func ReadBackendConfig(file string, opts ...ConfigLoaderOption) (*BackendConfig,
return c, nil
}

func (cm *BackendConfigLoader) LoadBackendConfigFile(file string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
c, err := ReadBackendConfigFile(file, opts...)
// Load a config file for a model
func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*BackendConfig, error) {

// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
},
}

cfgExisting, exists := bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
} else {
// Try loading a model config file
modelConfig := filepath.Join(modelPath, modelName+".yaml")
if _, err := os.Stat(modelConfig); err == nil {
if err := bcl.LoadBackendConfig(
modelConfig, opts...,
); err != nil {
return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error())
}
cfgExisting, exists = bcl.GetBackendConfig(modelName)
if exists {
cfg = &cfgExisting
}
}
}

cfg.SetDefaults(opts...)

return cfg, nil
}

// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readMultipleBackendConfigsFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot load config file: %w", err)
}

for _, cc := range c {
cm.configs[cc.Name] = *cc
bcl.configs[cc.Name] = *cc
}
return nil
}

func (cl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
cl.Lock()
defer cl.Unlock()
c, err := ReadBackendConfig(file, opts...)
func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
}

cl.configs[c.Name] = *c
bcl.configs[c.Name] = *c
return nil
}

func (cl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
cl.Lock()
defer cl.Unlock()
v, exists := cl.configs[m]
func (bcl *BackendConfigLoader) GetBackendConfig(m string) (BackendConfig, bool) {
bcl.Lock()
defer bcl.Unlock()
v, exists := bcl.configs[m]
return v, exists
}

func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
bcl.Lock()
defer bcl.Unlock()
var res []BackendConfig
for _, v := range cl.configs {
for _, v := range bcl.configs {
res = append(res, v)
}

Expand All @@ -182,26 +185,16 @@ func (cl *BackendConfigLoader) GetAllBackendConfigs() []BackendConfig {
return res
}

func (cl *BackendConfigLoader) RemoveBackendConfig(m string) {
cl.Lock()
defer cl.Unlock()
delete(cl.configs, m)
}

func (cl *BackendConfigLoader) ListBackendConfigs() []string {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike most of the other moves, this function was removed. All places that used it have been converted to GetAllBackendConfigs() - the one place in live code essentially did that anyway, and the test has been modified.

cl.Lock()
defer cl.Unlock()
var res []string
for k := range cl.configs {
res = append(res, k)
}
return res
func (bcl *BackendConfigLoader) RemoveBackendConfig(m string) {
bcl.Lock()
defer bcl.Unlock()
delete(bcl.configs, m)
}

// Preload prepare models if they are not local but url or huggingface repositories
func (cl *BackendConfigLoader) Preload(modelPath string) error {
cl.Lock()
defer cl.Unlock()
func (bcl *BackendConfigLoader) Preload(modelPath string) error {
bcl.Lock()
defer bcl.Unlock()

status := func(fileName, current, total string, percent float64) {
utils.DisplayDownloadFunction(fileName, current, total, percent)
Expand All @@ -223,7 +216,7 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}

for i, config := range cl.configs {
for i, config := range bcl.configs {

// Download files and verify their SHA
for i, file := range config.DownloadFiles {
Expand Down Expand Up @@ -252,10 +245,10 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}

cc := cl.configs[i]
cc := bcl.configs[i]
c := &cc
c.PredictionOptions.Model = modelFileName
cl.configs[i] = *c
bcl.configs[i] = *c
}

if config.IsMMProjURL() {
Expand All @@ -269,32 +262,32 @@ func (cl *BackendConfigLoader) Preload(modelPath string) error {
}
}

cc := cl.configs[i]
cc := bcl.configs[i]
c := &cc
c.MMProj = modelFileName
cl.configs[i] = *c
bcl.configs[i] = *c
}

if cl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", cl.configs[i].Name))
if bcl.configs[i].Name != "" {
glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name))
}
if cl.configs[i].Description != "" {
if bcl.configs[i].Description != "" {
//glamText("**Description**")
glamText(cl.configs[i].Description)
glamText(bcl.configs[i].Description)
}
if cl.configs[i].Usage != "" {
if bcl.configs[i].Usage != "" {
//glamText("**Usage**")
glamText(cl.configs[i].Usage)
glamText(bcl.configs[i].Usage)
}
}
return nil
}

// LoadBackendConfigsFromPath reads all the configurations of the models from a path
// (non-recursive)
func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
cm.Lock()
defer cm.Unlock()
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()
entries, err := os.ReadDir(path)
if err != nil {
return err
Expand All @@ -313,9 +306,9 @@ func (cm *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...C
strings.HasPrefix(file.Name(), ".") {
continue
}
c, err := ReadBackendConfig(filepath.Join(path, file.Name()), opts...)
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err == nil {
cm.configs[c.Name] = *c
bcl.configs[c.Name] = *c
}
}

Expand Down
29 changes: 16 additions & 13 deletions core/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package config_test
package config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to golang docs, since this file ends in _test.go it won't be included in our final binary even if the config package is linked. Therefore, I've dropped the _test package suffix here, to allow us to drop the public exports of the readBackendConfig functions

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is more a bit of a code/style in writing tests. I always prefer (especially in go) to go with black box testing rather then white box testing - in the long run black box tests are less brittle and allow me to test the intention rather then corner cases, but I do agree that here things are quite shady and can be improved

import (
"os"

. "github.com/go-skynet/LocalAI/core/config"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
Expand All @@ -17,8 +15,8 @@ var _ = Describe("Test cases for config related functions", func() {

Context("Test Read configuration functions", func() {
configFile = os.Getenv("CONFIG_FILE")
It("Test ReadConfigFile", func() {
config, err := ReadBackendConfigFile(configFile)
It("Test readConfigFile", func() {
config, err := readMultipleBackendConfigsFromFile(configFile)
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
// two configs in config.yaml
Expand All @@ -27,26 +25,31 @@ var _ = Describe("Test cases for config related functions", func() {
})

It("Test LoadConfigs", func() {
cm := NewBackendConfigLoader()
bcl := NewBackendConfigLoader()
opts := NewApplicationConfig()
err := cm.LoadBackendConfigsFromPath(opts.ModelPath)
err := bcl.LoadBackendConfigsFromPath(opts.ModelPath)
Expect(err).To(BeNil())
Expect(cm.ListBackendConfigs()).ToNot(BeNil())
configs := bcl.GetAllBackendConfigs()
loadedModelNames := []string{}
for _, v := range configs {
loadedModelNames = append(loadedModelNames, v.Name)
}
Expect(configs).ToNot(BeNil())

// config should includes gpt4all models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all"))
Expect(loadedModelNames).To(ContainElements("gpt4all"))

// config should includes gpt2 models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("gpt4all-2"))
Expect(loadedModelNames).To(ContainElements("gpt4all-2"))

// config should includes text-embedding-ada-002 models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("text-embedding-ada-002"))
Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002"))

// config should includes rwkv_test models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("rwkv_test"))
Expect(loadedModelNames).To(ContainElements("rwkv_test"))

// config should includes whisper-1 models's api.config
Expect(cm.ListBackendConfigs()).To(ContainElements("whisper-1"))
Expect(loadedModelNames).To(ContainElements("whisper-1"))
})
})
})
7 changes: 3 additions & 4 deletions core/startup/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}

if options.ConfigFile != "" {
if err := cl.LoadBackendConfigFile(options.ConfigFile, configLoaderOpts...); err != nil {
if err := cl.LoadMultipleBackendConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
log.Error().Err(err).Msg("error loading config file")
}
}
Expand All @@ -94,9 +94,8 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode
}

if options.Debug {
for _, v := range cl.ListBackendConfigs() {
cfg, _ := cl.GetBackendConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
for _, v := range cl.GetAllBackendConfigs() {
log.Debug().Msgf("Model: %s (config: %+v)", v.Name, v)
}
}

Expand Down
Loading