Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions core/application/distributed.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ func (ds *DistributedServices) Shutdown() {
// initDistributed validates distributed mode prerequisites and initializes
// NATS, object storage, node registry, and instance identity.
// Returns nil if distributed mode is not enabled.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) {
// configLoader is used by the SmartRouter to compute concurrency-group
// anti-affinity at placement time (#9659); it may be nil in tests.
func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) {
if !cfg.Distributed.Enabled {
return nil, nil
}
Expand Down Expand Up @@ -234,12 +236,17 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut
remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient)

// All dependencies ready — build SmartRouter with all options at once
var conflictResolver nodes.ConcurrencyConflictResolver
if configLoader != nil {
conflictResolver = configLoader
}
router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
Unloader: remoteUnloader,
FileStager: fileStager,
GalleriesJSON: routerGalleriesJSON,
AuthToken: routerAuthToken,
DB: authDB,
ConflictResolver: conflictResolver,
})

// Create ReplicaReconciler for auto-scaling model replicas. Adapter +
Expand Down
8 changes: 7 additions & 1 deletion core/application/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func New(opts ...config.AppOption) (*Application, error) {
}

// Initialize distributed mode services (NATS, object storage, node registry)
distSvc, err := initDistributed(options, application.authDB)
distSvc, err := initDistributed(options, application.authDB, application.ModelConfigLoader())
if err != nil {
return nil, fmt.Errorf("distributed mode initialization failed: %w", err)
}
Expand Down Expand Up @@ -680,6 +680,12 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon
options.LRUEvictionRetryInterval,
)

// Sync per-model state from configs to the watchdog. Without this,
// `pinned: true` and `concurrency_groups:` are only honored after a
// settings-driven RestartWatchdog and never at boot.
application.SyncPinnedModelsToWatchdog()
application.SyncModelGroupsToWatchdog()

// Start watchdog goroutine if any periodic checks are enabled
// LRU eviction doesn't need the Run() loop - it's triggered on model load
// But memory reclaimer needs the Run() loop for periodic checking
Expand Down
41 changes: 39 additions & 2 deletions core/application/watchdog.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package application

import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/xlog"
)
Expand All @@ -26,6 +27,40 @@ func (a *Application) SyncPinnedModelsToWatchdog() {
xlog.Debug("Synced pinned models to watchdog", "count", len(pinned))
}

// SyncModelGroupsToWatchdog reads concurrency_groups from all model configs and
// updates the watchdog so EnforceGroupExclusivity has the current view.
func (a *Application) SyncModelGroupsToWatchdog() {
cl := a.ModelConfigLoader()
if cl == nil {
return
}
wd := a.modelLoader.GetWatchDog()
if wd == nil {
return
}
groups := extractModelGroupsFromConfigs(cl.GetAllModelsConfigs())
wd.ReplaceModelGroups(groups)
xlog.Debug("Synced concurrency groups to watchdog", "count", len(groups))
}

// extractModelGroupsFromConfigs builds the model→groups map the watchdog
// expects. Disabled models are skipped — their declared groups should not
// block other models from loading.
func extractModelGroupsFromConfigs(configs []config.ModelConfig) map[string][]string {
out := make(map[string][]string)
for _, cfg := range configs {
if cfg.IsDisabled() {
continue
}
gs := cfg.GetConcurrencyGroups()
if len(gs) == 0 {
continue
}
out[cfg.Name] = gs
}
return out
}

