diff --git a/ldm/generate.py b/ldm/generate.py index dcef529988d..6c1ecfa803b 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -297,6 +297,13 @@ def prompt2image( catch_interrupts = False, hires_fix = False, use_mps_noise = False, + # Seam settings for outpainting + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + force_outpaint: bool = False, **args, ): # eat up additional cruft """ @@ -467,7 +474,13 @@ def process_image(image,seed): embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker + safety_checker=checker, + seam_size = seam_size, + seam_blur = seam_blur, + seam_strength = seam_strength, + seam_steps = seam_steps, + tile_size = tile_size, + force_outpaint = force_outpaint ) if init_color: @@ -929,8 +942,9 @@ def _load_img(self, img)->Image: image = ImageOps.exif_transpose(image) return image - def _create_init_image(self, image, width, height, fit=True): - image = image.convert('RGB') + def _create_init_image(self, image: Image.Image, width, height, fit=True): + if image.mode != 'RGBA': + image = image.convert('RGB') image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image) return image diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index c4810de3855..1981b4eacb6 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -30,7 +30,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, ) if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image) + init_image = self._image_to_tensor(init_image.convert('RGB')) scope = choose_autocast(self.precision) with scope(self.model.device.type): diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 34d4f209fc6..a6ab0cd06a6 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -2,12 +2,13 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator ''' +import math import torch import torchvision.transforms as T import numpy as np import cv2 as cv import PIL -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from skimage.exposure.histogram_matching import match_histograms from einops import rearrange, repeat from ldm.invoke.devices import choose_autocast @@ -24,11 +25,128 @@ def __init__(self, model, precision): self.mask_blur_radius = 0 super().__init__(model, precision) + # Outpaint support code + def get_tile_images(self, image: np.ndarray, width=8, height=8): + _nrows, _ncols, depth = image.shape + _strides = image.strides + + nrows, _m = divmod(_nrows, height) + ncols, _n = divmod(_ncols, width) + if _m != 0 or _n != 0: + return None + + return np.lib.stride_tricks.as_strided( + np.ravel(image), + shape=(nrows, ncols, height, width, depth), + strides=(height * _strides[0], width * _strides[1], *_strides), + writeable=False + ) + + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image: + a = np.asarray(im, dtype=np.uint8) + + tile_size = (tile_size, tile_size) + + # Get the image as tiles of a specified size + tiles = self.get_tile_images(a,*tile_size).copy() + + # Get the mask as tiles + tiles_mask = tiles[:,:,:,:,3] + + # Find any mask tiles with any fully transparent pixels (we will be replacing these later) + tmask_shape = tiles_mask.shape + tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape)) + n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:]) + tiles_mask = (tiles_mask > 0) + tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1) + + # Get RGB tiles in single array and filter by the mask + tshape = tiles.shape + tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:])) + filtered_tiles = tiles_all[tiles_mask] + + if len(filtered_tiles) == 0: + return im + + # Find all invalid tiles and replace with a random valid tile + replace_count = (tiles_mask == False).sum() + rng = np.random.default_rng(seed = seed) + tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:] + + # Convert back to an image + tiles_all = tiles_all.reshape(tshape) + tiles_all = tiles_all.swapaxes(1,2) + st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4])) + si = Image.fromarray(st, mode='RGBA') + + return si + + + def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image: + npimg = np.asarray(mask, dtype=np.uint8) + + # Detect any partially transparent regions + npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) + + # Detect hard edges + npedge = cv.Canny(npimg, threshold1=100, threshold2=200) + + # Combine + npmask = npgradient + npedge + + # Expand + npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2)) + + new_mask = Image.fromarray(npmask) + + if edge_blur > 0: + new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur)) + + return ImageOps.invert(new_mask) + + + def seam_paint(self, + im: Image.Image, + seam_size: int, + seam_blur: int, + prompt,sampler,steps,cfg_scale,ddim_eta, + conditioning,strength, + noise + ) -> Image.Image: + hard_mask = self.pil_image.split()[-1].copy() + mask = self.mask_edge(hard_mask, seam_size, seam_blur) + + make_image = self.get_make_image( + prompt, + sampler, + steps, + cfg_scale, + ddim_eta, + conditioning, + init_image = im.copy().convert('RGBA'), + mask_image = mask.convert('RGB'), # Code currently requires an RGB mask + strength = strength, + mask_blur_radius = 0, + seam_size = 0 + ) + + result = make_image(noise) + + return result + + @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, mask_blur_radius: int = 8, - step_callback=None,inpaint_replace=False, **kwargs): + # Seam settings - when 0, doesn't fill seam + seam_size: int = 0, + seam_blur: int = 0, + seam_strength: float = 0.7, + seam_steps: int = 10, + tile_size: int = 32, + step_callback=None, + inpaint_replace=False, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at @@ -37,7 +155,17 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, if isinstance(init_image, PIL.Image.Image): self.pil_image = init_image - init_image = self._image_to_tensor(init_image) + + # Fill missing areas of original image + init_filled = self.tile_fill_missing( + self.pil_image.copy(), + seed = self.seed, + tile_size = tile_size + ) + init_filled.paste(init_image, (0,0), init_image.split()[-1]) + + # Create init tensor + init_image = self._image_to_tensor(init_filled.convert('RGB')) if isinstance(mask_image, PIL.Image.Image): self.pil_mask = mask_image @@ -106,38 +234,56 @@ def make_image(x_T): mask = mask_image, init_latent = self.init_latent ) - return self.sample_to_image(samples) - return make_image + result = self.sample_to_image(samples) - def sample_to_image(self, samples)->Image.Image: - gen_result = super().sample_to_image(samples).convert('RGB') + # Seam paint if this is our first pass (seam_size set to 0 during seam painting) + if seam_size > 0: + result = self.seam_paint( + result, + seam_size, + seam_blur, + prompt, + sampler, + seam_steps, + cfg_scale, + ddim_eta, + conditioning, + seam_strength, + x_T) - if self.pil_image is None or self.pil_mask is None: - return gen_result - - pil_mask = self.pil_mask - pil_image = self.pil_image - mask_blur_radius = self.mask_blur_radius + return result + return make_image + + + def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image: # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L') - pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist + pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L') + pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. - # Note that this doesn't use the mask, which would exclude some source image pixels from the - # histogram and cause slight color changes. - init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3) - init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height) - init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0] - init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram + init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8) + init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8) + init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8) + + # Get numpy version of result + np_image = np.asarray(image, dtype=np.uint8) - # Get numpy version - np_gen_result = np.asarray(gen_result, dtype=np.uint8) + # Mask and calculate mean and standard deviation + mask_pixels = init_a_pixels * init_mask_pixels > 0 + np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :] + np_image_masked = np_image[mask_pixels, :] + + init_means = np_init_rgb_pixels_masked.mean(axis=0) + init_std = np_init_rgb_pixels_masked.std(axis=0) + gen_means = np_image_masked.mean(axis=0) + gen_std = np_image_masked.std(axis=0) # Color correct - np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1) + np_matched_result = np_image.copy() + np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8) matched_result = Image.fromarray(np_matched_result, mode='RGB') # Blur the mask out (into init image) by specified amount @@ -150,6 +296,16 @@ def sample_to_image(self, samples)->Image.Image: blurred_init_mask = pil_init_mask # Paste original on color-corrected generation (using blurred mask) - matched_result.paste(pil_image, (0,0), mask = blurred_init_mask) + matched_result.paste(base_image, (0,0), mask = blurred_init_mask) return matched_result + + def sample_to_image(self, samples)->Image.Image: + gen_result = super().sample_to_image(samples).convert('RGB') + + if self.pil_image is None or self.pil_mask is None: + return gen_result + + corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius) + + return corrected_result