In [None]:
%load_ext autoreload
%autoreload 2

from nerfiller.inpaint.rgb_inpainter import RGBInpainter, RGBInpainterXL
from nerfiller.inpaint.lama_inpainter import LaMaInpainter
from nerfiller.inpaint.depth_inpainter import DepthInpainter
from nerfiller.utils.image_utils import get_inpainted_image_row
from nerfiller.nerf.dataset_utils import parse_nerfstudio_frame
from nerfiller.utils.mask_utils import downscale_mask
from nerfiller.utils.mesh_utils import dilate

from pathlib import Path

import torch
import mediapy
import json

In [None]:
device = "cuda:1"
dataset = "billiards"
lora_model_path = None

In [None]:
rgb_inpainter = RGBInpainter(device=device, lora_model_path=lora_model_path, vae_device=device)
# rgb_inpainter = RGBInpainterXL(device=device, lora_model_path=lora_model_path, vae_device=device)
# rgb_inpainter = LaMaInpainter(device=device, model_path=Path("../data/models/big-lama"))

# uncomment for depth
# depth_inpainter = DepthInpainter(device=device, depth_method="zoedepth")

In [None]:
indices = [31, 10, 20, 55]

depth_max = 100.0
display_width = 512
strength = 1.0 # only set < 1 if the region under the mask is known but noisy e.g., for nerfbusters scenes

data_path = Path(f"../data/nerfstudio/{dataset}")
f = open(f"{data_path}/transforms.json")
transforms = json.load(f)
f.close()

images = []
masks = []
depths = []
Ks = []
c2ws = []
for i, idx in enumerate(indices):
    image, depth, mask, c2w, K = parse_nerfstudio_frame(
        transforms, data_path, idx, depth_max=depth_max, device=device
    )
    images.append(image)
    masks.append(mask)
    depths.append(depth)
    Ks.append(K)
    c2ws.append(c2w)

images = torch.cat(images)
masks = torch.cat(masks)
depths = torch.cat(depths)
Ks = torch.cat(Ks)
c2ws = torch.cat(c2ws)

# seed = 0
# generator = [torch.Generator(device=device).manual_seed(seed) for seed in range(len(images))]
generator = None

mediapy.show_images(images.permute(0, 2, 3, 1).cpu(), width=display_width)
mediapy.show_images(masks.permute(0, 2, 3, 1).cpu(), width=display_width)
# mediapy.show_images(depths.permute(0, 2, 3, 1).cpu(), width=display_width)

### Inpaint one image

You can choose which settings to use for inpainting.

In [None]:
if lora_model_path is None:
    ps = f"a photo of a {dataset}"
    pn = ""
else:
    ps = "a photo of sks"
    pn = ""

# uncomment if you want to set the positive prompt manually
# ps = "bunny ears"
# ps = "santa claus"
# ps = "farmer, overalls"
# ps = "basket"

if not isinstance(rgb_inpainter, LaMaInpainter):
    text_embeddings = rgb_inpainter.compute_text_embeddings(ps, pn)

In [None]:
# choose starting image
# starting_image = None
starting_image = images
# starting_image = torch.from_numpy(mediapy.read_image("../outputs/drawing-none/grid-prior-du/2023-11-02_032615/renders/step_78000/images/image_000030.png") / 255.0).permute(2, 0, 1)[None].to(images)
# mediapy.show_images(starting_image.permute(0, 2, 3, 1).cpu(), width=display_width)

In [None]:
s = 0
e = images.shape[0]
if isinstance(rgb_inpainter, LaMaInpainter):
    # lama inpainter
    ma = masks[s:e]
    reference_imagein = rgb_inpainter.get_image(image=images[s:e], mask=ma)
else:
    # diffusion inpainter
    dilate_iters = 0
    dilate_kernel_size = 3
    ma = masks[s:e]
    for _ in range(dilate_iters):
        ma = dilate(ma, kernel_size=dilate_kernel_size)
    reference_imagein = rgb_inpainter.get_image(
        text_embeddings=text_embeddings,
        image=images[s:e],
        mask=ma[s:e],
        denoise_in_grid=False,
        multidiffusion_steps=1,
        randomize_latents=False,
        text_guidance_scale=0.0, # modify to > 0, e.g., 7.5 or 15.0, if you want to use text CFG
        image_guidance_scale=1.5,
        num_inference_steps=20,
        generator=generator,
        use_decoder_approximation=False,
        replace_original_pixels=True,
        starting_image=starting_image,
        starting_lower_bound=strength,
        starting_upper_bound=strength,
        # output_folder=Path("inpaint/"),
    )
reference_imagein_row = get_inpainted_image_row(image=images[s:e], mask=ma[s:e], inpainted_image=reference_imagein, show_original=False).permute(0,2,3,1).detach().cpu()
mediapy.show_images(reference_imagein_row, width=display_width)