func (a *Application) StopWatchdog() error {
if a.watchdogStop != nil {
close(a.watchdogStop)
Expand Down Expand Up @@ -65,8 +100,9 @@ func (a *Application) startWatchdog() error {
// Set the watchdog on the model loader
a.modelLoader.SetWatchDog(wd)

// Sync pinned models from config to the watchdog
// Sync pinned models and concurrency groups from config to the watchdog
a.SyncPinnedModelsToWatchdog()
a.SyncModelGroupsToWatchdog()

// Start watchdog goroutine if any periodic checks are enabled
// LRU eviction doesn't need the Run() loop - it's triggered on model load
Expand Down Expand Up @@ -148,8 +184,9 @@ func (a *Application) RestartWatchdog() error {
newWD.RestoreState(oldState)
}

// Re-sync pinned models after restart
// Re-sync pinned models and concurrency groups after restart
a.SyncPinnedModelsToWatchdog()
a.SyncModelGroupsToWatchdog()

return nil
}
47 changes: 47 additions & 0 deletions core/application/watchdog_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package application

import (
"github.com/mudler/LocalAI/core/config"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("extractModelGroupsFromConfigs", func() {
It("returns an empty map when no config declares groups", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a"},
{Name: "b"},
})
Expect(out).To(BeEmpty())
})

It("returns each model's normalized groups", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a", ConcurrencyGroups: []string{" heavy ", "vision", "heavy"}},
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
{Name: "c"}, // no groups → omitted
})
Expect(out).To(HaveLen(2))
Expect(out["a"]).To(Equal([]string{"heavy", "vision"}))
Expect(out["b"]).To(Equal([]string{"heavy"}))
Expect(out).ToNot(HaveKey("c"))
})

It("omits models whose groups normalize to empty", func() {
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "blanks", ConcurrencyGroups: []string{"", " "}},
})
Expect(out).To(BeEmpty())
})

It("skips disabled models so they cannot block loading after re-enable", func() {
disabled := true
out := extractModelGroupsFromConfigs([]config.ModelConfig{
{Name: "a", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled},
{Name: "b", ConcurrencyGroups: []string{"heavy"}},
})
Expect(out).To(HaveLen(1))
Expect(out).To(HaveKey("b"))
Expect(out).ToNot(HaveKey("a"))
})
})
27 changes: 27 additions & 0 deletions core/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ type ModelConfig struct {
Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"`
Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"`

// ConcurrencyGroups declares per-node mutual-exclusion groups: the model
// cannot be loaded alongside another model that shares any group name.
// See docs/content/advanced/vram-management.md for usage.
ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"`

Options []string `yaml:"options,omitempty" json:"options,omitempty"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`

Expand Down Expand Up @@ -587,6 +592,28 @@ func (c *ModelConfig) IsPinned() bool {
return c.Pinned != nil && *c.Pinned
}

// GetConcurrencyGroups returns the model's concurrency groups, normalized:
// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no
// effective groups remain. The result is a fresh slice; the caller may
// mutate it without affecting the config.
func (c *ModelConfig) GetConcurrencyGroups() []string {
if len(c.ConcurrencyGroups) == 0 {
return nil
}
out := make([]string, 0, len(c.ConcurrencyGroups))
for _, g := range c.ConcurrencyGroups {
g = strings.TrimSpace(g)
if g == "" || slices.Contains(out, g) {
continue
}
out = append(out, g)
}
if len(out) == 0 {
return nil
}
return out
}

type ModelConfigUsecase int

const (
Expand Down
34 changes: 34 additions & 0 deletions core/config/model_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,40 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) {
delete(bcl.configs, m)
}

// GetModelsConflictingWith returns the names of every other configured (and
// not-disabled) model that shares at least one concurrency group with the
// named model. Returns nil if the named model has no groups, is unknown, or
// has no peers in any of its groups. The result excludes the queried name.
func (bcl *ModelConfigLoader) GetModelsConflictingWith(name string) []string {
bcl.Lock()
defer bcl.Unlock()
target, ok := bcl.configs[name]
if !ok {
return nil
}
targetGroups := target.GetConcurrencyGroups()
if len(targetGroups) == 0 {
return nil
}
var conflicts []string
for n, cfg := range bcl.configs {
if n == name || cfg.IsDisabled() {
continue
}
other := cfg.GetConcurrencyGroups()
if len(other) == 0 {
continue
}
for _, g := range targetGroups {
if slices.Contains(other, g) {
conflicts = append(conflicts, n)
break
}
}
}
return conflicts
}

// UpdateModelConfig updates an existing model config in the loader.
// This is useful for updating runtime-detected properties like thinking support.
func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) {
Expand Down
63 changes: 63 additions & 0 deletions core/config/model_config_loader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package config

import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("ModelConfigLoader.GetModelsConflictingWith", func() {
var bcl *ModelConfigLoader

BeforeEach(func() {
bcl = NewModelConfigLoader("/tmp/conflict-test-models")
})

insert := func(cfg ModelConfig) {
bcl.Lock()
bcl.configs[cfg.Name] = cfg
bcl.Unlock()
}

It("returns nil when the named model has no groups", func() {
insert(ModelConfig{Name: "loner"})
Expect(bcl.GetModelsConflictingWith("loner")).To(BeNil())
})

It("returns nil when the named model is unknown", func() {
Expect(bcl.GetModelsConflictingWith("ghost")).To(BeNil())
})

It("returns nil when no other model shares a group", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"vision"}})
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
})

