Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { heightChanged, widthChanged } from 'features/controlLayers/store/contro
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
Expand Down Expand Up @@ -186,21 +186,23 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
};

const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
const currentUpscaleModel = state.upscale.upscaleModel;
const { upscaleModel: currentUpscaleModel, simpleUpscaleModel: currentSimpleUpscaleModel } = state.upscale;
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);
const firstModel = upscaleModels[0] || null;

if (currentUpscaleModel) {
const isCurrentUpscaleModelAvailable = upscaleModels.some((m) => m.key === currentUpscaleModel.key);
if (isCurrentUpscaleModelAvailable) {
return;
}
}
const isCurrentUpscaleModelAvailable = currentUpscaleModel
? upscaleModels.some((m) => m.key === currentUpscaleModel.key)
: false;

const firstModel = upscaleModels[0];
if (firstModel) {
if (!isCurrentUpscaleModelAvailable) {
dispatch(upscaleModelChanged(firstModel));
return;
}

dispatch(upscaleModelChanged(null));
const isCurrentSimpleUpscaleModelAvailable = currentSimpleUpscaleModel
? upscaleModels.some((m) => m.key === currentSimpleUpscaleModel.key)
: false;

if (!isCurrentSimpleUpscaleModelAvailable) {
dispatch(simpleUpscaleModelChanged(firstModel));
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
const log = logger('session');

const { imageDTO } = action.payload;
const { image_name } = imageDTO;
const state = getState();

const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
Expand All @@ -40,8 +39,8 @@ export const addUpscaleRequestedListener = (startAppListening: AppStartListening
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: buildAdHocUpscaleGraph({
image_name,
graph: await buildAdHocUpscaleGraph({
image: imageDTO,
state,
}),
runs: 1,
Expand Down
3 changes: 0 additions & 3 deletions invokeai/frontend/web/src/app/store/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/no
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
Expand Down Expand Up @@ -53,7 +52,6 @@ const allReducers = {
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
Expand Down Expand Up @@ -104,7 +102,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[generationPersistConfig.name]: generationPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[postprocessingPersistConfig.name]: postprocessingPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[workflowPersistConfig.name]: workflowPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,46 @@
import type { RootState } from 'app/store/store';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { Graph, Invocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageDTO,
type Invocation,
isSpandrelImageToImageModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { assert } from 'tsafe';

import { addCoreMetadataNode, upsertMetadata } from './canvas/metadata';
import { ESRGAN } from './constants';
import { addCoreMetadataNode, getModelMetadataField, upsertMetadata } from './canvas/metadata';
import { SPANDREL } from './constants';

type Arg = {
image_name: string;
image: ImageDTO;
state: RootState;
};

export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
const { esrganModelName } = state.postprocessing;
export const buildAdHocUpscaleGraph = async ({ image, state }: Arg): Promise<NonNullableGraph> => {
const { simpleUpscaleModel } = state.upscale;

const realesrganNode: Invocation<'esrgan'> = {
id: ESRGAN,
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: false,
board: getBoardField(state),
assert(simpleUpscaleModel, 'No upscale model found in state');

const upscaleNode: Invocation<'spandrel_image_to_image'> = {
id: SPANDREL,
type: 'spandrel_image_to_image',
image_to_image_model: simpleUpscaleModel,
tile_size: 500,
image,
};

const graph: NonNullableGraph = {
id: `adhoc-esrgan-graph`,
id: `adhoc-upscale-graph`,
nodes: {
[ESRGAN]: realesrganNode,
[SPANDREL]: upscaleNode,
},
edges: [],
};
const modelConfig = await fetchModelConfigWithTypeGuard(simpleUpscaleModel.key, isSpandrelImageToImageModelConfig);

addCoreMetadataNode(graph, {}, ESRGAN);
addCoreMetadataNode(graph, {}, SPANDREL);
upsertMetadata(graph, {
esrgan_model: esrganModelName,
upscale_model: getModelMetadataField(modelConfig),
});

return graph;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ export const CONTROL_NET_COLLECT = 'control_net_collect';
export const IP_ADAPTER_COLLECT = 'ip_adapter_collect';
export const T2I_ADAPTER_COLLECT = 't2i_adapter_collect';
export const METADATA = 'core_metadata';
export const ESRGAN = 'esrgan';
export const SPANDREL = 'spandrel';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_DENOISE_LATENTS = 'sdxl_denoise_latents';
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import { Box, Combobox, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { simpleUpscaleModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';

const ParamSpandrelModel = () => {
interface Props {
isMultidiffusion: boolean;
}

const ParamSpandrelModel = ({ isMultidiffusion }: Props) => {
const { t } = useTranslation();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();

const model = useAppSelector((s) => s.upscale.upscaleModel);
const model = useAppSelector((s) => (isMultidiffusion ? s.upscale.upscaleModel : s.upscale.simpleUpscaleModel));
const dispatch = useAppDispatch();

const tooltipLabel = useMemo(() => {
Expand All @@ -23,9 +27,13 @@ const ParamSpandrelModel = () => {

const _onChange = useCallback(
(v: SpandrelImageToImageModelConfig | null) => {
dispatch(upscaleModelChanged(v));
if (isMultidiffusion) {
dispatch(upscaleModelChanged(v));
} else {
dispatch(simpleUpscaleModelChanged(v));
}
},
[dispatch]
[isMultidiffusion, dispatch]
);

const { options, value, onChange, placeholder, noOptionsMessage } = useModelCombobox({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listen
import { useAppDispatch } from 'app/store/storeHooks';
import { useIsAllowedToUpscale } from 'features/parameters/hooks/useIsAllowedToUpscale';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { UpscaleWarning } from 'features/settingsAccordions/components/UpscaleSettingsAccordion/UpscaleWarning';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFrameCornersBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';

import ParamESRGANModel from './ParamRealESRGANModel';
import ParamSpandrelModel from './ParamSpandrelModel';

type Props = { imageDTO?: ImageDTO };

Expand Down Expand Up @@ -48,9 +49,10 @@ const ParamUpscalePopover = (props: Props) => {
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody minW={96}>
<PopoverBody w={96}>
<Flex flexDirection="column" gap={4}>
<ParamESRGANModel />
<ParamSpandrelModel isMultidiffusion={false} />
<UpscaleWarning usesTile={false} />
<Button
tooltip={detail}
size="sm"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectPostprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
Expand Down Expand Up @@ -55,13 +55,16 @@ const getDetailTKey = (isAllowedToUpscale?: ReturnType<typeof getIsAllowedToUpsc
};

export const createIsAllowedToUpscaleSelector = (imageDTO?: ImageDTO) =>
createMemoizedSelector(selectPostprocessingSlice, selectConfigSlice, (postprocessing, config) => {
const { esrganModelName } = postprocessing;
createMemoizedSelector(selectUpscalelice, selectConfigSlice, (upscale, config) => {
const { simpleUpscaleModel } = upscale;
const { maxUpscalePixels } = config;
if (!simpleUpscaleModel) {
return { isAllowedToUpscale: false, detailTKey: undefined };
}

const upscaledPixels = getUpscaledPixels(imageDTO, maxUpscalePixels);
const isAllowedToUpscale = getIsAllowedToUpscale(upscaledPixels, maxUpscalePixels);
const scaleFactor = esrganModelName.includes('x2') ? 2 : 4;
const scaleFactor = simpleUpscaleModel.name.includes('x2') ? 2 : 4;
const detailTKey = getDetailTKey(isAllowedToUpscale, scaleFactor);
return {
isAllowedToUpscale: scaleFactor === 2 ? isAllowedToUpscale.x2 : isAllowedToUpscale.x4,
Expand Down

This file was deleted.

Loading