Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ldm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ def prompt2image(
catch_interrupts = False,
hires_fix = False,
use_mps_noise = False,
# Seam settings for outpainting
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
force_outpaint: bool = False,
**args,
): # eat up additional cruft
"""
Expand Down Expand Up @@ -467,7 +474,13 @@ def process_image(image,seed):
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=checker
safety_checker=checker,
seam_size = seam_size,
seam_blur = seam_blur,
seam_strength = seam_strength,
seam_steps = seam_steps,
tile_size = tile_size,
force_outpaint = force_outpaint
)

if init_color:
Expand Down Expand Up @@ -929,8 +942,9 @@ def _load_img(self, img)->Image:
image = ImageOps.exif_transpose(image)
return image

def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB')
def _create_init_image(self, image: Image.Image, width, height, fit=True):
if image.mode != 'RGBA':
image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image

Expand Down
2 changes: 1 addition & 1 deletion ldm/invoke/generator/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
)

if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
init_image = self._image_to_tensor(init_image.convert('RGB'))

scope = choose_autocast(self.precision)
with scope(self.model.device.type):
Expand Down
206 changes: 181 additions & 25 deletions ldm/invoke/generator/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''

import math
import torch
import torchvision.transforms as T
import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter
from PIL import Image, ImageFilter, ImageOps
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast
Expand All @@ -24,11 +25,128 @@ def __init__(self, model, precision):
self.mask_blur_radius = 0
super().__init__(model, precision)

# Outpaint support code
def get_tile_images(self, image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides

nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None

return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False
)

def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: int = None) -> Image:
a = np.asarray(im, dtype=np.uint8)

tile_size = (tile_size, tile_size)

# Get the image as tiles of a specified size
tiles = self.get_tile_images(a,*tile_size).copy()

# Get the mask as tiles
tiles_mask = tiles[:,:,:,:,3]

# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n,ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = (tiles_mask > 0)
tiles_mask = tiles_mask.reshape((n,ny)).all(axis = 1)

# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), * tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]

if len(filtered_tiles) == 0:
return im

# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed = seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count),:,:,:]

# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1,2)
st = tiles_all.reshape((math.prod(tiles_all.shape[0:2]), math.prod(tiles_all.shape[2:4]), tiles_all.shape[4]))
si = Image.fromarray(st, mode='RGBA')

return si


def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
npimg = np.asarray(mask, dtype=np.uint8)

# Detect any partially transparent regions
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))

# Detect hard edges
npedge = cv.Canny(npimg, threshold1=100, threshold2=200)

# Combine
npmask = npgradient + npedge

# Expand
npmask = cv.dilate(npmask, np.ones((3,3), np.uint8), iterations = int(edge_size / 2))

new_mask = Image.fromarray(npmask)

if edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(edge_blur))

return ImageOps.invert(new_mask)


def seam_paint(self,
im: Image.Image,
seam_size: int,
seam_blur: int,
prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,strength,
noise
) -> Image.Image:
hard_mask = self.pil_image.split()[-1].copy()
mask = self.mask_edge(hard_mask, seam_size, seam_blur)

make_image = self.get_make_image(
prompt,
sampler,
steps,
cfg_scale,
ddim_eta,
conditioning,
init_image = im.copy().convert('RGBA'),
mask_image = mask.convert('RGB'), # Code currently requires an RGB mask
strength = strength,
mask_blur_radius = 0,
seam_size = 0
)

result = make_image(noise)

return result


@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs):
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_strength: float = 0.7,
seam_steps: int = 10,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False, **kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
Expand All @@ -37,7 +155,17 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,

if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
init_image = self._image_to_tensor(init_image)

# Fill missing areas of original image
init_filled = self.tile_fill_missing(
self.pil_image.copy(),
seed = self.seed,
tile_size = tile_size
)
init_filled.paste(init_image, (0,0), init_image.split()[-1])

# Create init tensor
init_image = self._image_to_tensor(init_filled.convert('RGB'))

if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
Expand Down Expand Up @@ -106,38 +234,56 @@ def make_image(x_T):
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)

return make_image
result = self.sample_to_image(samples)

def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
if seam_size > 0:
result = self.seam_paint(
result,
seam_size,
seam_blur,
prompt,
sampler,
seam_steps,
cfg_scale,
ddim_eta,
conditioning,
seam_strength,
x_T)

if self.pil_image is None or self.pil_mask is None:
return gen_result

pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
return result

return make_image


def color_correct(self, image: Image.Image, base_image: Image.Image, mask: Image.Image, mask_blur_radius: int) -> Image.Image:
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
pil_init_mask = mask.getchannel('A') if mask.mode == 'RGBA' else mask.convert('L')
pil_init_image = base_image.convert('RGBA') # Add an alpha channel if one doesn't exist

# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
init_rgb_pixels = np.asarray(base_image.convert('RGB'), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)

# Get numpy version of result
np_image = np.asarray(image, dtype=np.uint8)

# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]

init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)

# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')

# Blur the mask out (into init image) by specified amount
Expand All @@ -150,6 +296,16 @@ def sample_to_image(self, samples)->Image.Image:
blurred_init_mask = pil_init_mask

# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
matched_result.paste(base_image, (0,0), mask = blurred_init_mask)
return matched_result


def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')

if self.pil_image is None or self.pil_mask is None:
return gen_result

corrected_result = self.color_correct(gen_result, self.pil_image, self.pil_mask, self.mask_blur_radius)

return corrected_result