Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions ldm/invoke/generator/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,26 @@ def sample_to_image(self, samples)->Image.Image:
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist

# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram

# Get numpy version
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8)
init_mask_pixels = np.asarray(pil_init_mask, dtype=np.uint8)

# Get numpy version of result
np_gen_result = np.asarray(gen_result, dtype=np.uint8)

# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_gen_result_masked = np_gen_result[mask_pixels, :]

init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_gen_result_masked.mean(axis=0)
gen_std = np_gen_result_masked.std(axis=0)

# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
np_matched_result = np_gen_result.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')

# Blur the mask out (into init image) by specified amount
Expand All @@ -151,4 +159,3 @@ def sample_to_image(self, samples)->Image.Image:
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result