Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
$tool,
layerReset,
selectControlLayersSlice,
selectedLayerDeleted,
selectedLayerReset,
} from 'features/controlLayers/store/controlLayersSlice';
import { useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
Expand All @@ -22,6 +22,7 @@ export const ToolChooser: React.FC = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const isDisabled = useAppSelector(selectIsDisabled);
const selectedLayerId = useAppSelector((s) => s.controlLayers.present.selectedLayerId);
const tool = useStore($tool);

const setToolToBrush = useCallback(() => {
Expand All @@ -42,8 +43,11 @@ export const ToolChooser: React.FC = () => {
useHotkeys('v', setToolToMove, { enabled: !isDisabled }, [isDisabled]);

const resetSelectedLayer = useCallback(() => {
dispatch(selectedLayerReset());
}, [dispatch]);
if (selectedLayerId === null) {
return;
}
dispatch(layerReset(selectedLayerId));
}, [dispatch, selectedLayerId]);
useHotkeys('shift+c', resetSelectedLayer);

const deleteSelectedLayer = useCallback(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,6 @@ export const isRenderableLayer = (
layer?.type === 'regional_guidance_layer' ||
layer?.type === 'control_adapter_layer' ||
layer?.type === 'initial_image_layer';
const resetLayer = (layer: Layer) => {
if (layer.type === 'regional_guidance_layer') {
layer.maskObjects = [];
layer.bbox = null;
layer.isEnabled = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
layer.uploadedMaskImage = null;
return;
}
};

export const selectCALayerOrThrow = (state: ControlLayersState, layerId: string): ControlAdapterLayer => {
const layer = state.layers.find((l) => l.id === layerId);
Expand Down Expand Up @@ -164,6 +153,9 @@ export const controlLayersSlice = createSlice({
layer.x = x;
layer.y = y;
}
if (isRegionalGuidanceLayer(layer)) {
layer.uploadedMaskImage = null;
}
},
layerBboxChanged: (state, action: PayloadAction<{ layerId: string; bbox: IRect | null }>) => {
const { layerId, bbox } = action.payload;
Expand All @@ -181,8 +173,14 @@ export const controlLayersSlice = createSlice({
},
layerReset: (state, action: PayloadAction<string>) => {
const layer = state.layers.find((l) => l.id === action.payload);
if (layer) {
resetLayer(layer);
// TODO(psyche): Should other layer types also have reset functionality?
if (isRegionalGuidanceLayer(layer)) {
layer.maskObjects = [];
layer.bbox = null;
layer.isEnabled = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
layer.uploadedMaskImage = null;
}
},
layerDeleted: (state, action: PayloadAction<string>) => {
Expand Down Expand Up @@ -215,12 +213,6 @@ export const controlLayersSlice = createSlice({
moveToFront(renderableLayers, cb);
state.layers = [...ipAdapterLayers, ...renderableLayers];
},
selectedLayerReset: (state) => {
const layer = state.layers.find((l) => l.id === state.selectedLayerId);
if (layer) {
resetLayer(layer);
}
},
selectedLayerDeleted: (state) => {
state.layers = state.layers.filter((l) => l.id !== state.selectedLayerId);
state.selectedLayerId = state.layers[0]?.id ?? null;
Expand Down Expand Up @@ -803,7 +795,6 @@ export const {
layerMovedToFront,
layerMovedBackward,
layerMovedToBack,
selectedLayerReset,
selectedLayerDeleted,
allLayersDeleted,
// CA Layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,20 @@ const CurrentImagePreview = () => {
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
const timeoutId = useRef(0);
const onMouseMove = useCallback(() => {
const onMouseOver = useCallback(() => {
setShouldShowNextPrevButtons(true);
window.clearTimeout(timeoutId.current);
}, []);
const onMouseOut = useCallback(() => {
timeoutId.current = window.setTimeout(() => {
setShouldShowNextPrevButtons(false);
}, 1000);
}, 500);
}, []);

return (
<Flex
onMouseMove={onMouseMove}
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
width="full"
height="full"
alignItems="center"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { uniqBy } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
Expand Down Expand Up @@ -83,6 +84,9 @@ export const gallerySlice = createSlice({
},
},
extraReducers: (builder) => {
builder.addCase(setActiveTab, (state) => {
state.isImageViewerOpen = false;
});
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
const deletedBoardId = action.meta.arg.originalArgs;
if (deletedBoardId === state.selectedBoardId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ import { assert } from 'tsafe';

import { IMAGE_TO_LATENTS, NOISE, RESIZE } from './constants';

/**
* Returns true if an initial image was added, false if not.
*/
export const addInitialImageToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
denoiseNodeId: string
): void => {
): boolean => {
// Remove Existing UNet Connections
const { img2imgStrength, vaePrecision, model } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
Expand All @@ -19,7 +22,7 @@ export const addInitialImageToLinearGraph = (
const initialImage = initialImageLayer?.isEnabled ? initialImageLayer?.image : null;

if (!initialImage) {
return;
return false;
}

const isSDXL = model?.base === 'sdxl';
Expand Down Expand Up @@ -122,4 +125,6 @@ export const addInitialImageToLinearGraph = (
strength: img2imgStrength,
init_image: initialImage.imageName,
});

return true;
};
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { RootState } from 'app/store/store';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import { getBoardField } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ESRGANInvocation, Graph, NonNullableGraph } from 'services/api/types';

import { ESRGAN } from './constants';
Expand All @@ -18,7 +18,7 @@ export const buildAdHocUpscaleGraph = ({ image_name, state }: Arg): Graph => {
type: 'esrgan',
image: { image_name },
model_name: esrganModelName,
is_intermediate: getIsIntermediate(state),
is_intermediate: false,
board: getBoardField(state),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
LATENTS_TO_IMAGE
);

addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);
const didAddInitialImage = addInitialImageToLinearGraph(state, graph, DENOISE_LATENTS);

// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
Expand All @@ -249,7 +249,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
await addControlLayersToGraph(state, graph, DENOISE_LATENTS);

// High resolution fix.
if (state.hrf.hrfEnabled) {
if (state.hrf.hrfEnabled && !didAddInitialImage) {
addHrfToGraph(state, graph);
}

Expand Down