Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_scale"] = "img_scale"

# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

Expand Down
37 changes: 18 additions & 19 deletions invokeai/app/invocations/upscale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Union, cast
from pathlib import Path
from typing import Literal, Union

import cv2 as cv
import numpy as np
Expand All @@ -16,19 +16,20 @@

# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
ESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"RealESRGAN_x2plus.pth",
]


class RealESRGANInvocation(BaseInvocation):
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""

type: Literal["realesrgan"] = "realesrgan"
type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field(
model_name: ESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)

Expand Down Expand Up @@ -73,19 +74,17 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
elif self.model_name in ["RealESRGAN_x2plus.pth"]:
# x2 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
Expand Down
7 changes: 6 additions & 1 deletion invokeai/backend/install/invokeai_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def download_conversion_models():

# ---------------------------------------------
def download_realesrgan():
logger.info("Installing RealESRGAN models...")
logger.info("Installing ESRGAN Upscaling models...")
URLs = [
dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
Expand All @@ -239,6 +239,11 @@ def download_realesrgan():
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth",
),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description = "RealESRGAN_x2plus.pth",
),
]
for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
Expand Down
7 changes: 0 additions & 7 deletions invokeai/frontend/web/src/app/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,8 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {

export type Scheduler = (typeof SCHEDULER_NAMES)[number];

// Valid upscaling levels
export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [
{ label: '2x', value: '2' },
{ label: '4x', value: '4' },
];
export const NUMPY_RAND_MIN = 0;

export const NUMPY_RAND_MAX = 2147483647;

export const FACETOOL_TYPES = ['gfpgan', 'codeformer'] as const;

export const NODE_MIN_WIDTH = 250;
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addModelLoadStartedEventListener } from './listeners/socketio/socketModelLoadStarted';
import { addModelLoadCompletedEventListener } from './listeners/socketio/socketModelLoadCompleted';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';

export const listenerMiddleware = createListenerMiddleware();

Expand Down Expand Up @@ -228,3 +229,5 @@ addModelSelectedListener();
addAppStartedListener();
addModelsLoadedListener();
addAppConfigReceivedListener();

addUpscaleRequestedListener();
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { sessionCreated } from 'services/api/thunks/session';
import { serializeError } from 'serialize-error';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';

const moduleLog = log.child({ namespace: 'session' });

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { createAction } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';

const moduleLog = log.child({ namespace: 'upscale' });

export const upscaleRequested = createAction<{ image_name: string }>(
`upscale/upscaleRequested`
);

export const addUpscaleRequestedListener = () => {
startAppListening({
actionCreator: upscaleRequested,
effect: async (
action,
{ dispatch, getState, take, unsubscribe, subscribe }
) => {
const { image_name } = action.payload;
const { esrganModelName } = getState().postprocessing;

const graph = buildAdHocUpscaleGraph({
image_name,
esrganModelName,
});

// Create a session to run the graph & wait til it's ready to invoke
dispatch(sessionCreated({ graph }));

await take(sessionCreated.fulfilled.match);

dispatch(sessionReadyToInvoke());
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {

const IAIMantineSelectItemWithTooltip = forwardRef<HTMLDivElement, ItemProps>(
({ label, tooltip, description, disabled, ...others }: ItemProps, ref) => (
<Tooltip label={tooltip} placement="top" hasArrow>
<Tooltip label={tooltip} placement="top" hasArrow openDelay={500}>
<Box ref={ref} {...others}>
<Box>
<Text>{label}</Text>
Expand Down
Loading