Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions core/config/model_config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"fmt"
"os"
"regexp"
"slices"
Expand Down Expand Up @@ -475,7 +476,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
cfg.syncKnownUsecasesFromString()
}

func (c *ModelConfig) Validate() bool {
func (c *ModelConfig) Validate() (bool, error) {
downloadedFileNames := []string{}
for _, f := range c.DownloadFiles {
downloadedFileNames = append(downloadedFileNames, f.Filename)
Expand All @@ -489,17 +490,20 @@ func (c *ModelConfig) Validate() bool {
}
if strings.HasPrefix(n, string(os.PathSeparator)) ||
strings.Contains(n, "..") {
return false
return false, fmt.Errorf("invalid file path: %s", n)
}
}

if c.Backend != "" {
// a regex that checks that is a string name with no special characters, except '-' and '_'
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
return re.MatchString(c.Backend)
if !re.MatchString(c.Backend) {
return false, fmt.Errorf("invalid backend name: %s", c.Backend)
}
return true, nil
}

return true
return true, nil
}

func (c *ModelConfig) HasTemplate() bool {
Expand Down
6 changes: 3 additions & 3 deletions core/config/model_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op
}

for _, cc := range c {
if cc.Validate() {
if valid, _ := cc.Validate(); valid {
bcl.configs[cc.Name] = *cc
}
}
Expand All @@ -184,7 +184,7 @@ func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderO
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
}

if c.Validate() {
if valid, _ := c.Validate(); valid {
bcl.configs[c.Name] = *c
} else {
return fmt.Errorf("config is not valid")
Expand Down Expand Up @@ -362,7 +362,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file")
continue
}
if c.Validate() {
if valid, _ := c.Validate(); valid {
bcl.configs[c.Name] = *c
} else {
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
Expand Down
12 changes: 9 additions & 3 deletions core/config/model_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ known_usecases:
config, err := readModelConfigFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(config).ToNot(BeNil())
Expect(config.Validate()).To(BeFalse())
valid, err := config.Validate()
Expect(err).To(HaveOccurred())
Expect(valid).To(BeFalse())
Expect(config.KnownUsecases).ToNot(BeNil())
})
It("Test Validate", func() {
Expand All @@ -46,7 +48,9 @@ parameters:
Expect(config).ToNot(BeNil())
// two configs in config.yaml
Expect(config.Name).To(Equal("bar-baz"))
Expect(config.Validate()).To(BeTrue())
valid, err := config.Validate()
Expect(err).To(BeNil())
Expect(valid).To(BeTrue())

// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
httpClient := http.Client{}
Expand All @@ -63,7 +67,9 @@ parameters:
Expect(config).ToNot(BeNil())
// two configs in config.yaml
Expect(config.Name).To(Equal("hermes-2-pro-mistral"))
Expect(config.Validate()).To(BeTrue())
valid, err = config.Validate()
Expect(err).To(BeNil())
Expect(valid).To(BeTrue())
})
})
It("Properly handles backend usecase matching", func() {
Expand Down
13 changes: 12 additions & 1 deletion core/gallery/backends.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
return fmt.Errorf("failed copying: %w", err)
}
} else {
uri := downloader.URI(config.URI)
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloading backend")
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
Expand All @@ -177,16 +177,27 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend")
break
}
}

if !success {
log.Error().Str("uri", config.URI).Str("backendPath", backendPath).Err(err).Msg("Failed to download backend")
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
}
} else {
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend")
}
}

// sanity check - check if runfile is present
runFile := filepath.Join(backendPath, runFile)
if _, err := os.Stat(runFile); os.IsNotExist(err) {
log.Error().Str("runFile", runFile).Msg("Run file not found")
return fmt.Errorf("not a valid backend: run file not found %q", runFile)
}

