Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/gallery/importers/importers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)

var DefaultImporters = []Importer{
var defaultImporters = []Importer{
&LlamaCPPImporter{},
&MLXImporter{},
&VLLMImporter{},
&TransformersImporter{},
}

type Details struct {
Expand Down Expand Up @@ -52,7 +54,7 @@ func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.Model
Preferences: preferences,
}

for _, importer := range DefaultImporters {
for _, importer := range defaultImporters {
if importer.Match(details) {
modelConfig, err = importer.Import(details)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions core/gallery/importers/llama-cpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
mmprojQuantsList = strings.Split(mmprojQuants, ",")
}

embeddings, _ := preferencesMap["embeddings"].(string)

modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
Backend: "llama-cpp",
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
Expand All @@ -95,6 +98,11 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error)
},
}

if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
trueV := true
modelConfig.Embeddings = &trueV
}

cfg := gallery.ModelConfig{
Name: name,
Description: description,
Expand Down
21 changes: 11 additions & 10 deletions core/gallery/importers/llama-cpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@ import (
"fmt"

"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("LlamaCPPImporter", func() {
var importer *importers.LlamaCPPImporter
var importer *LlamaCPPImporter

BeforeEach(func() {
importer = &importers.LlamaCPPImporter{}
importer = &LlamaCPPImporter{}
})

Context("Match", func() {
It("should match when URI ends with .gguf", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/model.gguf",
}

Expand All @@ -28,7 +29,7 @@ var _ = Describe("LlamaCPPImporter", func() {

It("should match when backend preference is llama-cpp", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
Expand All @@ -38,7 +39,7 @@ var _ = Describe("LlamaCPPImporter", func() {
})

It("should not match when URI does not end with .gguf and no backend preference", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/model.bin",
}

Expand All @@ -48,7 +49,7 @@ var _ = Describe("LlamaCPPImporter", func() {

It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
Expand All @@ -59,7 +60,7 @@ var _ = Describe("LlamaCPPImporter", func() {

It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
details := Details{
URI: "https://example.com/model.gguf",
Preferences: preferences,
}
Expand All @@ -72,7 +73,7 @@ var _ = Describe("LlamaCPPImporter", func() {

Context("Import", func() {
It("should import model config with default name and description", func() {
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
}

Expand All @@ -89,7 +90,7 @@ var _ = Describe("LlamaCPPImporter", func() {

It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
Expand All @@ -106,7 +107,7 @@ var _ = Describe("LlamaCPPImporter", func() {

It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
Expand Down
110 changes: 110 additions & 0 deletions core/gallery/importers/transformers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package importers

import (
"encoding/json"
"path/filepath"
"strings"

"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)

var _ Importer = &TransformersImporter{}

type TransformersImporter struct{}

func (i *TransformersImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}

b, ok := preferencesMap["backend"].(string)
if ok && b == "transformers" {
return true
}

if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}

return false
}

func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}

name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}

description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}

backend := "transformers"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}

modelType, ok := preferencesMap["type"].(string)
if !ok {
modelType = "AutoModelForCausalLM"
}

quantization, ok := preferencesMap["quantization"].(string)
if !ok {
quantization = ""
}

modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
modelConfig.ModelType = modelType
modelConfig.Quantization = quantization

data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}

return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}
Loading
Loading