Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
Expand All @@ -20,7 +21,6 @@

from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
Expand Down Expand Up @@ -57,8 +57,8 @@ class ApiDependencies:
invoker: Invoker = None

@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.debug(f'InvokeAI version {__version__}')
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.debug(f"InvokeAI version {__version__}")
logger.debug(f"Internet connectivity is {config.internet_available}")

events = FastAPIEventService(event_handler_id)
Expand Down Expand Up @@ -117,7 +117,7 @@ def initialize(config, event_handler_id: int, logger: Logger = logger):
)

services = InvocationServices(
model_manager=ModelManagerService(config,logger),
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
Expand All @@ -129,7 +129,6 @@ def initialize(config, event_handler_id: int, logger: Logger = logger):
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)
Expand Down
2 changes: 0 additions & 2 deletions invokeai/app/cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from .services.invoker import Invoker
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage

import torch
Expand Down Expand Up @@ -295,7 +294,6 @@ def invoke_cli():
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
configuration=config,
)
Expand Down
55 changes: 0 additions & 55 deletions invokeai/app/invocations/reconstruct.py

This file was deleted.

124 changes: 94 additions & 30 deletions invokeai/app/invocations/upscale.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Literal, Optional
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Union, cast

import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer

from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput

from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput

class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]

# fmt: off
type: Literal["upscale"] = "upscale"

# Inputs
image: Optional[ImageField] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
# fmt: on
class RealESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""

# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
type: Literal["realesrgan"] = "realesrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)

def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
strength=0.0, # GFPGAN strength
save_original=False,
image_callback=None,
models_path = context.services.configuration.models_path

rrdbnet_model = None
netscale = None
esrgan_model_path = None

if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)

esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")

upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
)

# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)

# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)

# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")

image_dto = context.services.images.create(
image=results[0][0],
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ def add_subparser(cls, parser: argparse.ArgumentParser):

@classmethod
def _excluded(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed
# internal fields that shouldn't be exposed as command line options
return ['type','initconf']

@classmethod
def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model']
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore']

class Config:
env_file_encoding = 'utf-8'
Expand Down Expand Up @@ -366,7 +366,7 @@ class InvokeAIAppConfig(InvokeAISettings):
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')

always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
Expand Down
10 changes: 3 additions & 7 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC

Expand All @@ -24,7 +23,7 @@ class InvocationServices:
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAISettings"
configuration: "InvokeAIAppConfig"
events: "EventServiceBase"
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
Expand All @@ -34,13 +33,12 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
queue: "InvocationQueueABC"
restoration: "RestorationServices"

def __init__(
self,
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAISettings",
configuration: "InvokeAIAppConfig",
events: "EventServiceBase",
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
graph_library: "ItemStorageABC"["LibraryGraph"],
Expand All @@ -50,7 +48,6 @@ def __init__(
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
queue: "InvocationQueueABC",
restoration: "RestorationServices",
):
self.board_images = board_images
self.boards = boards
Expand All @@ -65,4 +62,3 @@ def __init__(
self.model_manager = model_manager
self.processor = processor
self.queue = queue
self.restoration = restoration
Loading