// Create metadata for the backend
metadata := &BackendMetadata{
Name: name,
Expand Down
2 changes: 1 addition & 1 deletion core/gallery/backends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ var _ = Describe("Gallery Backends", func() {
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
})

It("should overwrite existing backend", func() {
Expand Down
56 changes: 45 additions & 11 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"os"
"path/filepath"
"strings"
"time"

"github.com/lithammer/fuzzysearch/fuzzy"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/LocalAI/pkg/xsync"
"github.com/rs/zerolog/log"

"gopkg.in/yaml.v2"
Expand All @@ -19,7 +21,7 @@ import (
func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand All @@ -32,7 +34,7 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down Expand Up @@ -141,7 +143,7 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst

// Get models from galleries
for _, gallery := range galleries {
galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
return true
}
Expand Down Expand Up @@ -182,7 +184,7 @@ func AvailableBackends(galleries []config.Gallery, systemState *system.SystemSta
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
uri := downloader.URI(url)
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -194,6 +196,17 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error)
return refFile, err
}

type galleryCacheEntry struct {
yamlEntry []byte
lastUpdated time.Time
}

func (entry galleryCacheEntry) hasExpired() bool {
return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour))
}

var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]()

func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) {
var models []T = []T{}

Expand All @@ -204,16 +217,37 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin
return models, err
}
}

cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL)
if galleryCache.Exists(cacheKey) {
entry := galleryCache.Get(cacheKey)
// refresh if last updated is more than 1 hour ago
if !entry.hasExpired() {
err := yaml.Unmarshal(entry.yamlEntry, &models)
if err != nil {
return models, err
}
} else {
galleryCache.Delete(cacheKey)
}
}

uri := downloader.URI(gallery.URL)

err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
if yamlErr, ok := err.(*yaml.TypeError); ok {
log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models)
if len(models) == 0 {
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
galleryCache.Set(cacheKey, galleryCacheEntry{
yamlEntry: d,
lastUpdated: time.Now(),
})
return yaml.Unmarshal(d, &models)
})
if err != nil {
if yamlErr, ok := err.(*yaml.TypeError); ok {
log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models)
}
return models, fmt.Errorf("failed to read gallery elements: %w", err)
}
return models, err
}

// Add gallery to models
Expand Down
57 changes: 55 additions & 2 deletions core/gallery/importers/importers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ package importers

import (
"encoding/json"
"fmt"
"os"
"strings"

"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"

"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)

Expand All @@ -28,6 +33,10 @@ type Importer interface {
Import(details Details) (gallery.ModelConfig, error)
}

func hasYAMLExtension(uri string) bool {
return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml")
}

func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
var err error
var modelConfig gallery.ModelConfig
Expand All @@ -42,26 +51,70 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
if err != nil {
// maybe not a HF repository
// TODO: maybe we can check if the URI is a valid HF repository
log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository")
log.Debug().Str("uri", uri).Str("hfrepoID", hfrepoID).Msg("Failed to get model details, maybe not a HF repository")
} else {
log.Debug().Str("uri", uri).Msg("Got model details")
log.Debug().Any("details", hfDetails).Msg("Model details")
}

// handle local config files ("/my-model.yaml" or "file://my-model.yaml")
localURI := uri
if strings.HasPrefix(uri, downloader.LocalPrefix) {
localURI = strings.TrimPrefix(uri, downloader.LocalPrefix)
}

// if a file exists or it's an url that ends with .yaml or .yml, read the config file directly
if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) {
var modelYAML []byte
if downloader.URI(localURI).LooksLikeURL() {
err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error {
modelYAML = i
return nil
})
if err != nil {
log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition")
return gallery.ModelConfig{}, err
}
} else {
modelYAML, err = os.ReadFile(localURI)
if err != nil {
log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition")
return gallery.ModelConfig{}, err
}
}

var modelConfig config.ModelConfig
if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil {
return gallery.ModelConfig{}, e
}

configFile, err := yaml.Marshal(modelConfig)
return gallery.ModelConfig{
Description: modelConfig.Description,
Name: modelConfig.Name,
ConfigFile: string(configFile),
}, err
}

details := Details{
HuggingFace: hfDetails,
URI: uri,
Preferences: preferences,
}

importerMatched := false
for _, importer := range defaultImporters {
if importer.Match(details) {
importerMatched = true
modelConfig, err = importer.Import(details)
if err != nil {
continue
}
break
}
}
return modelConfig, err
if !importerMatched {
return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri)
}
return modelConfig, nil
}
Loading
Loading