It("returns models that share at least one group", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "c", ConcurrencyGroups: []string{"vision"}})
insert(ModelConfig{Name: "d", ConcurrencyGroups: []string{"heavy", "vision"}})

conflicts := bcl.GetModelsConflictingWith("a")
Expect(conflicts).To(ConsistOf("b", "d"))
})

It("never lists the queried model itself", func() {
insert(ModelConfig{Name: "self", ConcurrencyGroups: []string{"heavy"}})
Expect(bcl.GetModelsConflictingWith("self")).To(BeNil())
})

It("ignores disabled conflicting models", func() {
disabled := true
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled})
Expect(bcl.GetModelsConflictingWith("a")).To(BeNil())
})

It("normalizes groups so whitespace and duplicates do not break overlap", func() {
insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{" heavy "}})
insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy", "heavy"}})
Expect(bcl.GetModelsConflictingWith("a")).To(ConsistOf("b"))
})
})
49 changes: 49 additions & 0 deletions core/config/model_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,53 @@ mcp:
Expect(err).To(BeNil())
Expect(valid).To(BeTrue())
})
Context("ConcurrencyGroups", func() {
It("returns nil when no groups are configured", func() {
cfg := &ModelConfig{Name: "no-groups"}
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
})
It("returns nil when all entries are blank", func() {
cfg := &ModelConfig{
Name: "blanks",
ConcurrencyGroups: []string{"", " ", "\t"},
}
Expect(cfg.GetConcurrencyGroups()).To(BeNil())
})
It("trims whitespace, drops empty entries, and dedupes", func() {
cfg := &ModelConfig{
Name: "messy",
ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "},
}
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"}))
})
It("returns a defensive copy", func() {
cfg := &ModelConfig{
Name: "copy",
ConcurrencyGroups: []string{"heavy"},
}
got := cfg.GetConcurrencyGroups()
got[0] = "tampered"
Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"}))
})
It("parses concurrency_groups from YAML", func() {
tmp, err := os.CreateTemp("", "concgroups.yaml")
Expect(err).To(BeNil())
defer func() { _ = os.Remove(tmp.Name()) }()
_, err = tmp.WriteString(
`name: heavy-a
backend: llama-cpp
parameters:
model: heavy-a.gguf
concurrency_groups:
- vram-heavy
- "120b"
`)
Expect(err).ToNot(HaveOccurred())
configs, err := readModelConfigsFromFile(tmp.Name())
Expect(err).To(BeNil())
Expect(configs).To(HaveLen(1))
Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"}))
Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"}))
})
})
})
9 changes: 9 additions & 0 deletions core/services/nodes/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ type ModelRouter interface {
FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error)
GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error)
FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error)
}

// ConcurrencyConflictResolver returns the names of configured models that
// share at least one concurrency group with the given model. It is satisfied
// by *config.ModelConfigLoader and lets the SmartRouter make group-aware
// placement decisions without importing the config package's full surface.
type ConcurrencyConflictResolver interface {
GetModelsConflictingWith(modelName string) []string
}

// NodeHealthStore is used by HealthMonitor for node status management.
Expand Down
Loading
Loading