diff --git a/core/application/distributed.go b/core/application/distributed.go index 26d56d12185d..2a77d5b3c659 100644 --- a/core/application/distributed.go +++ b/core/application/distributed.go @@ -71,7 +71,9 @@ func (ds *DistributedServices) Shutdown() { // initDistributed validates distributed mode prerequisites and initializes // NATS, object storage, node registry, and instance identity. // Returns nil if distributed mode is not enabled. -func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*DistributedServices, error) { +// configLoader is used by the SmartRouter to compute concurrency-group +// anti-affinity at placement time (#9659); it may be nil in tests. +func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoader *config.ModelConfigLoader) (*DistributedServices, error) { if !cfg.Distributed.Enabled { return nil, nil } @@ -234,12 +236,17 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB) (*Distribut remoteUnloader := nodes.NewRemoteUnloaderAdapter(registry, natsClient) // All dependencies ready — build SmartRouter with all options at once + var conflictResolver nodes.ConcurrencyConflictResolver + if configLoader != nil { + conflictResolver = configLoader + } router := nodes.NewSmartRouter(registry, nodes.SmartRouterOptions{ - Unloader: remoteUnloader, - FileStager: fileStager, - GalleriesJSON: routerGalleriesJSON, - AuthToken: routerAuthToken, - DB: authDB, + Unloader: remoteUnloader, + FileStager: fileStager, + GalleriesJSON: routerGalleriesJSON, + AuthToken: routerAuthToken, + DB: authDB, + ConflictResolver: conflictResolver, }) // Create ReplicaReconciler for auto-scaling model replicas. Adapter + diff --git a/core/application/startup.go b/core/application/startup.go index 4744ea8311bd..8747125e8284 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -139,7 +139,7 @@ func New(opts ...config.AppOption) (*Application, error) { } // Initialize distributed mode services (NATS, object storage, node registry) - distSvc, err := initDistributed(options, application.authDB) + distSvc, err := initDistributed(options, application.authDB, application.ModelConfigLoader()) if err != nil { return nil, fmt.Errorf("distributed mode initialization failed: %w", err) } @@ -680,6 +680,12 @@ func initializeWatchdog(application *Application, options *config.ApplicationCon options.LRUEvictionRetryInterval, ) + // Sync per-model state from configs to the watchdog. Without this, + // `pinned: true` and `concurrency_groups:` are only honored after a + // settings-driven RestartWatchdog and never at boot. + application.SyncPinnedModelsToWatchdog() + application.SyncModelGroupsToWatchdog() + // Start watchdog goroutine if any periodic checks are enabled // LRU eviction doesn't need the Run() loop - it's triggered on model load // But memory reclaimer needs the Run() loop for periodic checking diff --git a/core/application/watchdog.go b/core/application/watchdog.go index c1aee6c7adb5..9658b51149df 100644 --- a/core/application/watchdog.go +++ b/core/application/watchdog.go @@ -1,6 +1,7 @@ package application import ( + "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) @@ -26,6 +27,40 @@ func (a *Application) SyncPinnedModelsToWatchdog() { xlog.Debug("Synced pinned models to watchdog", "count", len(pinned)) } +// SyncModelGroupsToWatchdog reads concurrency_groups from all model configs and +// updates the watchdog so EnforceGroupExclusivity has the current view. +func (a *Application) SyncModelGroupsToWatchdog() { + cl := a.ModelConfigLoader() + if cl == nil { + return + } + wd := a.modelLoader.GetWatchDog() + if wd == nil { + return + } + groups := extractModelGroupsFromConfigs(cl.GetAllModelsConfigs()) + wd.ReplaceModelGroups(groups) + xlog.Debug("Synced concurrency groups to watchdog", "count", len(groups)) +} + +// extractModelGroupsFromConfigs builds the model→groups map the watchdog +// expects. Disabled models are skipped — their declared groups should not +// block other models from loading. +func extractModelGroupsFromConfigs(configs []config.ModelConfig) map[string][]string { + out := make(map[string][]string) + for _, cfg := range configs { + if cfg.IsDisabled() { + continue + } + gs := cfg.GetConcurrencyGroups() + if len(gs) == 0 { + continue + } + out[cfg.Name] = gs + } + return out +} + func (a *Application) StopWatchdog() error { if a.watchdogStop != nil { close(a.watchdogStop) @@ -65,8 +100,9 @@ func (a *Application) startWatchdog() error { // Set the watchdog on the model loader a.modelLoader.SetWatchDog(wd) - // Sync pinned models from config to the watchdog + // Sync pinned models and concurrency groups from config to the watchdog a.SyncPinnedModelsToWatchdog() + a.SyncModelGroupsToWatchdog() // Start watchdog goroutine if any periodic checks are enabled // LRU eviction doesn't need the Run() loop - it's triggered on model load @@ -148,8 +184,9 @@ func (a *Application) RestartWatchdog() error { newWD.RestoreState(oldState) } - // Re-sync pinned models after restart + // Re-sync pinned models and concurrency groups after restart a.SyncPinnedModelsToWatchdog() + a.SyncModelGroupsToWatchdog() return nil } diff --git a/core/application/watchdog_test.go b/core/application/watchdog_test.go new file mode 100644 index 000000000000..ddb1a6849853 --- /dev/null +++ b/core/application/watchdog_test.go @@ -0,0 +1,47 @@ +package application + +import ( + "github.com/mudler/LocalAI/core/config" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("extractModelGroupsFromConfigs", func() { + It("returns an empty map when no config declares groups", func() { + out := extractModelGroupsFromConfigs([]config.ModelConfig{ + {Name: "a"}, + {Name: "b"}, + }) + Expect(out).To(BeEmpty()) + }) + + It("returns each model's normalized groups", func() { + out := extractModelGroupsFromConfigs([]config.ModelConfig{ + {Name: "a", ConcurrencyGroups: []string{" heavy ", "vision", "heavy"}}, + {Name: "b", ConcurrencyGroups: []string{"heavy"}}, + {Name: "c"}, // no groups → omitted + }) + Expect(out).To(HaveLen(2)) + Expect(out["a"]).To(Equal([]string{"heavy", "vision"})) + Expect(out["b"]).To(Equal([]string{"heavy"})) + Expect(out).ToNot(HaveKey("c")) + }) + + It("omits models whose groups normalize to empty", func() { + out := extractModelGroupsFromConfigs([]config.ModelConfig{ + {Name: "blanks", ConcurrencyGroups: []string{"", " "}}, + }) + Expect(out).To(BeEmpty()) + }) + + It("skips disabled models so they cannot block loading after re-enable", func() { + disabled := true + out := extractModelGroupsFromConfigs([]config.ModelConfig{ + {Name: "a", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled}, + {Name: "b", ConcurrencyGroups: []string{"heavy"}}, + }) + Expect(out).To(HaveLen(1)) + Expect(out).To(HaveKey("b")) + Expect(out).ToNot(HaveKey("a")) + }) +}) diff --git a/core/config/model_config.go b/core/config/model_config.go index f1f0b30d3df9..5c97d477e8c3 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -87,6 +87,11 @@ type ModelConfig struct { Disabled *bool `yaml:"disabled,omitempty" json:"disabled,omitempty"` Pinned *bool `yaml:"pinned,omitempty" json:"pinned,omitempty"` + // ConcurrencyGroups declares per-node mutual-exclusion groups: the model + // cannot be loaded alongside another model that shares any group name. + // See docs/content/advanced/vram-management.md for usage. + ConcurrencyGroups []string `yaml:"concurrency_groups,omitempty" json:"concurrency_groups,omitempty"` + Options []string `yaml:"options,omitempty" json:"options,omitempty"` Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"` @@ -587,6 +592,28 @@ func (c *ModelConfig) IsPinned() bool { return c.Pinned != nil && *c.Pinned } +// GetConcurrencyGroups returns the model's concurrency groups, normalized: +// trimmed of whitespace, empty entries dropped, deduped. Returns nil when no +// effective groups remain. The result is a fresh slice; the caller may +// mutate it without affecting the config. +func (c *ModelConfig) GetConcurrencyGroups() []string { + if len(c.ConcurrencyGroups) == 0 { + return nil + } + out := make([]string, 0, len(c.ConcurrencyGroups)) + for _, g := range c.ConcurrencyGroups { + g = strings.TrimSpace(g) + if g == "" || slices.Contains(out, g) { + continue + } + out = append(out, g) + } + if len(out) == 0 { + return nil + } + return out +} + type ModelConfigUsecase int const ( diff --git a/core/config/model_config_loader.go b/core/config/model_config_loader.go index c3671a24f3ed..32b2bb38a03a 100644 --- a/core/config/model_config_loader.go +++ b/core/config/model_config_loader.go @@ -249,6 +249,40 @@ func (bcl *ModelConfigLoader) RemoveModelConfig(m string) { delete(bcl.configs, m) } +// GetModelsConflictingWith returns the names of every other configured (and +// not-disabled) model that shares at least one concurrency group with the +// named model. Returns nil if the named model has no groups, is unknown, or +// has no peers in any of its groups. The result excludes the queried name. +func (bcl *ModelConfigLoader) GetModelsConflictingWith(name string) []string { + bcl.Lock() + defer bcl.Unlock() + target, ok := bcl.configs[name] + if !ok { + return nil + } + targetGroups := target.GetConcurrencyGroups() + if len(targetGroups) == 0 { + return nil + } + var conflicts []string + for n, cfg := range bcl.configs { + if n == name || cfg.IsDisabled() { + continue + } + other := cfg.GetConcurrencyGroups() + if len(other) == 0 { + continue + } + for _, g := range targetGroups { + if slices.Contains(other, g) { + conflicts = append(conflicts, n) + break + } + } + } + return conflicts +} + // UpdateModelConfig updates an existing model config in the loader. // This is useful for updating runtime-detected properties like thinking support. func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) { diff --git a/core/config/model_config_loader_test.go b/core/config/model_config_loader_test.go new file mode 100644 index 000000000000..924a4d1e42a0 --- /dev/null +++ b/core/config/model_config_loader_test.go @@ -0,0 +1,63 @@ +package config + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ModelConfigLoader.GetModelsConflictingWith", func() { + var bcl *ModelConfigLoader + + BeforeEach(func() { + bcl = NewModelConfigLoader("/tmp/conflict-test-models") + }) + + insert := func(cfg ModelConfig) { + bcl.Lock() + bcl.configs[cfg.Name] = cfg + bcl.Unlock() + } + + It("returns nil when the named model has no groups", func() { + insert(ModelConfig{Name: "loner"}) + Expect(bcl.GetModelsConflictingWith("loner")).To(BeNil()) + }) + + It("returns nil when the named model is unknown", func() { + Expect(bcl.GetModelsConflictingWith("ghost")).To(BeNil()) + }) + + It("returns nil when no other model shares a group", func() { + insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}}) + insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"vision"}}) + Expect(bcl.GetModelsConflictingWith("a")).To(BeNil()) + }) + + It("returns models that share at least one group", func() { + insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}}) + insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}}) + insert(ModelConfig{Name: "c", ConcurrencyGroups: []string{"vision"}}) + insert(ModelConfig{Name: "d", ConcurrencyGroups: []string{"heavy", "vision"}}) + + conflicts := bcl.GetModelsConflictingWith("a") + Expect(conflicts).To(ConsistOf("b", "d")) + }) + + It("never lists the queried model itself", func() { + insert(ModelConfig{Name: "self", ConcurrencyGroups: []string{"heavy"}}) + Expect(bcl.GetModelsConflictingWith("self")).To(BeNil()) + }) + + It("ignores disabled conflicting models", func() { + disabled := true + insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{"heavy"}}) + insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy"}, Disabled: &disabled}) + Expect(bcl.GetModelsConflictingWith("a")).To(BeNil()) + }) + + It("normalizes groups so whitespace and duplicates do not break overlap", func() { + insert(ModelConfig{Name: "a", ConcurrencyGroups: []string{" heavy "}}) + insert(ModelConfig{Name: "b", ConcurrencyGroups: []string{"heavy", "heavy"}}) + Expect(bcl.GetModelsConflictingWith("a")).To(ConsistOf("b")) + }) +}) diff --git a/core/config/model_config_test.go b/core/config/model_config_test.go index 81b86fe3ff0a..c1216ec35b26 100644 --- a/core/config/model_config_test.go +++ b/core/config/model_config_test.go @@ -264,4 +264,53 @@ mcp: Expect(err).To(BeNil()) Expect(valid).To(BeTrue()) }) + Context("ConcurrencyGroups", func() { + It("returns nil when no groups are configured", func() { + cfg := &ModelConfig{Name: "no-groups"} + Expect(cfg.GetConcurrencyGroups()).To(BeNil()) + }) + It("returns nil when all entries are blank", func() { + cfg := &ModelConfig{ + Name: "blanks", + ConcurrencyGroups: []string{"", " ", "\t"}, + } + Expect(cfg.GetConcurrencyGroups()).To(BeNil()) + }) + It("trims whitespace, drops empty entries, and dedupes", func() { + cfg := &ModelConfig{ + Name: "messy", + ConcurrencyGroups: []string{" vram-heavy ", "", "vram-heavy", "vision", " vision "}, + } + Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "vision"})) + }) + It("returns a defensive copy", func() { + cfg := &ModelConfig{ + Name: "copy", + ConcurrencyGroups: []string{"heavy"}, + } + got := cfg.GetConcurrencyGroups() + got[0] = "tampered" + Expect(cfg.GetConcurrencyGroups()).To(Equal([]string{"heavy"})) + }) + It("parses concurrency_groups from YAML", func() { + tmp, err := os.CreateTemp("", "concgroups.yaml") + Expect(err).To(BeNil()) + defer func() { _ = os.Remove(tmp.Name()) }() + _, err = tmp.WriteString( + `name: heavy-a +backend: llama-cpp +parameters: + model: heavy-a.gguf +concurrency_groups: + - vram-heavy + - "120b" +`) + Expect(err).ToNot(HaveOccurred()) + configs, err := readModelConfigsFromFile(tmp.Name()) + Expect(err).To(BeNil()) + Expect(configs).To(HaveLen(1)) + Expect(configs[0].ConcurrencyGroups).To(Equal([]string{"vram-heavy", "120b"})) + Expect(configs[0].GetConcurrencyGroups()).To(Equal([]string{"vram-heavy", "120b"})) + }) + }) }) diff --git a/core/services/nodes/interfaces.go b/core/services/nodes/interfaces.go index c5f09105c5b5..086d83142b1a 100644 --- a/core/services/nodes/interfaces.go +++ b/core/services/nodes/interfaces.go @@ -35,6 +35,15 @@ type ModelRouter interface { FindIdleNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) FindLeastLoadedNodeFromSet(ctx context.Context, nodeIDs []string) (*BackendNode, error) GetNodeLabels(ctx context.Context, nodeID string) ([]NodeLabel, error) + FindNodesWithModel(ctx context.Context, modelName string) ([]BackendNode, error) +} + +// ConcurrencyConflictResolver returns the names of configured models that +// share at least one concurrency group with the given model. It is satisfied +// by *config.ModelConfigLoader and lets the SmartRouter make group-aware +// placement decisions without importing the config package's full surface. +type ConcurrencyConflictResolver interface { + GetModelsConflictingWith(modelName string) []string } // NodeHealthStore is used by HealthMonitor for node status management. diff --git a/core/services/nodes/model_router_test.go b/core/services/nodes/model_router_test.go index 4c0844158ea4..fedeb0e3d3ff 100644 --- a/core/services/nodes/model_router_test.go +++ b/core/services/nodes/model_router_test.go @@ -115,6 +115,9 @@ func (f *fakeModelRouterForSmartRouter) FindLeastLoadedNodeFromSet(_ context.Con func (f *fakeModelRouterForSmartRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabel, error) { return nil, nil } +func (f *fakeModelRouterForSmartRouter) FindNodesWithModel(_ context.Context, _ string) ([]BackendNode, error) { + return nil, nil +} // Compile-time check var _ ModelRouter = (*fakeModelRouterForSmartRouter)(nil) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index ea49255d85e0..7eedd51e14ac 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -37,18 +37,24 @@ type SmartRouterOptions struct { AuthToken string ClientFactory BackendClientFactory // optional; defaults to tokenClientFactory DB *gorm.DB // for advisory locks during routing + // ConflictResolver, when set, lets the scheduler narrow placement + // candidates by per-model concurrency_groups (#9659). When nil, group + // anti-affinity is disabled at the scheduler layer; the per-node + // watchdog still enforces the rule on arrival. + ConflictResolver ConcurrencyConflictResolver } // SmartRouter routes inference requests to the best available backend node. // It uses the ModelRouter interface (backed by NodeRegistry in production) for routing decisions. type SmartRouter struct { - registry ModelRouter - unloader NodeCommandSender // optional, for NATS-driven load/unload - fileStager FileStager // optional, for distributed file transfer - galleriesJSON string // backend gallery config for dynamic installation - clientFactory BackendClientFactory // creates gRPC backend clients - db *gorm.DB // for advisory locks during routing - stagingTracker *StagingTracker // tracks file staging progress for UI visibility + registry ModelRouter + unloader NodeCommandSender // optional, for NATS-driven load/unload + fileStager FileStager // optional, for distributed file transfer + galleriesJSON string // backend gallery config for dynamic installation + clientFactory BackendClientFactory // creates gRPC backend clients + db *gorm.DB // for advisory locks during routing + stagingTracker *StagingTracker // tracks file staging progress for UI visibility + conflictResolver ConcurrencyConflictResolver } // NewSmartRouter creates a new SmartRouter backed by the given ModelRouter. @@ -59,13 +65,14 @@ func NewSmartRouter(registry ModelRouter, opts SmartRouterOptions) *SmartRouter factory = &tokenClientFactory{token: opts.AuthToken} } return &SmartRouter{ - registry: registry, - unloader: opts.Unloader, - fileStager: opts.FileStager, - galleriesJSON: opts.GalleriesJSON, - clientFactory: factory, - db: opts.DB, - stagingTracker: NewStagingTracker(), + registry: registry, + unloader: opts.Unloader, + fileStager: opts.FileStager, + galleriesJSON: opts.GalleriesJSON, + clientFactory: factory, + db: opts.DB, + stagingTracker: NewStagingTracker(), + conflictResolver: opts.ConflictResolver, } } @@ -382,6 +389,60 @@ func (r *SmartRouter) resolveSelectorCandidates(ctx context.Context, modelID str return extractNodeIDs(candidates), nil } +// narrowByGroupAntiAffinity removes candidate nodes that already host a model +// declared as concurrent-conflicting with modelID via concurrency_groups +// (#9659). This is a soft filter: when *every* candidate would be excluded, +// the original set is returned and the per-node watchdog evicts on arrival. +// +// candidates may be nil ("any healthy node" — registry helpers treat nil as +// no filter). nil is returned unchanged: hard-narrowing the implicit "all +// nodes" set would silently exclude every node we know nothing about. +func (r *SmartRouter) narrowByGroupAntiAffinity(ctx context.Context, modelID string, candidates []string) ([]string, error) { + if r.conflictResolver == nil || candidates == nil { + return candidates, nil + } + conflicts := r.conflictResolver.GetModelsConflictingWith(modelID) + if len(conflicts) == 0 { + return candidates, nil + } + + excluded := make(map[string]struct{}) + for _, name := range conflicts { + nodes, err := r.registry.FindNodesWithModel(ctx, name) + if err != nil { + // Best-effort: a single lookup failure shouldn't fail placement. + // Log and move on — the watchdog still enforces the rule on arrival. + xlog.Warn("Group anti-affinity: lookup failed, skipping", "model", name, "error", err) + continue + } + for _, n := range nodes { + excluded[n.ID] = struct{}{} + } + } + if len(excluded) == 0 { + return candidates, nil + } + + narrowed := candidates[:0:0] + for _, id := range candidates { + if _, bad := excluded[id]; bad { + continue + } + narrowed = append(narrowed, id) + } + if len(narrowed) == 0 { + // Soft fallback: every candidate has a conflict. Return the original + // set and let the per-node watchdog evict on arrival rather than + // failing the request. + xlog.Debug("Group anti-affinity: all candidates conflict, falling back to original set", + "model", modelID, "conflicts", conflicts) + return candidates, nil + } + xlog.Debug("Group anti-affinity narrowed candidates", + "model", modelID, "before", len(candidates), "after", len(narrowed)) + return narrowed, nil +} + // nodeMatchesScheduling checks if a node satisfies the scheduling constraints for a model. // Returns true if no constraints exist or the node matches all selector labels. func (r *SmartRouter) nodeMatchesScheduling(ctx context.Context, node *BackendNode, modelName string) bool { @@ -438,6 +499,15 @@ func (r *SmartRouter) scheduleNewModel(ctx context.Context, backendType, modelID return nil, "", 0, err } + // Apply concurrency-group anti-affinity (#9659): prefer nodes that don't + // already host a model declared exclusive with this one. Soft filter — if + // every candidate has a conflict, the original set is returned and the + // per-node watchdog evicts on arrival. + candidateNodeIDs, err = r.narrowByGroupAntiAffinity(ctx, modelID, candidateNodeIDs) + if err != nil { + return nil, "", 0, err + } + // Narrow candidates to nodes that still have a free replica slot for this // model. Without this filter, the scheduler would happily pick a node // already at capacity for this model (e.g. when MinReplicas > free diff --git a/core/services/nodes/router_test.go b/core/services/nodes/router_test.go index 2e5d3ce9c5e9..f61613d27b16 100644 --- a/core/services/nodes/router_test.go +++ b/core/services/nodes/router_test.go @@ -106,6 +106,10 @@ type fakeModelRouter struct { getNodeLabels []NodeLabel getNodeLabelsErr error + // FindNodesWithModel returns (keyed by model name) + findNodesWithModelByName map[string][]BackendNode + findNodesWithModelErr error + // Track calls for assertions decrementCalls []string // "nodeID:modelName" incrementCalls []string @@ -228,6 +232,25 @@ func (f *fakeModelRouter) GetNodeLabels(_ context.Context, _ string) ([]NodeLabe return f.getNodeLabels, f.getNodeLabelsErr } +func (f *fakeModelRouter) FindNodesWithModel(_ context.Context, modelName string) ([]BackendNode, error) { + if f.findNodesWithModelErr != nil { + return nil, f.findNodesWithModelErr + } + return f.findNodesWithModelByName[modelName], nil +} + +// fakeConflictResolver implements ConcurrencyConflictResolver from a static map. +type fakeConflictResolver struct { + conflicts map[string][]string +} + +func (f *fakeConflictResolver) GetModelsConflictingWith(name string) []string { + if f == nil { + return nil + } + return f.conflicts[name] +} + // --------------------------------------------------------------------------- // Fake BackendClientFactory + Backend // --------------------------------------------------------------------------- @@ -847,4 +870,84 @@ var _ = Describe("SmartRouter", func() { Expect(original.MMProj).To(Equal(origMMProj)) }) }) + + // ----------------------------------------------------------------------- + // narrowByGroupAntiAffinity + // ----------------------------------------------------------------------- + Describe("narrowByGroupAntiAffinity", func() { + var ( + reg *fakeModelRouter + resolver *fakeConflictResolver + router *SmartRouter + ctx context.Context + ) + + BeforeEach(func() { + reg = &fakeModelRouter{} + resolver = &fakeConflictResolver{conflicts: map[string][]string{}} + router = NewSmartRouter(reg, SmartRouterOptions{ + ConflictResolver: resolver, + }) + ctx = context.Background() + }) + + It("returns the input set unchanged when the model has no conflicts", func() { + candidates := []string{"n1", "n2", "n3"} + out, err := router.narrowByGroupAntiAffinity(ctx, "lonely", candidates) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal(candidates)) + }) + + It("removes nodes that already host a conflicting model", func() { + resolver.conflicts["b"] = []string{"a"} + reg.findNodesWithModelByName = map[string][]BackendNode{ + "a": {{ID: "n1"}}, + } + out, err := router.narrowByGroupAntiAffinity(ctx, "b", []string{"n1", "n2"}) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(ConsistOf("n2")) + }) + + It("returns the original set unchanged when every candidate has a conflict (soft fallback)", func() { + resolver.conflicts["b"] = []string{"a"} + reg.findNodesWithModelByName = map[string][]BackendNode{ + "a": {{ID: "n1"}, {ID: "n2"}}, + } + candidates := []string{"n1", "n2"} + out, err := router.narrowByGroupAntiAffinity(ctx, "b", candidates) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal(candidates)) + }) + + It("removes nodes hosting any of multiple conflicting models", func() { + resolver.conflicts["c"] = []string{"a", "b"} + reg.findNodesWithModelByName = map[string][]BackendNode{ + "a": {{ID: "n1"}}, + "b": {{ID: "n2"}}, + } + out, err := router.narrowByGroupAntiAffinity(ctx, "c", []string{"n1", "n2", "n3"}) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(ConsistOf("n3")) + }) + + It("treats a nil candidate set (\"any healthy node\") by returning nil unchanged when narrowing yields nothing", func() { + resolver.conflicts["b"] = []string{"a"} + reg.findNodesWithModelByName = map[string][]BackendNode{ + "a": {{ID: "n1"}, {ID: "n2"}}, + } + out, err := router.narrowByGroupAntiAffinity(ctx, "b", nil) + Expect(err).ToNot(HaveOccurred()) + // nil in → nil out: caller's "any healthy node" semantics preserved. + // Hard-narrowing nil would silently exclude every other node. + Expect(out).To(BeNil()) + }) + + It("is a no-op when no resolver is configured", func() { + plain := NewSmartRouter(reg, SmartRouterOptions{}) + candidates := []string{"n1", "n2"} + out, err := plain.narrowByGroupAntiAffinity(ctx, "b", candidates) + Expect(err).ToNot(HaveOccurred()) + Expect(out).To(Equal(candidates)) + }) + }) }) diff --git a/docs/content/advanced/vram-management.md b/docs/content/advanced/vram-management.md index fb7e89490dfe..3b7620e80e65 100644 --- a/docs/content/advanced/vram-management.md +++ b/docs/content/advanced/vram-management.md @@ -8,7 +8,8 @@ url = '/advanced/vram-management' When running multiple models in LocalAI, especially on systems with limited GPU memory (VRAM), you may encounter situations where loading a new model fails because there isn't enough available VRAM. LocalAI provides several mechanisms to automatically manage model memory allocation and prevent VRAM exhaustion: 1. **Max Active Backends (LRU Eviction)**: Limit the number of loaded models, evicting the least recently used when the limit is reached -2. **Watchdog Mechanisms**: Automatically unload idle or stuck models based on configurable timeouts +2. **Concurrency Groups**: Per-model anti-affinity rules that prevent specific models from coexisting on the same node +3. **Watchdog Mechanisms**: Automatically unload idle or stuck models based on configurable timeouts ## The Problem @@ -136,6 +137,86 @@ LOCALAI_SINGLE_ACTIVE_BACKEND=true ./local-ai - When you only need one model active at a time - Simple deployments where model switching is acceptable +## Solution 1b: Concurrency Groups (per-model anti-affinity) + +`--max-active-backends` is a global count — three loaded models is fine, but it +doesn't know that two of them are 120B and shouldn't share a GPU. +**Concurrency groups** give per-model rules: any two models that share a group +name are mutually exclusive on the same node. Loading one evicts the others. +Models with no groups behave exactly as before. + +This addresses [issue #9659](https://github.com/mudler/LocalAI/issues/9659): + +> allow my zed prediction model to run alongside anything but don't allow my +> two 120b models to run alongside each other + +### Configuration + +Declare groups per model in the YAML config — no CLI flag, no env var: + +```yaml +# llama-120b-a.yaml +name: llama-120b-a +backend: llama-cpp +parameters: + model: llama-120b-a.gguf +concurrency_groups: ["vram-heavy"] +``` + +```yaml +# llama-120b-b.yaml +name: llama-120b-b +backend: llama-cpp +parameters: + model: llama-120b-b.gguf +concurrency_groups: ["vram-heavy"] +``` + +```yaml +# zed-predict.yaml — no groups, runs alongside anything +name: zed-predict +backend: llama-cpp +parameters: + model: zed-predict.gguf +``` + +With this configuration: + +1. Request `zed-predict` → loads. +2. Request `llama-120b-a` → loads alongside `zed-predict`. +3. Request `llama-120b-b` → `llama-120b-a` is evicted (shared group + `vram-heavy`); `zed-predict` stays loaded. + +A model can declare multiple groups; two models conflict if they share **any** +group name. Group names are arbitrary strings — pick names that make sense for +your hardware (`vram-heavy`, `gpu-1`, `large-context`, ...). + +### Interaction with other knobs + +- **`--max-active-backends`**: groups are checked *before* the LRU cap. Group + evictions may already make room; LRU then enforces the global count. +- **`pinned: true`**: a pinned model is never evicted, including by a group + conflict. The new request is loaded with a warning logged — pinning two + models in the same group is a configuration mismatch. +- **`--force-eviction-when-busy`**: same retry semantics as LRU. A busy + conflict is skipped and retried (`--lru-eviction-max-retries`, + `--lru-eviction-retry-interval`); after retries exhaust, the load proceeds + with a warning. + +### Distributed mode + +`concurrency_groups` is enforced **per node**, not cluster-wide — VRAM is a +node-local resource, so two heavy models on different nodes is fine. The +distributed scheduler additionally uses the rule as a placement hint: when +choosing where to load a new model, it prefers nodes that don't already host a +same-group model, falling back to eviction only if every candidate has a +conflict. + +`concurrency_groups` composes with `NodeSelector` (which decides *which +nodes* a model is eligible for) — the two filters apply in sequence. Use +`NodeSelector` to target hardware classes; use `concurrency_groups` to keep +specific models from co-residing on whichever node hosts them. + ## Solution 2: Watchdog Mechanisms For more flexible memory management, LocalAI provides watchdog mechanisms that automatically unload models based on their activity state. This allows multiple models to be loaded simultaneously, but automatically frees memory when models become inactive or stuck. diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index ad9a88c06d17..6d291961b545 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -197,6 +197,35 @@ func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err e return model.GRPC(o.parallelRequests, ml.wd), nil } +// retryEnforce repeatedly invokes fn until it returns NeedMore=false or the +// retry budget is exhausted. It sleeps `retryInterval` between attempts and +// logs progress under `label`. Used by both LRU and group-exclusivity +// enforcement so the busy-model wait behaviour is identical. +func retryEnforce(fn func() EnforceLRULimitResult, maxRetries int, retryInterval time.Duration, label string) { + for attempt := range maxRetries { + result := fn() + if !result.NeedMore { + if result.EvictedCount > 0 { + xlog.Info("[ModelLoader] "+label+" enforcement complete", "evicted", result.EvictedCount) + } + return + } + if attempt < maxRetries-1 { + xlog.Info("[ModelLoader] Waiting for busy models to become idle before eviction", + "label", label, + "evicted", result.EvictedCount, + "attempt", attempt+1, + "maxRetries", maxRetries, + "retryIn", retryInterval) + time.Sleep(retryInterval) + } else { + xlog.Warn("[ModelLoader] "+label+" enforcement incomplete after max retries", + "evicted", result.EvictedCount, + "reason", "conflicts are still busy or pinned") + } + } +} + // enforceLRULimit enforces the LRU limit before loading a new model. // This is called before loading a model to ensure we don't exceed the limit. // It accounts for models that are currently being loaded by other goroutines. @@ -206,41 +235,34 @@ func (ml *ModelLoader) enforceLRULimit() { return } - // Get the count of models currently being loaded to account for concurrent requests pendingLoads := ml.GetLoadingCount() - // Get retry settings from ModelLoader ml.mu.Lock() maxRetries := ml.lruEvictionMaxRetries retryInterval := ml.lruEvictionRetryInterval ml.mu.Unlock() - for attempt := range maxRetries { - result := ml.wd.EnforceLRULimit(pendingLoads) - - if !result.NeedMore { - // Successfully evicted enough models (or no eviction needed) - if result.EvictedCount > 0 { - xlog.Info("[ModelLoader] LRU enforcement complete", "evicted", result.EvictedCount) - } - return - } + retryEnforce(func() EnforceLRULimitResult { + return ml.wd.EnforceLRULimit(pendingLoads) + }, maxRetries, retryInterval, "LRU") +} - // Need more evictions but models are busy - wait and retry - if attempt < maxRetries-1 { - xlog.Info("[ModelLoader] Waiting for busy models to become idle before eviction", - "evicted", result.EvictedCount, - "attempt", attempt+1, - "maxRetries", maxRetries, - "retryIn", retryInterval) - time.Sleep(retryInterval) - } else { - // Last attempt - log warning but proceed (might fail to load, but at least we tried) - xlog.Warn("[ModelLoader] LRU enforcement incomplete after max retries", - "evicted", result.EvictedCount, - "reason", "models are still busy with active API calls") - } +// enforceGroupExclusivity evicts every loaded model that shares a concurrency +// group with modelID before loading proceeds. Reuses the LRU retry settings so +// busy conflicts wait for the same window as a busy LRU eviction. +func (ml *ModelLoader) enforceGroupExclusivity(modelID string) { + if ml.wd == nil { + return } + + ml.mu.Lock() + maxRetries := ml.lruEvictionMaxRetries + retryInterval := ml.lruEvictionRetryInterval + ml.mu.Unlock() + + retryEnforce(func() EnforceLRULimitResult { + return ml.wd.EnforceGroupExclusivity(modelID) + }, maxRetries, retryInterval, "group-exclusivity") } // updateModelLastUsed updates the last used time for a model (for LRU tracking) @@ -270,6 +292,12 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { return client, nil } + // Evict any loaded model that shares a concurrency group with the + // requested one before applying the global LRU cap — group eviction may + // already make room, and otherwise LRU might evict an unrelated model + // only for the group check to immediately evict another. + ml.enforceGroupExclusivity(o.modelID) + // Enforce LRU limit before loading a new model ml.enforceLRULimit() diff --git a/pkg/model/initializers_retry_test.go b/pkg/model/initializers_retry_test.go new file mode 100644 index 000000000000..b42fde6ebb77 --- /dev/null +++ b/pkg/model/initializers_retry_test.go @@ -0,0 +1,50 @@ +package model + +import ( + "sync/atomic" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("retryEnforce", func() { + It("returns immediately when the first attempt is satisfied", func() { + var calls atomic.Int32 + retryEnforce(func() EnforceLRULimitResult { + calls.Add(1) + return EnforceLRULimitResult{} + }, 5, 1*time.Millisecond, "test") + Expect(calls.Load()).To(Equal(int32(1))) + }) + + It("retries until NeedMore clears", func() { + var calls atomic.Int32 + retryEnforce(func() EnforceLRULimitResult { + n := calls.Add(1) + if n < 3 { + return EnforceLRULimitResult{NeedMore: true} + } + return EnforceLRULimitResult{EvictedCount: 1} + }, 5, 1*time.Millisecond, "test") + Expect(calls.Load()).To(Equal(int32(3))) + }) + + It("stops after maxRetries when NeedMore never clears", func() { + var calls atomic.Int32 + retryEnforce(func() EnforceLRULimitResult { + calls.Add(1) + return EnforceLRULimitResult{NeedMore: true} + }, 4, 1*time.Millisecond, "test") + Expect(calls.Load()).To(Equal(int32(4))) + }) + + It("treats maxRetries <= 0 as a no-op (no calls)", func() { + var calls atomic.Int32 + retryEnforce(func() EnforceLRULimitResult { + calls.Add(1) + return EnforceLRULimitResult{} + }, 0, 1*time.Millisecond, "test") + Expect(calls.Load()).To(Equal(int32(0))) + }) +}) diff --git a/pkg/model/watchdog.go b/pkg/model/watchdog.go index fea188803c06..d876d6ad9019 100644 --- a/pkg/model/watchdog.go +++ b/pkg/model/watchdog.go @@ -48,6 +48,11 @@ type WatchDog struct { // Pinned models are excluded from idle, LRU, and memory-pressure eviction pinnedModels map[string]bool + + // modelGroups maps a model name to its declared concurrency groups. + // Two loaded models that share at least one group cannot coexist on this + // node — see EnforceGroupExclusivity. + modelGroups map[string][]string } type ProcessManager interface { @@ -82,6 +87,7 @@ func NewWatchDog(opts ...WatchDogOption) *WatchDog { lruLimit: o.lruLimit, addressModelMap: make(map[string]string), pinnedModels: make(map[string]bool), + modelGroups: make(map[string][]string), stop: make(chan bool, 1), done: make(chan bool, 1), memoryReclaimerEnabled: o.memoryReclaimerEnabled, @@ -145,6 +151,34 @@ func (wd *WatchDog) IsModelPinned(modelName string) bool { return wd.pinnedModels[modelName] } +// ReplaceModelGroups replaces the per-model concurrency-group registry. The +// supplied map is copied; callers may mutate it after the call. Passing an +// empty or nil map clears all entries. +func (wd *WatchDog) ReplaceModelGroups(groups map[string][]string) { + wd.Lock() + defer wd.Unlock() + wd.modelGroups = make(map[string][]string, len(groups)) + for name, gs := range groups { + if len(gs) == 0 { + continue + } + wd.modelGroups[name] = slices.Clone(gs) + } +} + +// GetModelGroups returns a copy of the concurrency groups configured for +// the given model, or nil if the model has no groups. The result may be +// freely mutated by the caller. +func (wd *WatchDog) GetModelGroups(modelName string) []string { + wd.Lock() + defer wd.Unlock() + gs, ok := wd.modelGroups[modelName] + if !ok || len(gs) == 0 { + return nil + } + return slices.Clone(gs) +} + func (wd *WatchDog) Shutdown() { wd.Lock() defer wd.Unlock() @@ -327,43 +361,111 @@ func (wd *WatchDog) EnforceLRULimit(pendingLoads int) EnforceLRULimitResult { }) // Collect models to evict (the oldest ones) - var modelsToShutdown []string - evictedCount := 0 - skippedBusyCount := 0 - for i := 0; evictedCount < modelsToEvict && i < len(models); i++ { - m := models[i] - // Skip pinned models + modelsToShutdown, skippedBusyCount := wd.collectEvictionsLocked(models, modelsToEvict, forceEvictionWhenBusy) + needMore := len(modelsToShutdown) < modelsToEvict && skippedBusyCount > 0 + wd.Unlock() + + // Now shutdown models without holding the watchdog lock to prevent deadlock + for _, model := range modelsToShutdown { + if err := wd.pm.ShutdownModel(model); err != nil { + xlog.Error("[WatchDog] error shutting down model during LRU eviction", "error", err, "model", model) + } + xlog.Debug("[WatchDog] LRU eviction complete", "model", model) + } + + if needMore { + xlog.Warn("[WatchDog] LRU eviction incomplete", "evicted", len(modelsToShutdown), "needed", modelsToEvict, "skippedBusy", skippedBusyCount, "reason", "some models are busy with active API calls") + } + + return EnforceLRULimitResult{ + EvictedCount: len(modelsToShutdown), + NeedMore: needMore, + } +} + +// collectEvictionsLocked walks `candidates` (already in eviction order) and +// untracks up to `maxToEvict` models that are eligible for eviction. Pinned +// models are always skipped; busy models are skipped unless `force` is true. +// Returns the names of evicted models and the number skipped because they +// were busy. Must be called with wd.Lock() held. +func (wd *WatchDog) collectEvictionsLocked(candidates []modelUsageInfo, maxToEvict int, force bool) (evicted []string, skippedBusy int) { + for i := 0; len(evicted) < maxToEvict && i < len(candidates); i++ { + m := candidates[i] if wd.pinnedModels[m.model] { - xlog.Debug("[WatchDog] Skipping LRU eviction for pinned model", "model", m.model) + xlog.Debug("[WatchDog] Skipping eviction for pinned model", "model", m.model) continue } - // Check if model is busy _, isBusy := wd.busyTime[m.address] - if isBusy && !forceEvictionWhenBusy { - // Skip eviction for busy models when forceEvictionWhenBusy is false - xlog.Warn("[WatchDog] Skipping LRU eviction for busy model", "model", m.model, "reason", "model has active API calls") - skippedBusyCount++ + if isBusy && !force { + xlog.Warn("[WatchDog] Skipping eviction for busy model", "model", m.model, "reason", "model has active API calls") + skippedBusy++ continue } - xlog.Info("[WatchDog] LRU evicting model", "model", m.model, "lastUsed", m.lastUsed, "busy", isBusy) - modelsToShutdown = append(modelsToShutdown, m.model) - // Clean up the maps while we have the lock + xlog.Info("[WatchDog] evicting model", "model", m.model, "busy", isBusy) + evicted = append(evicted, m.model) wd.untrack(m.address) - evictedCount++ } - needMore := evictedCount < modelsToEvict && skippedBusyCount > 0 + return evicted, skippedBusy +} + +// EnforceGroupExclusivity evicts every loaded model that shares at least one +// concurrency group with the requested model. The pinned/busy/retry semantics +// match EnforceLRULimit so the loader's retry loop can stay generic. +func (wd *WatchDog) EnforceGroupExclusivity(requestedModel string) EnforceLRULimitResult { + wd.Lock() + + requestedGroups := wd.modelGroups[requestedModel] + if len(requestedGroups) == 0 { + wd.Unlock() + return EnforceLRULimitResult{} + } + + forceEvictionWhenBusy := wd.forceEvictionWhenBusy + + // Build the conflict candidate list: every loaded model whose groups + // overlap with requestedGroups. Order doesn't affect correctness, but + // sort by lastUsed (oldest first) so logs and behaviour are deterministic. + var conflicts []modelUsageInfo + for address, name := range wd.addressModelMap { + if name == requestedModel { + continue + } + if !groupsOverlap(requestedGroups, wd.modelGroups[name]) { + continue + } + conflicts = append(conflicts, modelUsageInfo{ + address: address, + model: name, + lastUsed: wd.lastUsed[address], + }) + } + if len(conflicts) == 0 { + wd.Unlock() + return EnforceLRULimitResult{} + } + slices.SortFunc(conflicts, func(a, b modelUsageInfo) int { + return a.lastUsed.Compare(b.lastUsed) + }) + + xlog.Debug("[WatchDog] Group exclusivity triggered", "requested", requestedModel, "groups", requestedGroups, "conflicts", len(conflicts)) + + modelsToShutdown, skippedBusyCount := wd.collectEvictionsLocked(conflicts, len(conflicts), forceEvictionWhenBusy) + // For groups any unresolved conflict matters — busy *or* pinned. The loader + // retries on NeedMore; pinned cases will eventually time out and the load + // proceeds with a visible warning, which is the right signal for what is a + // configuration mismatch. + needMore := len(modelsToShutdown) < len(conflicts) wd.Unlock() - // Now shutdown models without holding the watchdog lock to prevent deadlock - for _, model := range modelsToShutdown { - if err := wd.pm.ShutdownModel(model); err != nil { - xlog.Error("[WatchDog] error shutting down model during LRU eviction", "error", err, "model", model) + for _, m := range modelsToShutdown { + if err := wd.pm.ShutdownModel(m); err != nil { + xlog.Error("[WatchDog] error shutting down model during group eviction", "error", err, "model", m) } - xlog.Debug("[WatchDog] LRU eviction complete", "model", model) + xlog.Debug("[WatchDog] Group eviction complete", "model", m) } if needMore { - xlog.Warn("[WatchDog] LRU eviction incomplete", "evicted", evictedCount, "needed", modelsToEvict, "skippedBusy", skippedBusyCount, "reason", "some models are busy with active API calls") + xlog.Warn("[WatchDog] Group eviction incomplete", "requested", requestedModel, "evicted", len(modelsToShutdown), "needed", len(conflicts), "skippedBusy", skippedBusyCount, "reason", "some conflicts are busy or pinned") } return EnforceLRULimitResult{ @@ -372,6 +474,19 @@ func (wd *WatchDog) EnforceLRULimit(pendingLoads int) EnforceLRULimitResult { } } +// groupsOverlap reports whether the two group lists share any name. +func groupsOverlap(a, b []string) bool { + if len(a) == 0 || len(b) == 0 { + return false + } + for _, x := range a { + if slices.Contains(b, x) { + return true + } + } + return false +} + func (wd *WatchDog) Run() { xlog.Info("[WatchDog] starting watchdog") diff --git a/pkg/model/watchdog_test.go b/pkg/model/watchdog_test.go index f0f982843593..b21a87985330 100644 --- a/pkg/model/watchdog_test.go +++ b/pkg/model/watchdog_test.go @@ -666,6 +666,182 @@ var _ = Describe("WatchDog", func() { }) }) + Context("Concurrency Groups", func() { + Describe("ReplaceModelGroups / GetModelGroups", func() { + It("returns nil for unknown models", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + Expect(wd.GetModelGroups("nope")).To(BeNil()) + }) + + It("stores and retrieves groups", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy", "vision"}, + "b": {"heavy"}, + }) + Expect(wd.GetModelGroups("a")).To(Equal([]string{"heavy", "vision"})) + Expect(wd.GetModelGroups("b")).To(Equal([]string{"heavy"})) + Expect(wd.GetModelGroups("c")).To(BeNil()) + }) + + It("replaces previous state on subsequent calls", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{"a": {"heavy"}}) + wd.ReplaceModelGroups(map[string][]string{"b": {"vision"}}) + Expect(wd.GetModelGroups("a")).To(BeNil()) + Expect(wd.GetModelGroups("b")).To(Equal([]string{"vision"})) + }) + + It("clears state when called with an empty map", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{"a": {"heavy"}}) + wd.ReplaceModelGroups(nil) + Expect(wd.GetModelGroups("a")).To(BeNil()) + }) + + It("returns a defensive copy", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{"a": {"heavy"}}) + got := wd.GetModelGroups("a") + got[0] = "tampered" + Expect(wd.GetModelGroups("a")).To(Equal([]string{"heavy"})) + }) + }) + + Describe("EnforceGroupExclusivity", func() { + It("is a no-op when the requested model has no groups", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.AddAddressModelMap("addr1", "model1") + wd.AddAddressModelMap("addr2", "model2") + + result := wd.EnforceGroupExclusivity("requested") + Expect(result.EvictedCount).To(Equal(0)) + Expect(result.NeedMore).To(BeFalse()) + Expect(pm.getShutdownCalls()).To(BeEmpty()) + }) + + It("is a no-op when no loaded model shares a group", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{ + "loaded": {"vision"}, + "requested": {"heavy"}, + }) + wd.AddAddressModelMap("addr1", "loaded") + + result := wd.EnforceGroupExclusivity("requested") + Expect(result.EvictedCount).To(Equal(0)) + Expect(result.NeedMore).To(BeFalse()) + Expect(pm.getShutdownCalls()).To(BeEmpty()) + }) + + It("evicts a loaded model that shares a single group", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy"}, + "b": {"heavy"}, + }) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") + wd.UnMark("addrA") + + result := wd.EnforceGroupExclusivity("b") + Expect(result.EvictedCount).To(Equal(1)) + Expect(result.NeedMore).To(BeFalse()) + Expect(pm.getShutdownCalls()).To(ConsistOf("a")) + }) + + It("evicts when groups overlap on any single name", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"x", "y"}, + "b": {"y", "z"}, + }) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") + wd.UnMark("addrA") + + result := wd.EnforceGroupExclusivity("b") + Expect(result.EvictedCount).To(Equal(1)) + Expect(pm.getShutdownCalls()).To(ConsistOf("a")) + }) + + It("evicts every conflicting loaded model", func() { + wd = model.NewWatchDog( + model.WithProcessManager(pm), + model.WithForceEvictionWhenBusy(true), + ) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy"}, + "b": {"heavy"}, + "c": {"heavy"}, + }) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") + wd.UnMark("addrA") + wd.AddAddressModelMap("addrB", "b") + wd.Mark("addrB") + wd.UnMark("addrB") + + result := wd.EnforceGroupExclusivity("c") + Expect(result.EvictedCount).To(Equal(2)) + Expect(pm.getShutdownCalls()).To(ConsistOf("a", "b")) + }) + + It("skips a pinned conflicting model and reports NeedMore", func() { + wd = model.NewWatchDog( + model.WithProcessManager(pm), + model.WithForceEvictionWhenBusy(true), + ) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy"}, + "b": {"heavy"}, + }) + wd.SetPinnedModels([]string{"a"}) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") + wd.UnMark("addrA") + + result := wd.EnforceGroupExclusivity("b") + Expect(result.EvictedCount).To(Equal(0)) + Expect(result.NeedMore).To(BeTrue()) + Expect(pm.getShutdownCalls()).To(BeEmpty()) + }) + + It("skips a busy conflict when forceEvictionWhenBusy is false", func() { + wd = model.NewWatchDog(model.WithProcessManager(pm)) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy"}, + "b": {"heavy"}, + }) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") // leave busy + + result := wd.EnforceGroupExclusivity("b") + Expect(result.EvictedCount).To(Equal(0)) + Expect(result.NeedMore).To(BeTrue()) + Expect(pm.getShutdownCalls()).To(BeEmpty()) + }) + + It("evicts a busy conflict when forceEvictionWhenBusy is true", func() { + wd = model.NewWatchDog( + model.WithProcessManager(pm), + model.WithForceEvictionWhenBusy(true), + ) + wd.ReplaceModelGroups(map[string][]string{ + "a": {"heavy"}, + "b": {"heavy"}, + }) + wd.AddAddressModelMap("addrA", "a") + wd.Mark("addrA") // leave busy + + result := wd.EnforceGroupExclusivity("b") + Expect(result.EvictedCount).To(Equal(1)) + Expect(result.NeedMore).To(BeFalse()) + Expect(pm.getShutdownCalls()).To(ConsistOf("a")) + }) + }) + }) + Context("Functional Options", func() { It("should use default options when none provided", func() { wd = model.NewWatchDog(