diff --git a/core/config/model_config.go b/core/config/model_config.go index 50ac44625cb6..4f02776c48fd 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "regexp" "slices" @@ -475,7 +476,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.syncKnownUsecasesFromString() } -func (c *ModelConfig) Validate() bool { +func (c *ModelConfig) Validate() (bool, error) { downloadedFileNames := []string{} for _, f := range c.DownloadFiles { downloadedFileNames = append(downloadedFileNames, f.Filename) @@ -489,17 +490,20 @@ func (c *ModelConfig) Validate() bool { } if strings.HasPrefix(n, string(os.PathSeparator)) || strings.Contains(n, "..") { - return false + return false, fmt.Errorf("invalid file path: %s", n) } } if c.Backend != "" { // a regex that checks that is a string name with no special characters, except '-' and '_' re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) - return re.MatchString(c.Backend) + if !re.MatchString(c.Backend) { + return false, fmt.Errorf("invalid backend name: %s", c.Backend) + } + return true, nil } - return true + return true, nil } func (c *ModelConfig) HasTemplate() bool { diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index 9895a4a0e5f8..f0f2c3338c13 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -169,7 +169,7 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op } for _, cc := range c { - if cc.Validate() { + if valid, _ := cc.Validate(); valid { bcl.configs[cc.Name] = *cc } } @@ -184,7 +184,7 @@ func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderO return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err) } - if c.Validate() { + if valid, _ := c.Validate(); valid { bcl.configs[c.Name] = *c } else { return fmt.Errorf("config is not valid") @@ -362,7 +362,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file") continue } - if c.Validate() { + if valid, _ := c.Validate(); valid { bcl.configs[c.Name] = *c } else { log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid") diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go index 9d49e270c751..342b10c47f8b 100644 --- a/core/config/model_config_test.go +++ b/core/config/model_config_test.go @@ -28,7 +28,9 @@ known_usecases: config, err := readModelConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) - Expect(config.Validate()).To(BeFalse()) + valid, err := config.Validate() + Expect(err).To(HaveOccurred()) + Expect(valid).To(BeFalse()) Expect(config.KnownUsecases).ToNot(BeNil()) }) It("Test Validate", func() { @@ -46,7 +48,9 @@ parameters: Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("bar-baz")) - Expect(config.Validate()).To(BeTrue()) + valid, err := config.Validate() + Expect(err).To(BeNil()) + Expect(valid).To(BeTrue()) // download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml httpClient := http.Client{} @@ -63,7 +67,9 @@ parameters: Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("hermes-2-pro-mistral")) - Expect(config.Validate()).To(BeTrue()) + valid, err = config.Validate() + Expect(err).To(BeNil()) + Expect(valid).To(BeTrue()) }) }) It("Properly handles backend usecase matching", func() { diff --git a/core/gallery/backends.go b/core/gallery/backends.go index aee4b2d93928..9049664b3549 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -164,7 +164,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL return fmt.Errorf("failed copying: %w", err) } } else { - uri := downloader.URI(config.URI) + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloading backend") if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { success := false // Try to download from mirrors @@ -177,16 +177,27 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL } if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { success = true + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend") break } } if !success { + log.Error().Str("uri", config.URI).Str("backendPath", backendPath).Err(err).Msg("Failed to download backend") return fmt.Errorf("failed to download backend %q: %v", config.URI, err) } + } else { + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend") } } + // sanity check - check if runfile is present + runFile := filepath.Join(backendPath, runFile) + if _, err := os.Stat(runFile); os.IsNotExist(err) { + log.Error().Str("runFile", runFile).Msg("Run file not found") + return fmt.Errorf("not a valid backend: run file not found %q", runFile) + } + // Create metadata for the backend metadata := &BackendMetadata{ Name: name, diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go index 15900d25018b..756d2e7a23b8 100644 --- a/core/gallery/backends_test.go +++ b/core/gallery/backends_test.go @@ -563,8 +563,8 @@ var _ = Describe("Gallery Backends", func() { ) Expect(err).NotTo(HaveOccurred()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) - Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created Expect(newPath).To(BeADirectory()) + Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created }) It("should overwrite existing backend", func() { diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 62362148ecef..0475b898c645 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -6,11 +6,13 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/lithammer/fuzzysearch/fuzzy" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/LocalAI/pkg/xsync" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" @@ -19,7 +21,7 @@ import ( func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) { var config T uri := downloader.URI(url) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { @@ -32,7 +34,7 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) { func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) { var config T uri := downloader.URI(url) - err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error { + err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { @@ -141,7 +143,7 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { + galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { return true } @@ -182,7 +184,7 @@ func AvailableBackends(galleries []config.Gallery, systemState *system.SystemSta func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string uri := downloader.URI(url) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) @@ -194,6 +196,17 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) return refFile, err } +type galleryCacheEntry struct { + yamlEntry []byte + lastUpdated time.Time +} + +func (entry galleryCacheEntry) hasExpired() bool { + return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour)) +} + +var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]() + func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) { var models []T = []T{} @@ -204,16 +217,37 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin return models, err } } + + cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL) + if galleryCache.Exists(cacheKey) { + entry := galleryCache.Get(cacheKey) + // refresh if last updated is more than 1 hour ago + if !entry.hasExpired() { + err := yaml.Unmarshal(entry.yamlEntry, &models) + if err != nil { + return models, err + } + } else { + galleryCache.Delete(cacheKey) + } + } + uri := downloader.URI(gallery.URL) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { - return yaml.Unmarshal(d, &models) - }) - if err != nil { - if yamlErr, ok := err.(*yaml.TypeError); ok { - log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models) + if len(models) == 0 { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { + galleryCache.Set(cacheKey, galleryCacheEntry{ + yamlEntry: d, + lastUpdated: time.Now(), + }) + return yaml.Unmarshal(d, &models) + }) + if err != nil { + if yamlErr, ok := err.(*yaml.TypeError); ok { + log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models) + } + return models, fmt.Errorf("failed to read gallery elements: %w", err) } - return models, err } // Add gallery to models diff --git a/core/gallery/importers/importers.go b/core/gallery/importers/importers.go index 238aad6f1634..283a3349a5e4 100644 --- a/core/gallery/importers/importers.go +++ b/core/gallery/importers/importers.go @@ -2,11 +2,16 @@ package importers import ( "encoding/json" + "fmt" + "os" "strings" "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/downloader" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" ) @@ -28,6 +33,10 @@ type Importer interface { Import(details Details) (gallery.ModelConfig, error) } +func hasYAMLExtension(uri string) bool { + return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml") +} + func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) { var err error var modelConfig gallery.ModelConfig @@ -42,20 +51,61 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model if err != nil { // maybe not a HF repository // TODO: maybe we can check if the URI is a valid HF repository - log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository") + log.Debug().Str("uri", uri).Str("hfrepoID", hfrepoID).Msg("Failed to get model details, maybe not a HF repository") } else { log.Debug().Str("uri", uri).Msg("Got model details") log.Debug().Any("details", hfDetails).Msg("Model details") } + // handle local config files ("/my-model.yaml" or "file://my-model.yaml") + localURI := uri + if strings.HasPrefix(uri, downloader.LocalPrefix) { + localURI = strings.TrimPrefix(uri, downloader.LocalPrefix) + } + + // if a file exists or it's an url that ends with .yaml or .yml, read the config file directly + if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) { + var modelYAML []byte + if downloader.URI(localURI).LooksLikeURL() { + err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error { + modelYAML = i + return nil + }) + if err != nil { + log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition") + return gallery.ModelConfig{}, err + } + } else { + modelYAML, err = os.ReadFile(localURI) + if err != nil { + log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition") + return gallery.ModelConfig{}, err + } + } + + var modelConfig config.ModelConfig + if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil { + return gallery.ModelConfig{}, e + } + + configFile, err := yaml.Marshal(modelConfig) + return gallery.ModelConfig{ + Description: modelConfig.Description, + Name: modelConfig.Name, + ConfigFile: string(configFile), + }, err + } + details := Details{ HuggingFace: hfDetails, URI: uri, Preferences: preferences, } + importerMatched := false for _, importer := range defaultImporters { if importer.Match(details) { + importerMatched = true modelConfig, err = importer.Import(details) if err != nil { continue @@ -63,5 +113,8 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model break } } - return modelConfig, err + if !importerMatched { + return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri) + } + return modelConfig, nil } diff --git a/core/gallery/importers/importers_test.go b/core/gallery/importers/importers_test.go index 34814fe66c0d..f34be0d6e5bc 100644 --- a/core/gallery/importers/importers_test.go +++ b/core/gallery/importers/importers_test.go @@ -3,6 +3,8 @@ package importers_test import ( "encoding/json" "fmt" + "os" + "path/filepath" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/onsi/ginkgo/v2" @@ -212,4 +214,139 @@ var _ = Describe("DiscoverModelConfig", func() { Expect(modelConfig.Name).To(BeEmpty()) }) }) + + Context("with local YAML config files", func() { + var tempDir string + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "importers-test-*") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + It("should read local YAML file with file:// prefix", func() { + yamlContent := `name: test-model +backend: llama-cpp +description: Test model from local YAML +parameters: + model: /path/to/model.gguf + temperature: 0.7 +` + yamlFile := filepath.Join(tempDir, "test-model.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("test-model")) + Expect(modelConfig.Description).To(Equal("Test model from local YAML")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("name: test-model")) + }) + + It("should read local YAML file without file:// prefix (direct path)", func() { + yamlContent := `name: direct-path-model +backend: mlx +description: Test model from direct path +parameters: + model: /path/to/model.safetensors +` + yamlFile := filepath.Join(tempDir, "direct-model.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("direct-path-model")) + Expect(modelConfig.Description).To(Equal("Test model from direct path")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) + }) + + It("should read local YAML file with .yml extension", func() { + yamlContent := `name: yml-extension-model +backend: transformers +description: Test model with .yml extension +parameters: + model: /path/to/model +` + yamlFile := filepath.Join(tempDir, "test-model.yml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("yml-extension-model")) + Expect(modelConfig.Description).To(Equal("Test model with .yml extension")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers")) + }) + + It("should ignore preferences when reading YAML files directly", func() { + yamlContent := `name: yaml-model +backend: llama-cpp +description: Original description +parameters: + model: /path/to/model.gguf +` + yamlFile := filepath.Join(tempDir, "prefs-test.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + // Preferences should be ignored when reading YAML directly + preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description", "backend": "mlx"}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + // Should use values from YAML file, not preferences + Expect(modelConfig.Name).To(Equal("yaml-model")) + Expect(modelConfig.Description).To(Equal("Original description")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) + }) + + It("should return error when local YAML file doesn't exist", func() { + nonExistentFile := filepath.Join(tempDir, "nonexistent.yaml") + uri := "file://" + nonExistentFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).To(HaveOccurred()) + Expect(modelConfig.Name).To(BeEmpty()) + }) + + It("should return error when YAML file is invalid/malformed", func() { + invalidYaml := `name: invalid-model +backend: llama-cpp +invalid: yaml: content: [unclosed bracket +` + yamlFile := filepath.Join(tempDir, "invalid.yaml") + err := os.WriteFile(yamlFile, []byte(invalidYaml), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).To(HaveOccurred()) + Expect(modelConfig.Name).To(BeEmpty()) + }) + }) }) diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go index 669faf79076c..f1c4a4dc96e0 100644 --- a/core/gallery/importers/llama-cpp.go +++ b/core/gallery/importers/llama-cpp.go @@ -9,7 +9,9 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" + "github.com/rs/zerolog/log" "go.yaml.in/yaml/v2" ) @@ -20,14 +22,22 @@ type LlamaCPPImporter struct{} func (i *LlamaCPPImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { + log.Error().Err(err).Msg("failed to marshal preferences") return false } + preferencesMap := make(map[string]any) - err = json.Unmarshal(preferences, &preferencesMap) - if err != nil { - return false + + if len(preferences) > 0 { + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + log.Error().Err(err).Msg("failed to unmarshal preferences") + return false + } } + uri := downloader.URI(details.URI) + if preferencesMap["backend"] == "llama-cpp" { return true } @@ -36,6 +46,10 @@ func (i *LlamaCPPImporter) Match(details Details) bool { return true } + if uri.LooksLikeOCI() { + return true + } + if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.HasSuffix(file.Path, ".gguf") { @@ -48,14 +62,19 @@ func (i *LlamaCPPImporter) Match(details Details) bool { } func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) { + + log.Debug().Str("uri", details.URI).Msg("llama.cpp importer matched") + preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) - err = json.Unmarshal(preferences, &preferencesMap) - if err != nil { - return gallery.ModelConfig{}, err + if len(preferences) > 0 { + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + return gallery.ModelConfig{}, err + } } name, ok := preferencesMap["name"].(string) @@ -108,7 +127,40 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) Description: description, } - if strings.HasSuffix(details.URI, ".gguf") { + uri := downloader.URI(details.URI) + + switch { + case uri.LooksLikeOCI(): + ociName := strings.TrimPrefix(string(uri), downloader.OCIPrefix) + ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) + ociName = strings.ReplaceAll(ociName, "/", "__") + ociName = strings.ReplaceAll(ociName, ":", "__") + cfg.Files = append(cfg.Files, gallery.File{ + URI: details.URI, + Filename: ociName, + }) + modelConfig.PredictionOptions = schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{ + Model: ociName, + }, + } + case uri.LooksLikeURL() && strings.HasSuffix(details.URI, ".gguf"): + // Extract filename from URL + fileName, e := uri.FilenameFromUrl() + if e != nil { + return gallery.ModelConfig{}, e + } + + cfg.Files = append(cfg.Files, gallery.File{ + URI: details.URI, + Filename: fileName, + }) + modelConfig.PredictionOptions = schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{ + Model: fileName, + }, + } + case strings.HasSuffix(details.URI, ".gguf"): cfg.Files = append(cfg.Files, gallery.File{ URI: details.URI, Filename: filepath.Base(details.URI), @@ -118,7 +170,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) Model: filepath.Base(details.URI), }, } - } else if details.HuggingFace != nil { + case details.HuggingFace != nil: // We want to: // Get first the chosen quants that match filenames // OR the first mmproj/gguf file found @@ -195,7 +247,6 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) } break } - } data, err := yaml.Marshal(modelConfig) diff --git a/core/gallery/models.go b/core/gallery/models.go index 7205886b633c..6b20ad7b2b40 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -9,7 +9,6 @@ import ( "strings" "dario.cat/mergo" - "github.com/mudler/LocalAI/core/config" lconfig "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" @@ -17,7 +16,7 @@ import ( "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) /* @@ -74,7 +73,7 @@ type PromptTemplate struct { // Installs a model from the gallery func InstallModelFromGallery( ctx context.Context, - modelGalleries, backendGalleries []config.Gallery, + modelGalleries, backendGalleries []lconfig.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { @@ -260,8 +259,8 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } - if !modelConfig.Validate() { - return nil, fmt.Errorf("failed to validate updated config YAML") + if valid, err := modelConfig.Validate(); !valid { + return nil, fmt.Errorf("failed to validate updated config YAML: %v", err) } err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) @@ -304,7 +303,7 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { // Galleryname is the name of the model in this case dat, err := os.ReadFile(configFile) if err == nil { - modelConfig := &config.ModelConfig{} + modelConfig := &lconfig.ModelConfig{} err = yaml.Unmarshal(dat, &modelConfig) if err != nil { @@ -369,7 +368,7 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { // This is ***NEVER*** going to be perfect or finished. // This is a BEST EFFORT function to surface known-vulnerable models to users. -func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error { +func SafetyScanGalleryModels(galleries []lconfig.Gallery, systemState *system.SystemState) error { galleryModels, err := AvailableGalleryModels(galleries, systemState) if err != nil { return err diff --git a/core/http/app_test.go b/core/http/app_test.go index 7935d2c91a2f..0e84530a7c53 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -87,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) { response := []gallery.GalleryModel{} uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? - err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { + err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) @@ -513,6 +513,124 @@ var _ = Describe("API test", func() { }) }) + + Context("Importing models from URI", func() { + var testYamlFile string + + BeforeEach(func() { + // Create a test YAML config file + yamlContent := `name: test-import-model +backend: llama-cpp +description: Test model imported from file URI +parameters: + model: path/to/model.gguf + temperature: 0.7 +` + testYamlFile = filepath.Join(tmpdir, "test-import.yaml") + err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.Remove(testYamlFile) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should import model from file:// URI pointing to local YAML config", func() { + importReq := schema.ImportModelRequest{ + URI: "file://" + testYamlFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response.ID).ToNot(BeEmpty()) + + uuid := response.ID + resp := map[string]interface{}{} + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + resp = response + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // Check that the model was imported successfully + Expect(resp["message"]).ToNot(ContainSubstring("error")) + Expect(resp["error"]).To(BeNil()) + + // Verify the model config file was created + dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["name"]).To(Equal("test-import-model")) + Expect(content["backend"]).To(Equal("llama-cpp")) + }) + + It("should return error when file:// URI points to non-existent file", func() { + nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml") + importReq := schema.ImportModelRequest{ + URI: "file://" + nonExistentFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + // The endpoint should return an error immediately + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to discover model config")) + }) + }) + + Context("Importing models from URI can't point to absolute paths", func() { + var testYamlFile string + + BeforeEach(func() { + // Create a test YAML config file + yamlContent := `name: test-import-model +backend: llama-cpp +description: Test model imported from file URI +parameters: + model: /path/to/model.gguf + temperature: 0.7 +` + testYamlFile = filepath.Join(tmpdir, "test-import.yaml") + err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.Remove(testYamlFile) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail to import model from file:// URI pointing to local YAML config", func() { + importReq := schema.ImportModelRequest{ + URI: "file://" + testYamlFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response.ID).ToNot(BeEmpty()) + + uuid := response.ID + resp := map[string]interface{}{} + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + resp = response + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // Check that the model was imported successfully + Expect(resp["message"]).To(ContainSubstring("error")) + Expect(resp["error"]).ToNot(BeNil()) + }) + }) }) Context("Model gallery", func() { diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go index 4c59add22c31..f84b4d21bd00 100644 --- a/core/http/endpoints/localai/edit_model.go +++ b/core/http/endpoints/localai/edit_model.go @@ -135,7 +135,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati } // Validate the configuration - if !req.Validate() { + if valid, _ := req.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Validation failed", @@ -196,7 +196,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { // Reload configurations - if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index d44d11ff8deb..77abcdfb60b3 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -148,7 +148,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica modelConfig.SetDefaults() // Validate the configuration - if !modelConfig.Validate() { + if valid, _ := modelConfig.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Invalid configuration", @@ -185,7 +185,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica return c.JSON(http.StatusInternalServerError, response) } // Reload configurations - if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index a62e5a18a902..6bc00480224d 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -112,7 +112,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgVAD.Validate() { + if valid, _ := cfgVAD.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } @@ -128,7 +128,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgSST.Validate() { + if valid, _ := cfgSST.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } @@ -155,7 +155,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgVAD.Validate() { + if valid, _ := cfgVAD.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -172,7 +172,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgSST.Validate() { + if valid, _ := cfgSST.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -191,7 +191,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgAnyToAny.Validate() { + if valid, _ := cfgAnyToAny.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -218,7 +218,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgLLM.Validate() { + if valid, _ := cfgLLM.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -228,7 +228,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgTTS.Validate() { + if valid, _ := cfgTTS.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 362feadc1677..24720578ef2b 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -475,7 +475,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. } } - if config.Validate() { + if valid, _ := config.Validate(); valid { return nil } return fmt.Errorf("unable to validate configuration after merging") diff --git a/core/http/static/chat.js b/core/http/static/chat.js index 317a517d6c00..fea4b1efac95 100644 --- a/core/http/static/chat.js +++ b/core/http/static/chat.js @@ -1213,9 +1213,6 @@ async function promptGPT(systemPrompt, input) { document.getElementById("system_prompt").addEventListener("submit", submitSystemPrompt); document.getElementById("prompt").addEventListener("submit", submitPrompt); document.getElementById("input").focus(); -document.getElementById("input_image").addEventListener("change", readInputImage); -document.getElementById("input_audio").addEventListener("change", readInputAudio); -document.getElementById("input_file").addEventListener("change", readInputFile); storesystemPrompt = localStorage.getItem("system_prompt"); if (storesystemPrompt) { diff --git a/core/http/views/backends.html b/core/http/views/backends.html index 72b92e72a674..0735e09cdd8f 100644 --- a/core/http/views/backends.html +++ b/core/http/views/backends.html @@ -629,11 +629,33 @@