Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const resetLayer = (layer: Layer) => {
layer.isEnabled = true;
layer.needsPixelBbox = false;
layer.bboxNeedsUpdate = false;
layer.uploadedMaskImage = null;
return;
}
};
Expand Down Expand Up @@ -173,6 +174,7 @@ export const controlLayersSlice = createSlice({
if (bbox === null && layer.type === 'regional_guidance_layer') {
// The layer was fully erased, empty its objects to prevent accumulation of invisible objects
layer.maskObjects = [];
layer.uploadedMaskImage = null;
layer.needsPixelBbox = false;
}
}
Expand Down Expand Up @@ -456,6 +458,7 @@ export const controlLayersSlice = createSlice({
negativePrompt: null,
ipAdapters: [],
isSelected: true,
uploadedMaskImage: null,
};
state.layers.push(layer);
state.selectedLayerId = layer.id;
Expand Down Expand Up @@ -505,6 +508,7 @@ export const controlLayersSlice = createSlice({
strokeWidth: state.brushSize,
});
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
if (!layer.needsPixelBbox && tool === 'eraser') {
layer.needsPixelBbox = true;
}
Expand All @@ -524,6 +528,7 @@ export const controlLayersSlice = createSlice({
// TODO: Handle this in the event listener
lastLine.points.push(point[0] - layer.x, point[1] - layer.y);
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
rgLayerRectAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; rect: IRect; rectUuid: string }>) => {
Expand All @@ -543,9 +548,15 @@ export const controlLayersSlice = createSlice({
height: rect.height,
});
layer.bboxNeedsUpdate = true;
layer.uploadedMaskImage = null;
},
prepare: (payload: { layerId: string; rect: IRect }) => ({ payload: { ...payload, rectUuid: uuidv4() } }),
},
rgLayerMaskImageUploaded: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO }>) => {
const { layerId, imageDTO } = action.payload;
const layer = selectRGLayerOrThrow(state, layerId);
layer.uploadedMaskImage = imageDTOToImageWithDims(imageDTO);
},
rgLayerAutoNegativeChanged: (
state,
action: PayloadAction<{ layerId: string; autoNegative: ParameterAutoNegative }>
Expand Down Expand Up @@ -825,6 +836,7 @@ export const {
rgLayerLineAdded,
rgLayerPointsAdded,
rgLayerRectAdded,
rgLayerMaskImageUploaded,
rgLayerAutoNegativeChanged,
rgLayerIPAdapterAdded,
rgLayerIPAdapterDeleted,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export type RegionalGuidanceLayer = RenderableLayerBase & {
previewColor: RgbColor;
autoNegative: ParameterAutoNegative;
needsPixelBbox: boolean; // Needs the slower pixel-based bbox calculation - set to true when an there is an eraser object
uploadedMaskImage: ImageWithDims | null;
};

export type InitialImageLayer = RenderableLayerBase & {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ export const ImageViewer = memo(() => {
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);

// The AnimatePresence mode must be wait - else framer can get confused if you spam the toggle button
return (
<AnimatePresence>
<AnimatePresence mode="wait">
{shouldShowViewer && (
<Flex
key="imageViewer"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import {
isControlAdapterLayer,
isIPAdapterLayer,
isRegionalGuidanceLayer,
rgLayerMaskImageUploaded,
} from 'features/controlLayers/store/controlLayersSlice';
import type { RegionalGuidanceLayer } from 'features/controlLayers/store/types';
import {
type ControlNetConfigV2,
type ImageWithDims,
Expand Down Expand Up @@ -32,12 +34,13 @@ import {
} from 'features/nodes/util/graph/constants';
import { upsertMetadata } from 'features/nodes/util/graph/metadata';
import { size } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type {
CollectInvocation,
ControlNetInvocation,
CoreMetadataInvocation,
Edge,
ImageDTO,
IPAdapterInvocation,
NonNullableGraph,
S,
Expand Down Expand Up @@ -337,7 +340,6 @@ const addGlobalIPAdaptersToGraph = async (
};

export const addControlLayersToGraph = async (state: RootState, graph: NonNullableGraph, denoiseNodeId: string) => {
const { dispatch } = getStore();
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
Expand Down Expand Up @@ -404,10 +406,6 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
return hasTextPrompt || hasIPAdapter;
});

const layerIds = rgLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');

// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.

Expand Down Expand Up @@ -470,22 +468,15 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
},
});

// Upload the blobs to the backend, add each to graph
// TODO: Store the uploaded image names in redux to reuse them, so long as the layer hasn't otherwise changed. This
// would be a great perf win - not only would we skip re-uploading the same image, but we'd be able to use the node
// cache (currently, when we re-use the same mask data, since it is a different image, the node cache is not used).
const layerIds = rgLayers.map((l) => l.id);
const blobs = await getRegionalPromptLayerBlobs(layerIds);
assert(size(blobs) === size(layerIds), 'Mismatch between layer IDs and blobs');

for (const layer of rgLayers) {
const blob = blobs[layer.id];
assert(blob, `Blob for layer ${layer.id} not found`);

const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();

// TODO: This will raise on network error
const { image_name } = await req.unwrap();
// Upload the mask image, or get the cached image if it exists
const { image_name } = await getMaskImage(layer, blob);

// The main mask-to-tensor node
const maskToTensorNode: S['AlphaMaskToTensorInvocation'] = {
Expand Down Expand Up @@ -679,3 +670,23 @@ export const addControlLayersToGraph = async (state: RootState, graph: NonNullab
}
}
};

const getMaskImage = async (layer: RegionalGuidanceLayer, blob: Blob): Promise<ImageDTO> => {
if (layer.uploadedMaskImage) {
const imageDTO = await getImageDTO(layer.uploadedMaskImage.imageName);
if (imageDTO) {
return imageDTO;
}
}
const { dispatch } = getStore();
// No cached mask, or the cached image no longer exists - we need to upload the mask image
const file = new File([blob], `${layer.id}_mask.png`, { type: 'image/png' });
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate({ file, image_category: 'mask', is_intermediate: true })
);
req.reset();

const imageDTO = await req.unwrap();
dispatch(rgLayerMaskImageUploaded({ layerId: layer.id, imageDTO }));
return imageDTO;
};
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ const InvokeTabs = () => {
/>
</>
)}
<Panel style={{ position: 'relative' }} id="main-panel" order={1} minSize={20}>
<TabPanels w="full" h="full">
<Panel id="main-panel" order={1} minSize={20}>
<TabPanels w="full" h="full" position="relative">
{tabPanels}
<ImageViewer />
</TabPanels>
<ImageViewer />
</Panel>
{shouldShowGalleryPanel && (
<>
Expand Down