diff --git a/core/gallery/importers/diffuser.go b/core/gallery/importers/diffuser.go new file mode 100644 index 000000000000..c702da3d3025 --- /dev/null +++ b/core/gallery/importers/diffuser.go @@ -0,0 +1,121 @@ +package importers + +import ( + "encoding/json" + "path/filepath" + "strings" + + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/gallery" + "github.com/mudler/LocalAI/core/schema" + "gopkg.in/yaml.v3" +) + +var _ Importer = &DiffuserImporter{} + +type DiffuserImporter struct{} + +func (i *DiffuserImporter) Match(details Details) bool { + preferences, err := details.Preferences.MarshalJSON() + if err != nil { + return false + } + preferencesMap := make(map[string]any) + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + return false + } + + b, ok := preferencesMap["backend"].(string) + if ok && b == "diffusers" { + return true + } + + if details.HuggingFace != nil { + for _, file := range details.HuggingFace.Files { + if strings.Contains(file.Path, "model_index.json") || + strings.Contains(file.Path, "scheduler/scheduler_config.json") { + return true + } + } + } + + return false +} + +func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) { + preferences, err := details.Preferences.MarshalJSON() + if err != nil { + return gallery.ModelConfig{}, err + } + preferencesMap := make(map[string]any) + err = json.Unmarshal(preferences, &preferencesMap) + if err != nil { + return gallery.ModelConfig{}, err + } + + name, ok := preferencesMap["name"].(string) + if !ok { + name = filepath.Base(details.URI) + } + + description, ok := preferencesMap["description"].(string) + if !ok { + description = "Imported from " + details.URI + } + + backend := "diffusers" + b, ok := preferencesMap["backend"].(string) + if ok { + backend = b + } + + pipelineType, ok := preferencesMap["pipeline_type"].(string) + if !ok { + pipelineType = "StableDiffusionPipeline" + } + + schedulerType, ok := preferencesMap["scheduler_type"].(string) + if !ok { + schedulerType = "" + } + + enableParameters, ok := preferencesMap["enable_parameters"].(string) + if !ok { + enableParameters = "negative_prompt,num_inference_steps" + } + + cuda := false + if cudaVal, ok := preferencesMap["cuda"].(bool); ok { + cuda = cudaVal + } + + modelConfig := config.ModelConfig{ + Name: name, + Description: description, + KnownUsecaseStrings: []string{"image"}, + Backend: backend, + PredictionOptions: schema.PredictionOptions{ + BasicModelRequest: schema.BasicModelRequest{ + Model: details.URI, + }, + }, + Diffusers: config.Diffusers{ + PipelineType: pipelineType, + SchedulerType: schedulerType, + EnableParameters: enableParameters, + CUDA: cuda, + }, + } + + data, err := yaml.Marshal(modelConfig) + if err != nil { + return gallery.ModelConfig{}, err + } + + return gallery.ModelConfig{ + Name: name, + Description: description, + ConfigFile: string(data), + }, nil +} diff --git a/core/gallery/importers/diffuser_test.go b/core/gallery/importers/diffuser_test.go new file mode 100644 index 000000000000..38765e88bade --- /dev/null +++ b/core/gallery/importers/diffuser_test.go @@ -0,0 +1,246 @@ +package importers_test + +import ( + "encoding/json" + + "github.com/mudler/LocalAI/core/gallery/importers" + . "github.com/mudler/LocalAI/core/gallery/importers" + hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("DiffuserImporter", func() { + var importer *DiffuserImporter + + BeforeEach(func() { + importer = &DiffuserImporter{} + }) + + Context("Match", func() { + It("should match when backend preference is diffusers", func() { + preferences := json.RawMessage(`{"backend": "diffusers"}`) + details := Details{ + URI: "https://example.com/model", + Preferences: preferences, + } + + result := importer.Match(details) + Expect(result).To(BeTrue()) + }) + + It("should match when HuggingFace details contain model_index.json", func() { + hfDetails := &hfapi.ModelDetails{ + Files: []hfapi.ModelFile{ + {Path: "model_index.json"}, + }, + } + details := Details{ + URI: "https://huggingface.co/test/model", + HuggingFace: hfDetails, + } + + result := importer.Match(details) + Expect(result).To(BeTrue()) + }) + + It("should match when HuggingFace details contain scheduler config", func() { + hfDetails := &hfapi.ModelDetails{ + Files: []hfapi.ModelFile{ + {Path: "scheduler/scheduler_config.json"}, + }, + } + details := Details{ + URI: "https://huggingface.co/test/model", + HuggingFace: hfDetails, + } + + result := importer.Match(details) + Expect(result).To(BeTrue()) + }) + + It("should not match when URI has no diffuser files and no backend preference", func() { + details := Details{ + URI: "https://example.com/model.bin", + } + + result := importer.Match(details) + Expect(result).To(BeFalse()) + }) + + It("should not match when backend preference is different", func() { + preferences := json.RawMessage(`{"backend": "llama-cpp"}`) + details := Details{ + URI: "https://example.com/model", + Preferences: preferences, + } + + result := importer.Match(details) + Expect(result).To(BeFalse()) + }) + + It("should return false when JSON preferences are invalid", func() { + preferences := json.RawMessage(`invalid json`) + details := Details{ + URI: "https://example.com/model", + Preferences: preferences, + } + + result := importer.Match(details) + Expect(result).To(BeFalse()) + }) + }) + + Context("Import", func() { + It("should import model config with default name and description", func() { + details := Details{ + URI: "https://huggingface.co/test/my-diffuser-model", + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("my-diffuser-model")) + Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-diffuser-model")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-diffuser-model")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: negative_prompt,num_inference_steps")) + }) + + It("should import model config with custom name and description from preferences", func() { + preferences := json.RawMessage(`{"name": "custom-diffuser", "description": "Custom diffuser model"}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("custom-diffuser")) + Expect(modelConfig.Description).To(Equal("Custom diffuser model")) + }) + + It("should use custom pipeline_type from preferences", func() { + preferences := json.RawMessage(`{"pipeline_type": "StableDiffusion3Pipeline"}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusion3Pipeline")) + }) + + It("should use default pipeline_type when not specified", func() { + details := Details{ + URI: "https://huggingface.co/test/my-model", + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline")) + }) + + It("should use custom scheduler_type from preferences", func() { + preferences := json.RawMessage(`{"scheduler_type": "k_dpmpp_2m"}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("scheduler_type: k_dpmpp_2m")) + }) + + It("should use cuda setting from preferences", func() { + preferences := json.RawMessage(`{"cuda": true}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("cuda: true")) + }) + + It("should use custom enable_parameters from preferences", func() { + preferences := json.RawMessage(`{"enable_parameters": "num_inference_steps,guidance_scale"}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: num_inference_steps,guidance_scale")) + }) + + It("should use custom backend from preferences", func() { + preferences := json.RawMessage(`{"backend": "diffusers"}`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers")) + }) + + It("should handle invalid JSON preferences", func() { + preferences := json.RawMessage(`invalid json`) + details := Details{ + URI: "https://huggingface.co/test/my-model", + Preferences: preferences, + } + + _, err := importer.Import(details) + Expect(err).To(HaveOccurred()) + }) + + It("should extract filename correctly from URI with path", func() { + details := importers.Details{ + URI: "https://huggingface.co/test/path/to/model", + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.Name).To(Equal("model")) + }) + + It("should include known_usecases as image in config", func() { + details := Details{ + URI: "https://huggingface.co/test/my-model", + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:")) + Expect(modelConfig.ConfigFile).To(ContainSubstring("- image")) + }) + + It("should include diffusers configuration in config", func() { + details := Details{ + URI: "https://huggingface.co/test/my-model", + } + + modelConfig, err := importer.Import(details) + + Expect(err).ToNot(HaveOccurred()) + Expect(modelConfig.ConfigFile).To(ContainSubstring("diffusers:")) + }) + }) +}) diff --git a/core/gallery/importers/importers.go b/core/gallery/importers/importers.go index 283a3349a5e4..76020ca80cd8 100644 --- a/core/gallery/importers/importers.go +++ b/core/gallery/importers/importers.go @@ -20,6 +20,7 @@ var defaultImporters = []Importer{ &MLXImporter{}, &VLLMImporter{}, &TransformersImporter{}, + &DiffuserImporter{}, } type Details struct { diff --git a/core/http/views/model-editor.html b/core/http/views/model-editor.html index 734a327db205..8729d0829697 100644 --- a/core/http/views/model-editor.html +++ b/core/http/views/model-editor.html @@ -299,6 +299,7 @@
Force a specific backend. Leave empty to auto-detect from URI. @@ -401,6 +402,71 @@
+ Pipeline type for diffusers backend. Examples: StableDiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline. Leave empty to use default (StableDiffusionPipeline). +
++ Scheduler type for diffusers backend. Examples: k_dpmpp_2m, euler_a, ddim. Leave empty to use model default. +
++ Enabled parameters for diffusers backend (comma-separated). Leave empty to use default (negative_prompt,num_inference_steps). +
++ Enable CUDA support for GPU acceleration with diffusers backend. +
+