diff --git a/core/config/model_config.go b/core/config/model_config.go index 50ac44625cb6..4f02776c48fd 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "regexp" "slices" @@ -475,7 +476,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { cfg.syncKnownUsecasesFromString() } -func (c *ModelConfig) Validate() bool { +func (c *ModelConfig) Validate() (bool, error) { downloadedFileNames := []string{} for _, f := range c.DownloadFiles { downloadedFileNames = append(downloadedFileNames, f.Filename) @@ -489,17 +490,20 @@ func (c *ModelConfig) Validate() bool { } if strings.HasPrefix(n, string(os.PathSeparator)) || strings.Contains(n, "..") { - return false + return false, fmt.Errorf("invalid file path: %s", n) } } if c.Backend != "" { // a regex that checks that is a string name with no special characters, except '-' and '_' re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) - return re.MatchString(c.Backend) + if !re.MatchString(c.Backend) { + return false, fmt.Errorf("invalid backend name: %s", c.Backend) + } + return true, nil } - return true + return true, nil } func (c *ModelConfig) HasTemplate() bool { diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index 9895a4a0e5f8..f0f2c3338c13 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -169,7 +169,7 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op } for _, cc := range c { - if cc.Validate() { + if valid, _ := cc.Validate(); valid { bcl.configs[cc.Name] = *cc } } @@ -184,7 +184,7 @@ func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderO return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err) } - if c.Validate() { + if valid, _ := c.Validate(); valid { bcl.configs[c.Name] = *c } else { return fmt.Errorf("config is not valid") @@ -362,7 +362,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file") continue } - if c.Validate() { + if valid, _ := c.Validate(); valid { bcl.configs[c.Name] = *c } else { log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid") diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go index 9d49e270c751..342b10c47f8b 100644 --- a/core/config/model_config_test.go +++ b/core/config/model_config_test.go @@ -28,7 +28,9 @@ known_usecases: config, err := readModelConfigFromFile(tmp.Name()) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) - Expect(config.Validate()).To(BeFalse()) + valid, err := config.Validate() + Expect(err).To(HaveOccurred()) + Expect(valid).To(BeFalse()) Expect(config.KnownUsecases).ToNot(BeNil()) }) It("Test Validate", func() { @@ -46,7 +48,9 @@ parameters: Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("bar-baz")) - Expect(config.Validate()).To(BeTrue()) + valid, err := config.Validate() + Expect(err).To(BeNil()) + Expect(valid).To(BeTrue()) // download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml httpClient := http.Client{} @@ -63,7 +67,9 @@ parameters: Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("hermes-2-pro-mistral")) - Expect(config.Validate()).To(BeTrue()) + valid, err = config.Validate() + Expect(err).To(BeNil()) + Expect(valid).To(BeTrue()) }) }) It("Properly handles backend usecase matching", func() { diff --git a/core/gallery/backends.go b/core/gallery/backends.go index aee4b2d93928..9049664b3549 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -164,7 +164,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL return fmt.Errorf("failed copying: %w", err) } } else { - uri := downloader.URI(config.URI) + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloading backend") if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { success := false // Try to download from mirrors @@ -177,16 +177,27 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL } if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { success = true + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend") break } } if !success { + log.Error().Str("uri", config.URI).Str("backendPath", backendPath).Err(err).Msg("Failed to download backend") return fmt.Errorf("failed to download backend %q: %v", config.URI, err) } + } else { + log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend") } } + // sanity check - check if runfile is present + runFile := filepath.Join(backendPath, runFile) + if _, err := os.Stat(runFile); os.IsNotExist(err) { + log.Error().Str("runFile", runFile).Msg("Run file not found") + return fmt.Errorf("not a valid backend: run file not found %q", runFile) + } + // Create metadata for the backend metadata := &BackendMetadata{ Name: name, diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go index 15900d25018b..756d2e7a23b8 100644 --- a/core/gallery/backends_test.go +++ b/core/gallery/backends_test.go @@ -563,8 +563,8 @@ var _ = Describe("Gallery Backends", func() { ) Expect(err).NotTo(HaveOccurred()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) - Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created Expect(newPath).To(BeADirectory()) + Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created }) It("should overwrite existing backend", func() { diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 62362148ecef..0475b898c645 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -6,11 +6,13 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/lithammer/fuzzysearch/fuzzy" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/system" + "github.com/mudler/LocalAI/pkg/xsync" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" @@ -19,7 +21,7 @@ import ( func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) { var config T uri := downloader.URI(url) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { @@ -32,7 +34,7 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) { func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) { var config T uri := downloader.URI(url) - err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error { + err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { @@ -141,7 +143,7 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst // Get models from galleries for _, gallery := range galleries { - galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { + galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { return true } @@ -182,7 +184,7 @@ func AvailableBackends(galleries []config.Gallery, systemState *system.SystemSta func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string uri := downloader.URI(url) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) @@ -194,6 +196,17 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) return refFile, err } +type galleryCacheEntry struct { + yamlEntry []byte + lastUpdated time.Time +} + +func (entry galleryCacheEntry) hasExpired() bool { + return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour)) +} + +var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]() + func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) { var models []T = []T{} @@ -204,16 +217,37 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin return models, err } } + + cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL) + if galleryCache.Exists(cacheKey) { + entry := galleryCache.Get(cacheKey) + // refresh if last updated is more than 1 hour ago + if !entry.hasExpired() { + err := yaml.Unmarshal(entry.yamlEntry, &models) + if err != nil { + return models, err + } + } else { + galleryCache.Delete(cacheKey) + } + } + uri := downloader.URI(gallery.URL) - err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error { - return yaml.Unmarshal(d, &models) - }) - if err != nil { - if yamlErr, ok := err.(*yaml.TypeError); ok { - log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models) + if len(models) == 0 { + err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { + galleryCache.Set(cacheKey, galleryCacheEntry{ + yamlEntry: d, + lastUpdated: time.Now(), + }) + return yaml.Unmarshal(d, &models) + }) + if err != nil { + if yamlErr, ok := err.(*yaml.TypeError); ok { + log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models) + } + return models, fmt.Errorf("failed to read gallery elements: %w", err) } - return models, err } // Add gallery to models diff --git a/core/gallery/importers/importers.go b/core/gallery/importers/importers.go index 238aad6f1634..283a3349a5e4 100644 --- a/core/gallery/importers/importers.go +++ b/core/gallery/importers/importers.go @@ -2,11 +2,16 @@ package importers import ( "encoding/json" + "fmt" + "os" "strings" "github.com/rs/zerolog/log" + "gopkg.in/yaml.v3" + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/pkg/downloader" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" ) @@ -28,6 +33,10 @@ type Importer interface { Import(details Details) (gallery.ModelConfig, error) } +func hasYAMLExtension(uri string) bool { + return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml") +} + func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) { var err error var modelConfig gallery.ModelConfig @@ -42,20 +51,61 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model if err != nil { // maybe not a HF repository // TODO: maybe we can check if the URI is a valid HF repository - log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository") + log.Debug().Str("uri", uri).Str("hfrepoID", hfrepoID).Msg("Failed to get model details, maybe not a HF repository") } else { log.Debug().Str("uri", uri).Msg("Got model details") log.Debug().Any("details", hfDetails).Msg("Model details") } + // handle local config files ("/my-model.yaml" or "file://my-model.yaml") + localURI := uri + if strings.HasPrefix(uri, downloader.LocalPrefix) { + localURI = strings.TrimPrefix(uri, downloader.LocalPrefix) + } + + // if a file exists or it's an url that ends with .yaml or .yml, read the config file directly + if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) { + var modelYAML []byte + if downloader.URI(localURI).LooksLikeURL() { + err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error { + modelYAML = i + return nil + }) + if err != nil { + log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition") + return gallery.ModelConfig{}, err + } + } else { + modelYAML, err = os.ReadFile(localURI) + if err != nil { + log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition") + return gallery.ModelConfig{}, err + } + } + + var modelConfig config.ModelConfig + if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil { + return gallery.ModelConfig{}, e + } + + configFile, err := yaml.Marshal(modelConfig) + return gallery.ModelConfig{ + Description: modelConfig.Description, + Name: modelConfig.Name, + ConfigFile: string(configFile), + }, err + } + details := Details{ HuggingFace: hfDetails, URI: uri, Preferences: preferences, } + importerMatched := false for _, importer := range defaultImporters { if importer.Match(details) { + importerMatched = true modelConfig, err = importer.Import(details) if err != nil { continue @@ -63,5 +113,8 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model break } } - return modelConfig, err + if !importerMatched { + return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri) + } + return modelConfig, nil } diff --git a/core/gallery/importers/importers_test.go b/core/gallery/importers/importers_test.go index 34814fe66c0d..f34be0d6e5bc 100644 --- a/core/gallery/importers/importers_test.go +++ b/core/gallery/importers/importers_test.go @@ -3,6 +3,8 @@ package importers_test import ( "encoding/json" "fmt" + "os" + "path/filepath" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/onsi/ginkgo/v2" @@ -212,4 +214,139 @@ var _ = Describe("DiscoverModelConfig", func() { Expect(modelConfig.Name).To(BeEmpty()) }) }) + + Context("with local YAML config files", func() { + var tempDir string + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "importers-test-*") + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + It("should read local YAML file with file:// prefix", func() { + yamlContent := `name: test-model +backend: llama-cpp +description: Test model from local YAML +parameters: + model: /path/to/model.gguf + temperature: 0.7 +` + yamlFile := filepath.Join(tempDir, "test-model.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("test-model")) + Expect(modelConfig.Description).To(Equal("Test model from local YAML")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("name: test-model")) + }) + + It("should read local YAML file without file:// prefix (direct path)", func() { + yamlContent := `name: direct-path-model +backend: mlx +description: Test model from direct path +parameters: + model: /path/to/model.safetensors +` + yamlFile := filepath.Join(tempDir, "direct-model.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("direct-path-model")) + Expect(modelConfig.Description).To(Equal("Test model from direct path")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) + }) + + It("should read local YAML file with .yml extension", func() { + yamlContent := `name: yml-extension-model +backend: transformers +description: Test model with .yml extension +parameters: + model: /path/to/model +` + yamlFile := filepath.Join(tempDir, "test-model.yml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("yml-extension-model")) + Expect(modelConfig.Description).To(Equal("Test model with .yml extension")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers")) + }) + + It("should ignore preferences when reading YAML files directly", func() { + yamlContent := `name: yaml-model +backend: llama-cpp +description: Original description +parameters: + model: /path/to/model.gguf +` + yamlFile := filepath.Join(tempDir, "prefs-test.yaml") + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + // Preferences should be ignored when reading YAML directly + preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description", "backend": "mlx"}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).ToNot(HaveOccurred()) + // Should use values from YAML file, not preferences + Expect(modelConfig.Name).To(Equal("yaml-model")) + Expect(modelConfig.Description).To(Equal("Original description")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) + }) + + It("should return error when local YAML file doesn't exist", func() { + nonExistentFile := filepath.Join(tempDir, "nonexistent.yaml") + uri := "file://" + nonExistentFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).To(HaveOccurred()) + Expect(modelConfig.Name).To(BeEmpty()) + }) + + It("should return error when YAML file is invalid/malformed", func() { + invalidYaml := `name: invalid-model +backend: llama-cpp +invalid: yaml: content: [unclosed bracket +` + yamlFile := filepath.Join(tempDir, "invalid.yaml") + err := os.WriteFile(yamlFile, []byte(invalidYaml), 0644) + Expect(err).ToNot(HaveOccurred()) + + uri := "file://" + yamlFile + preferences := json.RawMessage(`{}`) + + modelConfig, err := importers.DiscoverModelConfig(uri, preferences) + + Expect(err).To(HaveOccurred()) + Expect(modelConfig.Name).To(BeEmpty()) + }) + }) }) diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go index 669faf79076c..f1c4a4dc96e0 100644 --- a/core/gallery/importers/llama-cpp.go +++ b/core/gallery/importers/llama-cpp.go @@ -9,7 +9,9 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" + "github.com/rs/zerolog/log" "go.yaml.in/yaml/v2" ) @@ -20,14 +22,22 @@ type LlamaCPPImporter struct{} func (i *LlamaCPPImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { + log.Error().Err(err).Msg("failed to marshal preferences") return false } + preferencesMap := make(map[string]any) - err = json.Unmarshal(preferences, &preferencesMap) - if err != nil { - return false + + if len(preferences) > 0 { + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + log.Error().Err(err).Msg("failed to unmarshal preferences") + return false + } } + uri := downloader.URI(details.URI) + if preferencesMap["backend"] == "llama-cpp" { return true } @@ -36,6 +46,10 @@ func (i *LlamaCPPImporter) Match(details Details) bool { return true } + if uri.LooksLikeOCI() { + return true + } + if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.HasSuffix(file.Path, ".gguf") { @@ -48,14 +62,19 @@ func (i *LlamaCPPImporter) Match(details Details) bool { } func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) { + + log.Debug().Str("uri", details.URI).Msg("llama.cpp importer matched") + preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) - err = json.Unmarshal(preferences, &preferencesMap) - if err != nil { - return gallery.ModelConfig{}, err + if len(preferences) > 0 { + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + return gallery.ModelConfig{}, err + } } name, ok := preferencesMap["name"].(string) @@ -108,7 +127,40 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) Description: description, } - if strings.HasSuffix(details.URI, ".gguf") { + uri := downloader.URI(details.URI) + + switch { + case uri.LooksLikeOCI(): + ociName := strings.TrimPrefix(string(uri), downloader.OCIPrefix) + ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) + ociName = strings.ReplaceAll(ociName, "/", "__") + ociName = strings.ReplaceAll(ociName, ":", "__") + cfg.Files = append(cfg.Files, gallery.File{ + URI: details.URI, + Filename: ociName, + }) + modelConfig.PredictionOptions = schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{ + Model: ociName, + }, + } + case uri.LooksLikeURL() && strings.HasSuffix(details.URI, ".gguf"): + // Extract filename from URL + fileName, e := uri.FilenameFromUrl() + if e != nil { + return gallery.ModelConfig{}, e + } + + cfg.Files = append(cfg.Files, gallery.File{ + URI: details.URI, + Filename: fileName, + }) + modelConfig.PredictionOptions = schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{ + Model: fileName, + }, + } + case strings.HasSuffix(details.URI, ".gguf"): cfg.Files = append(cfg.Files, gallery.File{ URI: details.URI, Filename: filepath.Base(details.URI), @@ -118,7 +170,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) Model: filepath.Base(details.URI), }, } - } else if details.HuggingFace != nil { + case details.HuggingFace != nil: // We want to: // Get first the chosen quants that match filenames // OR the first mmproj/gguf file found @@ -195,7 +247,6 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) } break } - } data, err := yaml.Marshal(modelConfig) diff --git a/core/gallery/models.go b/core/gallery/models.go index 7205886b633c..6b20ad7b2b40 100644 --- a/core/gallery/models.go +++ b/core/gallery/models.go @@ -9,7 +9,6 @@ import ( "strings" "dario.cat/mergo" - "github.com/mudler/LocalAI/core/config" lconfig "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" @@ -17,7 +16,7 @@ import ( "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" ) /* @@ -74,7 +73,7 @@ type PromptTemplate struct { // Installs a model from the gallery func InstallModelFromGallery( ctx context.Context, - modelGalleries, backendGalleries []config.Gallery, + modelGalleries, backendGalleries []lconfig.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { @@ -260,8 +259,8 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } - if !modelConfig.Validate() { - return nil, fmt.Errorf("failed to validate updated config YAML") + if valid, err := modelConfig.Validate(); !valid { + return nil, fmt.Errorf("failed to validate updated config YAML: %v", err) } err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) @@ -304,7 +303,7 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { // Galleryname is the name of the model in this case dat, err := os.ReadFile(configFile) if err == nil { - modelConfig := &config.ModelConfig{} + modelConfig := &lconfig.ModelConfig{} err = yaml.Unmarshal(dat, &modelConfig) if err != nil { @@ -369,7 +368,7 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error { // This is ***NEVER*** going to be perfect or finished. // This is a BEST EFFORT function to surface known-vulnerable models to users. -func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error { +func SafetyScanGalleryModels(galleries []lconfig.Gallery, systemState *system.SystemState) error { galleryModels, err := AvailableGalleryModels(galleries, systemState) if err != nil { return err diff --git a/core/http/app_test.go b/core/http/app_test.go index 7935d2c91a2f..0e84530a7c53 100644 --- a/core/http/app_test.go +++ b/core/http/app_test.go @@ -87,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) { response := []gallery.GalleryModel{} uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? - err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { + err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) @@ -513,6 +513,124 @@ var _ = Describe("API test", func() { }) }) + + Context("Importing models from URI", func() { + var testYamlFile string + + BeforeEach(func() { + // Create a test YAML config file + yamlContent := `name: test-import-model +backend: llama-cpp +description: Test model imported from file URI +parameters: + model: path/to/model.gguf + temperature: 0.7 +` + testYamlFile = filepath.Join(tmpdir, "test-import.yaml") + err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.Remove(testYamlFile) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should import model from file:// URI pointing to local YAML config", func() { + importReq := schema.ImportModelRequest{ + URI: "file://" + testYamlFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response.ID).ToNot(BeEmpty()) + + uuid := response.ID + resp := map[string]interface{}{} + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + resp = response + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // Check that the model was imported successfully + Expect(resp["message"]).ToNot(ContainSubstring("error")) + Expect(resp["error"]).To(BeNil()) + + // Verify the model config file was created + dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["name"]).To(Equal("test-import-model")) + Expect(content["backend"]).To(Equal("llama-cpp")) + }) + + It("should return error when file:// URI points to non-existent file", func() { + nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml") + importReq := schema.ImportModelRequest{ + URI: "file://" + nonExistentFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + // The endpoint should return an error immediately + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to discover model config")) + }) + }) + + Context("Importing models from URI can't point to absolute paths", func() { + var testYamlFile string + + BeforeEach(func() { + // Create a test YAML config file + yamlContent := `name: test-import-model +backend: llama-cpp +description: Test model imported from file URI +parameters: + model: /path/to/model.gguf + temperature: 0.7 +` + testYamlFile = filepath.Join(tmpdir, "test-import.yaml") + err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + err := os.Remove(testYamlFile) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should fail to import model from file:// URI pointing to local YAML config", func() { + importReq := schema.ImportModelRequest{ + URI: "file://" + testYamlFile, + Preferences: json.RawMessage(`{}`), + } + + var response schema.GalleryResponse + err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) + Expect(err).ToNot(HaveOccurred()) + Expect(response.ID).ToNot(BeEmpty()) + + uuid := response.ID + resp := map[string]interface{}{} + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + resp = response + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // Check that the model was imported successfully + Expect(resp["message"]).To(ContainSubstring("error")) + Expect(resp["error"]).ToNot(BeNil()) + }) + }) }) Context("Model gallery", func() { diff --git a/core/http/endpoints/localai/edit_model.go b/core/http/endpoints/localai/edit_model.go index 4c59add22c31..f84b4d21bd00 100644 --- a/core/http/endpoints/localai/edit_model.go +++ b/core/http/endpoints/localai/edit_model.go @@ -135,7 +135,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati } // Validate the configuration - if !req.Validate() { + if valid, _ := req.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Validation failed", @@ -196,7 +196,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { // Reload configurations - if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index d44d11ff8deb..77abcdfb60b3 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -148,7 +148,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica modelConfig.SetDefaults() // Validate the configuration - if !modelConfig.Validate() { + if valid, _ := modelConfig.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Invalid configuration", @@ -185,7 +185,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica return c.JSON(http.StatusInternalServerError, response) } // Reload configurations - if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { + if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index a62e5a18a902..6bc00480224d 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -112,7 +112,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgVAD.Validate() { + if valid, _ := cfgVAD.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } @@ -128,7 +128,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfig return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgSST.Validate() { + if valid, _ := cfgSST.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } @@ -155,7 +155,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgVAD.Validate() { + if valid, _ := cfgVAD.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -172,7 +172,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgSST.Validate() { + if valid, _ := cfgSST.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -191,7 +191,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgAnyToAny.Validate() { + if valid, _ := cfgAnyToAny.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -218,7 +218,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgLLM.Validate() { + if valid, _ := cfgLLM.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } @@ -228,7 +228,7 @@ func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model return nil, fmt.Errorf("failed to load backend config: %w", err) } - if !cfgTTS.Validate() { + if valid, _ := cfgTTS.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } diff --git a/core/http/middleware/request.go b/core/http/middleware/request.go index 362feadc1677..24720578ef2b 100644 --- a/core/http/middleware/request.go +++ b/core/http/middleware/request.go @@ -475,7 +475,7 @@ func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema. } } - if config.Validate() { + if valid, _ := config.Validate(); valid { return nil } return fmt.Errorf("unable to validate configuration after merging") diff --git a/core/http/static/chat.js b/core/http/static/chat.js index 317a517d6c00..fea4b1efac95 100644 --- a/core/http/static/chat.js +++ b/core/http/static/chat.js @@ -1213,9 +1213,6 @@ async function promptGPT(systemPrompt, input) { document.getElementById("system_prompt").addEventListener("submit", submitSystemPrompt); document.getElementById("prompt").addEventListener("submit", submitPrompt); document.getElementById("input").focus(); -document.getElementById("input_image").addEventListener("change", readInputImage); -document.getElementById("input_audio").addEventListener("change", readInputAudio); -document.getElementById("input_file").addEventListener("change", readInputFile); storesystemPrompt = localStorage.getItem("system_prompt"); if (storesystemPrompt) { diff --git a/core/http/views/backends.html b/core/http/views/backends.html index 72b92e72a674..0735e09cdd8f 100644 --- a/core/http/views/backends.html +++ b/core/http/views/backends.html @@ -629,11 +629,33 @@

0) { + // Try common error object properties + errorMessage = jobData.error.message || jobData.error.error || jobData.error.Error || JSON.stringify(jobData.error); + } else { + // Empty object {}, fall back to message field + errorMessage = jobData.message || 'Unknown error'; + } + } else if (jobData.message) { + // Use message field if error is not present or is empty + errorMessage = jobData.message; + } + // Remove "error: " prefix if present + if (errorMessage.startsWith('error: ')) { + errorMessage = errorMessage.substring(7); + } + this.addNotification(`Error ${action} backend "${backend.name}": ${errorMessage}`, 'error'); } } catch (error) { console.error('Error polling job:', error); diff --git a/core/http/views/index.html b/core/http/views/index.html index 794bd28c1755..460ff322fe90 100644 --- a/core/http/views/index.html +++ b/core/http/views/index.html @@ -127,6 +127,7 @@

imageFiles: [], audioFiles: [], textFiles: [], + attachedFiles: [], currentPlaceholder: 'Send a message...', placeholderIndex: 0, charIndex: 0, @@ -241,6 +242,30 @@

} else { this.resumeTyping(); } + }, + handleFileSelection(files, fileType) { + Array.from(files).forEach(file => { + // Check if file already exists + const exists = this.attachedFiles.some(f => f.name === file.name && f.type === fileType); + if (!exists) { + this.attachedFiles.push({ name: file.name, type: fileType }); + } + }); + }, + removeAttachedFile(fileType, fileName) { + // Remove from attachedFiles array + const index = this.attachedFiles.findIndex(f => f.name === fileName && f.type === fileType); + if (index !== -1) { + this.attachedFiles.splice(index, 1); + } + // Remove from corresponding file array + if (fileType === 'image') { + this.imageFiles = this.imageFiles.filter(f => f.name !== fileName); + } else if (fileType === 'audio') { + this.audioFiles = this.audioFiles.filter(f => f.name !== fileName); + } else if (fileType === 'file') { + this.textFiles = this.textFiles.filter(f => f.name !== fileName); + } } }"> @@ -265,6 +290,24 @@

+ +
+ +
+
-
diff --git a/core/http/views/model-editor.html b/core/http/views/model-editor.html index d976e5fd1f49..734a327db205 100644 --- a/core/http/views/model-editor.html +++ b/core/http/views/model-editor.html @@ -77,18 +77,197 @@

- +

Enter the URI or path to the model file you want to import

+ + +
+ + +
+ + +
+

+ + HuggingFace +

+
+
+ +
+ huggingface://TheBloke/Llama-2-7B-Chat-GGUF +

Standard HuggingFace format

+
+
+
+ +
+ hf://TheBloke/Llama-2-7B-Chat-GGUF +

Short HuggingFace format

+
+
+
+ +
+ https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF +

Full HuggingFace URL

+
+
+
+
+ + +
+

+ + HTTP/HTTPS URLs +

+
+
+ +
+ https://example.com/model.gguf +

Direct download from any HTTPS URL

+
+
+
+
+ + +
+

+ + Local Files +

+
+
+ +
+ file:///path/to/model.gguf +

Local file path (absolute)

+
+
+
+ +
+ /path/to/model.yaml +

Direct local YAML config file

+
+
+
+
+ + +
+

+ + OCI Registry +

+
+
+ +
+ oci://registry.example.com/model:tag +

OCI container registry

+
+
+
+ +
+ ocifile:///path/to/image.tar +

Local OCI tarball file

+
+
+
+
+ + +
+

+ + Ollama +

+
+
+ +
+ ollama://llama2:7b +

Ollama model format

+
+
+
+
+ + +
+

+ + YAML Configuration Files +

+
+
+ +
+ https://example.com/model.yaml +

Remote YAML config file

+
+
+
+ +
+ file:///path/to/config.yaml +

Local YAML config file

+
+
+
+
+ +
+

+ + Tip: For HuggingFace models, you can use any of the three formats. The system will automatically detect and download the appropriate model files. +

+
+
+
@@ -629,11 +808,33 @@

setTimeout(() => { window.location.reload(); }, 2000); - } else if (jobData.error) { + } else if (jobData.error || (jobData.message && jobData.message.startsWith('error:'))) { clearInterval(this.jobPollInterval); this.isSubmitting = false; this.currentJobId = null; - this.showAlert('error', 'Import failed: ' + jobData.error); + // Extract error message - handle both string and object errors + let errorMessage = 'Unknown error'; + if (typeof jobData.error === 'string') { + errorMessage = jobData.error; + } else if (jobData.error && typeof jobData.error === 'object') { + // Check if error object has any properties + const errorKeys = Object.keys(jobData.error); + if (errorKeys.length > 0) { + // Try common error object properties + errorMessage = jobData.error.message || jobData.error.error || jobData.error.Error || JSON.stringify(jobData.error); + } else { + // Empty object {}, fall back to message field + errorMessage = jobData.message || 'Unknown error'; + } + } else if (jobData.message) { + // Use message field if error is not present or is empty + errorMessage = jobData.message; + } + // Remove "error: " prefix if present + if (errorMessage.startsWith('error: ')) { + errorMessage = errorMessage.substring(7); + } + this.showAlert('error', 'Import failed: ' + errorMessage); } } catch (error) { console.error('Error polling job status:', error); diff --git a/core/http/views/models.html b/core/http/views/models.html index 4e3dd10204ec..96d6a1ebd598 100644 --- a/core/http/views/models.html +++ b/core/http/views/models.html @@ -714,11 +714,33 @@

0) { + // Try common error object properties + errorMessage = jobData.error.message || jobData.error.error || jobData.error.Error || JSON.stringify(jobData.error); + } else { + // Empty object {}, fall back to message field + errorMessage = jobData.message || 'Unknown error'; + } + } else if (jobData.message) { + // Use message field if error is not present or is empty + errorMessage = jobData.message; + } + // Remove "error: " prefix if present + if (errorMessage.startsWith('error: ')) { + errorMessage = errorMessage.substring(7); + } + this.addNotification(`Error ${action} model "${model.name}": ${errorMessage}`, 'error'); } } catch (error) { console.error('Error polling job:', error); diff --git a/core/services/models.go b/core/services/models.go index 40ebbc98ee63..5e76adc98c98 100644 --- a/core/services/models.go +++ b/core/services/models.go @@ -85,7 +85,7 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, galler } // Reload models - err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath) + err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath, g.appConfig.ToConfigLoaderOptions()...) if err != nil { return err } diff --git a/core/startup/model_preload.go b/core/startup/model_preload.go index 9377830a4df1..69b67a5fddb1 100644 --- a/core/startup/model_preload.go +++ b/core/startup/model_preload.go @@ -5,10 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "os" - "path" - "path/filepath" - "strings" "time" "github.com/google/uuid" @@ -16,12 +12,10 @@ import ( "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery/importers" "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" ) const ( @@ -34,178 +28,59 @@ const ( func InstallModels(ctx context.Context, galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error { // create an error that groups all errors var err error - - installBackend := func(modelPath string) error { - // Then load the model file, and read the backend - modelYAML, e := os.ReadFile(modelPath) - if e != nil { - log.Error().Err(e).Str("filepath", modelPath).Msg("error reading model definition") - return e - } - - var model config.ModelConfig - if e := yaml.Unmarshal(modelYAML, &model); e != nil { - log.Error().Err(e).Str("filepath", modelPath).Msg("error unmarshalling model definition") - return e - } - - if model.Backend == "" { - log.Debug().Str("filepath", modelPath).Msg("no backend found in model definition") - return nil - } - - if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil { - log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend") - return err - } - - return nil - } - for _, url := range models { - // As a best effort, try to resolve the model from the remote library - // if it's not resolved we try with the other method below - - uri := downloader.URI(url) - - switch { - case uri.LooksLikeOCI(): - log.Debug().Msgf("[startup] resolved OCI model to download: %s", url) - - // convert OCI image name to a file name. - ociName := strings.TrimPrefix(url, downloader.OCIPrefix) - ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) - ociName = strings.ReplaceAll(ociName, "/", "__") - ociName = strings.ReplaceAll(ociName, ":", "__") - - // check if file exists - if _, e := os.Stat(filepath.Join(systemState.Model.ModelsPath, ociName)); errors.Is(e, os.ErrNotExist) { - modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, ociName) - e := uri.DownloadFile(modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) - if e != nil { - log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") - err = errors.Join(err, e) - } + // Check if it's a model gallery, or print a warning + e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries) + if e != nil && found { + log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) + err = errors.Join(err, e) + } else if !found { + log.Debug().Msgf("[startup] model not found in the gallery '%s'", url) + + if galleryService == nil { + return fmt.Errorf("cannot start autoimporter, not sure how to handle this uri") } - log.Info().Msgf("[startup] installed model from OCI repository: %s", ociName) - case uri.LooksLikeURL(): - log.Debug().Msgf("[startup] downloading %s", url) - - // Extract filename from URL - fileName, e := uri.FilenameFromUrl() - if e != nil { - log.Warn().Err(e).Str("url", url).Msg("error extracting filename from URL") - err = errors.Join(err, e) + // TODO: we should just use the discoverModelConfig here and default to this. + modelConfig, discoverErr := importers.DiscoverModelConfig(url, json.RawMessage{}) + if discoverErr != nil { + log.Error().Err(discoverErr).Msgf("[startup] failed to discover model config '%s'", url) + err = errors.Join(discoverErr, fmt.Errorf("failed to discover model config: %w", err)) continue } - modelPath := filepath.Join(systemState.Model.ModelsPath, fileName) - - if e := utils.VerifyPath(fileName, modelPath); e != nil { - log.Error().Err(e).Str("filepath", modelPath).Msg("error verifying path") - err = errors.Join(err, e) + uuid, uuidErr := uuid.NewUUID() + if uuidErr != nil { + err = errors.Join(uuidErr, fmt.Errorf("failed to generate UUID: %w", uuidErr)) continue } - // check if file exists - if _, e := os.Stat(modelPath); errors.Is(e, os.ErrNotExist) { - e := uri.DownloadFile(modelPath, "", 0, 0, func(fileName, current, total string, percent float64) { - utils.DisplayDownloadFunction(fileName, current, total, percent) - }) - if e != nil { - log.Error().Err(e).Str("url", url).Str("filepath", modelPath).Msg("error downloading model") - err = errors.Join(err, e) - } + galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ + Req: gallery.GalleryModel{ + Overrides: map[string]interface{}{}, + }, + ID: uuid.String(), + GalleryElementName: modelConfig.Name, + GalleryElement: &modelConfig, + BackendGalleries: backendGalleries, } - // Check if we have the backend installed - if autoloadBackendGalleries && path.Ext(modelPath) == YAML_EXTENSION { - if err := installBackend(modelPath); err != nil { - log.Error().Err(err).Str("filepath", modelPath).Msg("error installing backend") + var status *services.GalleryOpStatus + // wait for op to finish + for { + status = galleryService.GetStatus(uuid.String()) + if status != nil && status.Processed { + break } + time.Sleep(1 * time.Second) } - default: - if _, e := os.Stat(url); e == nil { - log.Debug().Msgf("[startup] resolved local model: %s", url) - // copy to modelPath - md5Name := utils.MD5(url) - - modelYAML, e := os.ReadFile(url) - if e != nil { - log.Error().Err(e).Str("filepath", url).Msg("error reading model definition") - err = errors.Join(err, e) - continue - } - modelDefinitionFilePath := filepath.Join(systemState.Model.ModelsPath, md5Name) + YAML_EXTENSION - if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") - err = errors.Join(err, e) - } - - // Check if we have the backend installed - if autoloadBackendGalleries && path.Ext(modelDefinitionFilePath) == YAML_EXTENSION { - if err := installBackend(modelDefinitionFilePath); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error installing backend") - } - } - } else { - // Check if it's a model gallery, or print a warning - e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries) - if e != nil && found { - log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) - err = errors.Join(err, e) - } else if !found { - log.Warn().Msgf("[startup] failed resolving model '%s'", url) - - if galleryService == nil { - err = errors.Join(err, fmt.Errorf("cannot start autoimporter, not sure how to handle this uri")) - continue - } - - // TODO: we should just use the discoverModelConfig here and default to this. - modelConfig, discoverErr := importers.DiscoverModelConfig(url, json.RawMessage{}) - if discoverErr != nil { - err = errors.Join(discoverErr, fmt.Errorf("failed to discover model config: %w", err)) - continue - } - - uuid, uuidErr := uuid.NewUUID() - if uuidErr != nil { - err = errors.Join(uuidErr, fmt.Errorf("failed to generate UUID: %w", uuidErr)) - continue - } - - galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ - Req: gallery.GalleryModel{ - Overrides: map[string]interface{}{}, - }, - ID: uuid.String(), - GalleryElementName: modelConfig.Name, - GalleryElement: &modelConfig, - BackendGalleries: backendGalleries, - } - - var status *services.GalleryOpStatus - // wait for op to finish - for { - status = galleryService.GetStatus(uuid.String()) - if status != nil && status.Processed { - break - } - time.Sleep(1 * time.Second) - } - - if status.Error != nil { - return status.Error - } - - log.Info().Msgf("[startup] imported model '%s' from '%s'", modelConfig.Name, url) - } + if status.Error != nil { + log.Error().Err(status.Error).Msgf("[startup] failed to import model '%s' from '%s'", modelConfig.Name, url) + return status.Error } + + log.Info().Msgf("[startup] imported model '%s' from '%s'", modelConfig.Name, url) } } return err diff --git a/core/startup/model_preload_test.go b/core/startup/model_preload_test.go index 54dc5507392f..3bf6d2687b0a 100644 --- a/core/startup/model_preload_test.go +++ b/core/startup/model_preload_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services" . "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" @@ -19,8 +20,11 @@ var _ = Describe("Preload test", func() { var tmpdir string var systemState *system.SystemState var ml *model.ModelLoader + var ctx context.Context + var cancel context.CancelFunc BeforeEach(func() { + ctx, cancel = context.WithCancel(context.Background()) var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) @@ -29,13 +33,24 @@ var _ = Describe("Preload test", func() { ml = model.NewModelLoader(systemState, true) }) + AfterEach(func() { + cancel() + }) + Context("Preloading from strings", func() { It("loads from embedded full-urls", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", "phi-2") - InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url) + galleryService := services.NewGalleryService(&config.ApplicationConfig{ + SystemState: systemState, + }, ml) + galleryService.Start(ctx, config.NewModelConfigLoader(tmpdir), systemState) + err := InstallModels(ctx, galleryService, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, func(s1, s2, s3 string, f float64) { + fmt.Println(s1, s2, s3, f) + }, url) + Expect(err).ToNot(HaveOccurred()) resultFile := filepath.Join(tmpdir, fileName) content, err := os.ReadFile(resultFile) @@ -47,13 +62,22 @@ var _ = Describe("Preload test", func() { url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K") - err := InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url) + galleryService := services.NewGalleryService(&config.ApplicationConfig{ + SystemState: systemState, + }, ml) + galleryService.Start(ctx, config.NewModelConfigLoader(tmpdir), systemState) + + err := InstallModels(ctx, galleryService, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, func(s1, s2, s3 string, f float64) { + fmt.Println(s1, s2, s3, f) + }, url) Expect(err).ToNot(HaveOccurred()) resultFile := filepath.Join(tmpdir, fileName) + dirs, err := os.ReadDir(tmpdir) + Expect(err).ToNot(HaveOccurred()) _, err = os.Stat(resultFile) - Expect(err).ToNot(HaveOccurred()) + Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("%+v", dirs)) }) }) }) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index ea1631f4fffa..0129c5fdc12e 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -19,6 +19,7 @@ import ( "github.com/mudler/LocalAI/pkg/oci" "github.com/mudler/LocalAI/pkg/utils" + "github.com/mudler/LocalAI/pkg/xio" "github.com/rs/zerolog/log" ) @@ -49,17 +50,16 @@ func loadConfig() string { return HF_ENDPOINT } -func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error { - return uri.DownloadWithAuthorizationAndCallback(context.Background(), basePath, "", f) +func (uri URI) ReadWithCallback(basePath string, f func(url string, i []byte) error) error { + return uri.ReadWithAuthorizationAndCallback(context.Background(), basePath, "", f) } -func (uri URI) DownloadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error { +func (uri URI) ReadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error { url := uri.ResolveURL() - if strings.HasPrefix(url, LocalPrefix) { - rawURL := strings.TrimPrefix(url, LocalPrefix) + if strings.HasPrefix(string(uri), LocalPrefix) { // checks if the file is symbolic, and resolve if so - otherwise, this function returns the path unmodified. - resolvedFile, err := filepath.EvalSymlinks(rawURL) + resolvedFile, err := filepath.EvalSymlinks(url) if err != nil { return err } @@ -175,6 +175,8 @@ func (s URI) LooksLikeOCIFile() bool { func (s URI) ResolveURL() string { switch { + case strings.HasPrefix(string(s), LocalPrefix): + return strings.TrimPrefix(string(s), LocalPrefix) case strings.HasPrefix(string(s), GithubURI2): repository := strings.Replace(string(s), GithubURI2, "", 1) @@ -311,11 +313,6 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus) } - // We need to check if url looks like an URL or bail out - if !URI(url).LooksLikeHTTPURL() { - return fmt.Errorf("url %q does not look like an HTTP URL", url) - } - // Check for cancellation before starting select { case <-ctx.Done(): @@ -326,6 +323,7 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // Check if the file already exists _, err := os.Stat(filePath) if err == nil { + log.Debug().Str("filePath", filePath).Msg("[downloader] File already exists") // File exists, check SHA if sha != "" { // Verify SHA @@ -350,12 +348,12 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string log.Debug().Msgf("File %q already exists. Skipping download", filePath) return nil } - } else if !os.IsNotExist(err) { + } else if !os.IsNotExist(err) || !URI(url).LooksLikeHTTPURL() { // Error occurred while checking file existence - return fmt.Errorf("failed to check file %q existence: %v", filePath, err) + return fmt.Errorf("file %s does not exist (%v) and %s does not look like an HTTP URL", filePath, err, url) } - log.Info().Msgf("Downloading %q", url) + log.Info().Msgf("Downloading %s", url) req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { @@ -365,7 +363,7 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string // save partial download to dedicated file tmpFilePath := filePath + ".partial" tmpFileInfo, err := os.Stat(tmpFilePath) - if err == nil { + if err == nil && uri.LooksLikeHTTPURL() { support, err := uri.checkSeverSupportsRangeHeader() if err != nil { return fmt.Errorf("failed to check if uri server supports range header: %v", err) @@ -383,22 +381,40 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string return fmt.Errorf("failed to check file %q existence: %v", filePath, err) } - // Start the request - resp, err := http.DefaultClient.Do(req) - if err != nil { - // Check if error is due to context cancellation - if errors.Is(err, context.Canceled) { - // Clean up partial file on cancellation - removePartialFile(tmpFilePath) - return err + var source io.ReadCloser + var contentLength int64 + if _, e := os.Stat(uri.ResolveURL()); strings.HasPrefix(string(uri), LocalPrefix) || e == nil { + file, err := os.Open(uri.ResolveURL()) + if err != nil { + return fmt.Errorf("failed to open file %q: %v", uri.ResolveURL(), err) } - return fmt.Errorf("failed to download file %q: %v", filePath, err) - } - defer resp.Body.Close() + l, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to get file size %q: %v", uri.ResolveURL(), err) + } + source = file + contentLength = l.Size() + } else { + // Start the request + resp, err := http.DefaultClient.Do(req) + if err != nil { + // Check if error is due to context cancellation + if errors.Is(err, context.Canceled) { + // Clean up partial file on cancellation + removePartialFile(tmpFilePath) + return err + } + return fmt.Errorf("failed to download file %q: %v", filePath, err) + } + //defer resp.Body.Close() - if resp.StatusCode >= 400 { - return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) + if resp.StatusCode >= 400 { + return fmt.Errorf("failed to download url %q, invalid status code %d", url, resp.StatusCode) + } + source = resp.Body + contentLength = resp.ContentLength } + defer source.Close() // Create parent directory err = os.MkdirAll(filepath.Dir(filePath), 0750) @@ -418,14 +434,15 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string } progress := &progressWriter{ fileName: tmpFilePath, - total: resp.ContentLength, + total: contentLength, hash: hash, fileNo: fileN, totalFiles: total, downloadStatus: downloadStatus, ctx: ctx, } - _, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body) + + _, err = xio.Copy(ctx, io.MultiWriter(outFile, progress), source) if err != nil { // Check if error is due to context cancellation if errors.Is(err, context.Canceled) { diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index 17ade771661d..57186907777b 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -20,7 +20,7 @@ var _ = Describe("Gallery API tests", func() { It("parses github with a branch", func() { uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml") Expect( - uri.DownloadWithCallback("", func(url string, i []byte) error { + uri.ReadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -30,7 +30,7 @@ var _ = Describe("Gallery API tests", func() { uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main") Expect( - uri.DownloadWithCallback("", func(url string, i []byte) error { + uri.ReadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), @@ -39,7 +39,7 @@ var _ = Describe("Gallery API tests", func() { It("parses github with urls", func() { uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml") Expect( - uri.DownloadWithCallback("", func(url string, i []byte) error { + uri.ReadWithCallback("", func(url string, i []byte) error { Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) return nil }), diff --git a/pkg/huggingface-api/client.go b/pkg/huggingface-api/client.go index 9b1959f1d857..74494ed0b866 100644 --- a/pkg/huggingface-api/client.go +++ b/pkg/huggingface-api/client.go @@ -185,7 +185,7 @@ func (c *Client) ListFiles(repoID string) ([]FileInfo, error) { func (c *Client) GetFileSHA(repoID, fileName string) (string, error) { files, err := c.ListFiles(repoID) if err != nil { - return "", fmt.Errorf("failed to list files: %w", err) + return "", fmt.Errorf("failed to list files while getting SHA: %w", err) } for _, file := range files {