From d8351a9829f1a10a161d7e6795e6992bcdbd0bde Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 9 Apr 2026 16:30:49 -0400 Subject: [PATCH] fix anima model auto-selection --- .../listeners/modelSelected.test.ts | 360 ++++++++++++++++++ .../listeners/modelSelected.ts | 4 + .../controlLayers/store/paramsSlice.ts | 7 +- 3 files changed, 366 insertions(+), 5 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.test.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.test.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.test.ts new file mode 100644 index 00000000000..9443001c2d7 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.test.ts @@ -0,0 +1,360 @@ +import { zModelIdentifierField } from 'features/nodes/types/common'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Mock model configs returned by selectors - these simulate what RTK Query provides +const mockAnimaQwen3Encoder = { + key: 'qwen3-06b-key', + hash: 'qwen3-06b-hash', + name: 'Qwen3 0.6B Encoder', + base: 'any' as const, + type: 'qwen3_encoder' as const, + variant: 'qwen3_06b' as const, + format: 'qwen3_encoder' as const, +}; + +const mockAnimaVAE = { + key: 'anima-vae-key', + hash: 'anima-vae-hash', + name: 'Anima VAE', + base: 'anima' as const, + type: 'vae' as const, + format: 'diffusers' as const, +}; + +const mockT5Encoder = { + key: 't5-xxl-key', + hash: 't5-xxl-hash', + name: 'T5-XXL Encoder', + base: 'any' as const, + type: 't5_encoder' as const, + format: 't5_encoder' as const, +}; + +const mockAnimaMainModel = { + key: 'anima-main-key', + hash: 'anima-main-hash', + name: 'Anima Generate', + base: 'anima' as const, + type: 'main' as const, +}; + +const mockFluxMainModel = { + key: 'flux-main-key', + hash: 'flux-main-hash', + name: 'FLUX.1 Dev', + base: 'flux' as const, + type: 'main' as const, +}; + +// Track dispatched actions +const dispatched: Array<{ type: string; payload: unknown }> = []; +const mockDispatch = vi.fn((action: { type: string; payload: unknown }) => { + dispatched.push(action); +}); + +// Mock logger +vi.mock('app/logging/logger', () => ({ + logger: () => ({ + debug: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + info: vi.fn(), + }), +})); + +// Mock toast +vi.mock('features/toast/toast', () => ({ + toast: vi.fn(), +})); + +// Mock i18next +vi.mock('i18next', () => ({ + t: (key: string) => key, +})); + +// Mock model selectors from RTK Query hooks + +const mockSelectAnimaQwen3EncoderModels = vi.fn((_state: unknown) => [mockAnimaQwen3Encoder]); + +const mockSelectAnimaVAEModels = vi.fn((_state: unknown) => [mockAnimaVAE]); + +const mockSelectT5EncoderModels = vi.fn((_state: unknown) => [mockT5Encoder]); + +vi.mock('services/api/hooks/modelsByType', () => ({ + selectAnimaQwen3EncoderModels: (state: unknown) => mockSelectAnimaQwen3EncoderModels(state), + selectAnimaVAEModels: (state: unknown) => mockSelectAnimaVAEModels(state), + selectT5EncoderModels: (state: unknown) => mockSelectT5EncoderModels(state), + selectQwen3EncoderModels: vi.fn(() => []), + selectZImageDiffusersModels: vi.fn(() => []), + selectFluxVAEModels: vi.fn(() => []), + selectGlobalRefImageModels: vi.fn(() => []), + selectRegionalRefImageModels: vi.fn(() => []), +})); + +// Mock model configs adapter +vi.mock('services/api/endpoints/models', () => ({ + modelConfigsAdapterSelectors: { selectById: vi.fn() }, + selectModelConfigsQuery: vi.fn(() => ({ data: undefined })), +})); + +vi.mock('services/api/types', () => ({ + isFluxKontextModelConfig: vi.fn(() => false), + isFluxReduxModelConfig: vi.fn(() => false), +})); + +// Mock canvas selectors +vi.mock('features/controlLayers/store/canvasStagingAreaSlice', () => ({ + buildSelectIsStaging: vi.fn(() => vi.fn(() => false)), + selectCanvasSessionId: vi.fn(() => null), +})); + +vi.mock('features/controlLayers/store/selectors', () => ({ + selectAllEntitiesOfType: vi.fn(() => []), + selectBboxModelBase: vi.fn(() => 'anima'), + selectCanvasSlice: vi.fn(() => ({})), +})); + +vi.mock('features/controlLayers/store/refImagesSlice', () => ({ + refImageConfigChanged: vi.fn(), + refImageModelChanged: vi.fn(), + selectReferenceImageEntities: vi.fn(() => []), +})); + +vi.mock('features/controlLayers/store/types', async () => { + const actual = await vi.importActual('features/controlLayers/store/types'); + return { + ...(actual as Record), + getEntityIdentifier: vi.fn(), + isFlux2ReferenceImageConfig: vi.fn(() => false), + }; +}); + +vi.mock('features/controlLayers/store/util', () => ({ + initialFlux2ReferenceImage: {}, + initialFluxKontextReferenceImage: {}, + initialFLUXRedux: {}, + initialIPAdapter: {}, +})); + +vi.mock('features/modelManagerV2/models', () => ({ + SUPPORTS_REF_IMAGES_BASE_MODELS: ['sd-1', 'sdxl', 'flux', 'flux2'], +})); + +vi.mock('features/controlLayers/store/canvasSlice', () => ({ + bboxSyncedToOptimalDimension: vi.fn(() => ({ type: 'bboxSyncedToOptimalDimension' })), + rgRefImageModelChanged: vi.fn(), +})); + +vi.mock('features/controlLayers/store/lorasSlice', () => ({ + loraIsEnabledChanged: vi.fn((payload: unknown) => ({ type: 'loraIsEnabledChanged', payload })), +})); + +// Capture the listener effect so we can call it directly +let capturedEffect: ((action: unknown, api: unknown) => void) | null = null; + +// Import actual action creators for assertion matching +const paramsSliceActual = (await vi.importActual('features/controlLayers/store/paramsSlice')) as { + animaQwen3EncoderModelSelected: { type: string }; + animaT5EncoderModelSelected: { type: string }; + animaVaeModelSelected: { type: string }; +}; +const { animaQwen3EncoderModelSelected, animaT5EncoderModelSelected, animaVaeModelSelected } = paramsSliceActual; + +// Import after mocks are set up +const { addModelSelectedListener } = await import('./modelSelected'); +const { modelSelected } = await import('features/parameters/store/actions'); +const { zParameterModel } = await import('features/parameters/types/parameterSchemas'); + +// Capture the effect +addModelSelectedListener(((config: { effect: typeof capturedEffect }) => { + capturedEffect = config.effect; +}) as never); + +function buildMockState(overrides: Record = {}) { + return { + params: { + model: null, + vae: null, + zImageVaeModel: null, + zImageQwen3EncoderModel: null, + zImageQwen3SourceModel: null, + animaVaeModel: null, + animaQwen3EncoderModel: null, + animaT5EncoderModel: null, + animaScheduler: 'euler', + kleinVaeModel: null, + kleinQwen3EncoderModel: null, + zImageScheduler: 'euler', + ...overrides, + }, + loras: { loras: [] }, + canvas: {}, + }; +} + +describe('modelSelected listener - Anima defaulting', () => { + beforeEach(() => { + dispatched.length = 0; + mockDispatch.mockClear(); + mockSelectAnimaQwen3EncoderModels.mockReturnValue([mockAnimaQwen3Encoder]); + mockSelectAnimaVAEModels.mockReturnValue([mockAnimaVAE]); + mockSelectT5EncoderModels.mockReturnValue([mockT5Encoder]); + }); + + it('should dispatch encoder models with full ModelIdentifierField payloads when switching to Anima', () => { + const state = buildMockState({ model: mockFluxMainModel }); + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + // Find the dispatched actions for Anima encoders + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); + + // All three should have been dispatched + expect(qwen3Dispatch).toBeDefined(); + expect(t5Dispatch).toBeDefined(); + expect(vaeDispatch).toBeDefined(); + + // The payloads must pass zModelIdentifierField validation (the actual schema used by reducers) + expect(zModelIdentifierField.safeParse(qwen3Dispatch!.payload).success).toBe(true); + expect(zModelIdentifierField.safeParse(t5Dispatch!.payload).success).toBe(true); + expect(zModelIdentifierField.safeParse(vaeDispatch!.payload).success).toBe(true); + }); + + it('should include hash and type in Qwen3 encoder payload', () => { + const state = buildMockState({ model: mockFluxMainModel }); + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); + expect(qwen3Dispatch!.payload).toMatchObject({ + key: mockAnimaQwen3Encoder.key, + hash: mockAnimaQwen3Encoder.hash, + name: mockAnimaQwen3Encoder.name, + base: mockAnimaQwen3Encoder.base, + type: mockAnimaQwen3Encoder.type, + }); + }); + + it('should include hash and type in T5 encoder payload', () => { + const state = buildMockState({ model: mockFluxMainModel }); + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); + expect(t5Dispatch!.payload).toMatchObject({ + key: mockT5Encoder.key, + hash: mockT5Encoder.hash, + name: mockT5Encoder.name, + base: mockT5Encoder.base, + type: mockT5Encoder.type, + }); + }); + + it('should not dispatch encoder defaults when Anima models are already set', () => { + const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' }; + const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' }; + const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' }; + + const state = buildMockState({ + model: mockFluxMainModel, + animaQwen3EncoderModel: existingQwen3, + animaT5EncoderModel: existingT5, + animaVaeModel: existingVae, + }); + + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + // Should NOT dispatch any encoder model selections since they're already set + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); + + expect(qwen3Dispatch).toBeUndefined(); + expect(t5Dispatch).toBeUndefined(); + expect(vaeDispatch).toBeUndefined(); + }); + + it('should not dispatch encoder defaults when no encoder models are available', () => { + mockSelectAnimaQwen3EncoderModels.mockReturnValue([]); + mockSelectAnimaVAEModels.mockReturnValue([]); + + const state = buildMockState({ model: mockFluxMainModel }); + const action = modelSelected(zParameterModel.parse(mockAnimaMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); + + expect(qwen3Dispatch).toBeUndefined(); + expect(t5Dispatch).toBeUndefined(); + expect(vaeDispatch).toBeUndefined(); + }); + + it('should clear Anima models when switching away from Anima', () => { + const existingQwen3 = { key: 'existing', hash: 'h', name: 'Existing', base: 'any', type: 'qwen3_encoder' }; + const existingT5 = { key: 'existing-t5', hash: 'h', name: 'Existing T5', base: 'any', type: 't5_encoder' }; + const existingVae = { key: 'existing-vae', hash: 'h', name: 'Existing VAE', base: 'anima', type: 'vae' }; + + const state = buildMockState({ + model: mockAnimaMainModel, + animaQwen3EncoderModel: existingQwen3, + animaT5EncoderModel: existingT5, + animaVaeModel: existingVae, + }); + + const action = modelSelected(zParameterModel.parse(mockFluxMainModel)); + + capturedEffect!(action, { + getState: () => state, + dispatch: mockDispatch, + }); + + // Should dispatch null for all three + const qwen3Dispatch = dispatched.find((a) => a.type === animaQwen3EncoderModelSelected.type); + const t5Dispatch = dispatched.find((a) => a.type === animaT5EncoderModelSelected.type); + const vaeDispatch = dispatched.find((a) => a.type === animaVaeModelSelected.type); + + expect(qwen3Dispatch).toBeDefined(); + expect(qwen3Dispatch!.payload).toBeNull(); + expect(t5Dispatch).toBeDefined(); + expect(t5Dispatch!.payload).toBeNull(); + expect(vaeDispatch).toBeDefined(); + expect(vaeDispatch!.payload).toBeNull(); + }); +}); + +describe('zModelIdentifierField schema validation', () => { + it('should reject payloads missing hash and type', () => { + const incomplete = { key: 'some-key', name: 'Some Model', base: 'any' }; + expect(zModelIdentifierField.safeParse(incomplete).success).toBe(false); + }); + + it('should accept payloads with all required fields', () => { + const complete = { key: 'some-key', hash: 'some-hash', name: 'Some Model', base: 'any', type: 'qwen3_encoder' }; + expect(zModelIdentifierField.safeParse(complete).success).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 6cefc1ba4f1..186f2c52dae 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -193,8 +193,10 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = dispatch( animaQwen3EncoderModelSelected({ key: qwen3Encoder.key, + hash: qwen3Encoder.hash, name: qwen3Encoder.name, base: qwen3Encoder.base, + type: qwen3Encoder.type, }) ); } @@ -214,8 +216,10 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = dispatch( animaT5EncoderModelSelected({ key: t5Encoder.key, + hash: t5Encoder.hash, name: t5Encoder.name, base: t5Encoder.base, + type: t5Encoder.type, }) ); } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 7e0c7a5029e..7543396cfcb 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -218,17 +218,14 @@ const slice = createSlice({ } state.animaVaeModel = result.data; }, - animaQwen3EncoderModelSelected: ( - state, - action: PayloadAction<{ key: string; name: string; base: string } | null> - ) => { + animaQwen3EncoderModelSelected: (state, action: PayloadAction) => { const result = zParamsState.shape.animaQwen3EncoderModel.safeParse(action.payload); if (!result.success) { return; } state.animaQwen3EncoderModel = result.data; }, - animaT5EncoderModelSelected: (state, action: PayloadAction<{ key: string; name: string; base: string } | null>) => { + animaT5EncoderModelSelected: (state, action: PayloadAction) => { const result = zParamsState.shape.animaT5EncoderModel.safeParse(action.payload); if (!result.success) { return;