In [None]:
# show masks and images separately
# mediapy.show_images(reference_imagein_row[:,:512])
# mediapy.show_images(reference_imagein_row[:,512:])

# uncomment to save a reference image. then, copy it to the nerfstudio dataset folder
# mediapy.write_image("reference.png", reference_imagein.permute(0,2,3,1)[0].detach().cpu())
# images[0:1] = reference_imagein
# masks[0:1] = 0.0
# images[4:5] = reference_imagein
# masks[4:5] = 0.0

# uncomment to inpaint the depth
# with torch.no_grad():
#     reference_depthin = depth_inpainter.get_depth(image=reference_imagein)
# mediapy.show_images(reference_depthin.permute(0,2,3,1).detach().cpu(), width=display_width)

### Inpaint multiple images

In [None]:
# uncomment for expanded attention
# from nerfiller.utils.diff_utils import register_extended_attention
# register_extended_attention(rgb_inpainter.unet)

if isinstance(rgb_inpainter, LaMaInpainter):
    # lama inpainter
    imagein = rgb_inpainter.get_image(image=images, mask=masks)
else:
    # diffusion inpainter
    denoise_in_grid = True
    scale_factor = 0.5 if denoise_in_grid else 1.0
    dilate_iters = 0
    dilate_kernel_size = 3
    im = torch.nn.functional.interpolate(images, scale_factor=scale_factor, mode="bilinear")
    ma = downscale_mask(
        masks,
        scale_factor=scale_factor,
        dilate_iters=dilate_iters,
        dilate_kernel_size=dilate_kernel_size,
    )
    imagein = rgb_inpainter.get_image(
        text_embeddings=text_embeddings,
        image=im,
        mask=ma,
        denoise_in_grid=denoise_in_grid,
        multidiffusion_steps=8,
        multidiffusion_type="epsilon",
        randomize_latents=True,
        randomize_within_grid=False,
        text_guidance_scale=0.0,
        image_guidance_scale=1.5,
        num_inference_steps=20,
        generator=generator,
        replace_original_pixels=True,
        starting_image=im,
        starting_lower_bound=strength,
        starting_upper_bound=strength,
        # output_folder=Path("inpaint_grid/"),
    )
imagein = torch.nn.functional.interpolate(imagein, scale_factor=1/scale_factor, mode="bilinear")
imagein = torch.where(masks==1, imagein, images)
imagein_row = get_inpainted_image_row(image=images, mask=masks, inpainted_image=imagein, show_original=False).permute(0,2,3,1).detach().cpu()
mediapy.show_images(imagein_row, width=display_width)

In [None]:
# show masks and images separately, nice for drag and drop to make figures
# mediapy.show_images(imagein_row[:,:512].detach().cpu())
# mediapy.show_images(imagein_row[:,512:].detach().cpu())
# or, save to a folder
# for i in range(len(imagein_row)):
#     mediapy.write_image(f"joint_inpainting_figure/mask_{i:0d}.png", imagein_row[i,:512])
#     mediapy.write_image(f"joint_inpainting_figure/image_{i:0d}.png", imagein_row[i,512:])

In [None]:
# with torch.no_grad():
#     depthin = depth_inpainter.get_depth(image=imagein)
# mediapy.show_images(depthin.permute(0,2,3,1).detach().cpu(), width=display_width)

In [None]:
# # uncomment to make grid prior figure
# prefix = "grid_prior_figure"
# imagerow = get_inpainted_image_row(image=images, mask=masks, inpainted_image=reference_imagein, show_original=False).permute(0,2,3,1).detach().cpu()
# gridimagerow = get_inpainted_image_row(image=images, mask=masks, inpainted_image=imagein, show_original=False).permute(0,2,3,1).detach().cpu()
# for i in range(len(imagerow)):
#     mediapy.write_image(f"{prefix}/mask_{i:06d}.png", imagerow[i][:512])
#     mediapy.write_image(f"{prefix}/individual_{i:06d}.png", imagerow[i][512:])
#     mediapy.write_image(f"{prefix}/grid_{i:06d}.png", gridimagerow[i][512:])

# # uncomment to split an intermediate image separate images
# image_filename = "inpaint_grid/x0-000004.png"
# intermediate_image = mediapy.read_image(image_filename)
# for i in range(images.shape[0]):
#     # we assume resolution is 256
#     intermediate_image_i = intermediate_image[:, i*256:i*256+256]
#     mediapy.write_image(f"{prefix}/intermediate_grid_{i:06d}.png", intermediate_image_i)

# uncomment to concat rows together for a big image
# reference_imagein_row_cat = torch.cat(list(reference_imagein_row), dim=1).detach().cpu()
# imagein_row_cat = torch.cat(list(imagein_row), dim=1).detach().cpu()
# full_image = torch.cat([reference_imagein_row_cat, imagein_row_cat[512:]])
# mediapy.show_image(full_image)
# mediapy.write_image(f"full_image_{dataset}.png", full_image)