<a href="https://colab.research.google.com/github/j-min/IterInpaint/blob/main/inference_iterinpaint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/j-min/IterInpaint

In [None]:
cd IterInpaint

In [None]:
!pip uninstall -y torchtext

In [None]:
!pip install -r requirements.txt

In [None]:
import argparse, os, sys, glob, re
import json
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, ImageEnhance
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid, save_image
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from transformers import CLIPTokenizer, CLIPTextModel
from pathlib import Path

In [None]:
def parse(argument):
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )
    parser.add_argument(
        "--outdir",
        type=str,
        nargs="?",
        help="dir to write results to",
        default="outputs/txt2img-samples"
    )
    parser.add_argument(
        "--skip_grid",
        action='store_true',
        help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
    )
    parser.add_argument(
        "--skip_save",
        action='store_true',
        help="do not save individual samples. For speed measurements.",
    )
    parser.add_argument(
        "--ddim_steps",
        type=int,
        default=50,
        help="number of ddim sampling steps",
    )
    parser.add_argument(
        "--plms",
        action='store_true',
        help="use plms sampling",
    )
    parser.add_argument(
        "--laion400m",
        action='store_true',
        help="uses the LAION400M model",
    )
    parser.add_argument(
        "--fixed_code",
        action='store_true',
        help="if enabled, uses the same starting code across samples ",
    )
    parser.add_argument(
        "--ddim_eta",
        type=float,
        default=0.0,
        help="ddim eta (eta=0.0 corresponds to deterministic sampling",
    )
    parser.add_argument(
        "--n_iter",
        type=int,
        default=2,
        help="sample this often",
    )
    parser.add_argument(
        "--H",
        type=int,
        default=512,
        help="image height, in pixel space",
    )
    parser.add_argument(
        "--W",
        type=int,
        default=512,
        help="image width, in pixel space",
    )
    parser.add_argument(
        "--C",
        type=int,
        default=4,
        help="latent channels",
    )
    parser.add_argument(
        "--f",
        type=int,
        default=8,
        help="downsampling factor",
    )
    parser.add_argument(
        "--n_samples",
        type=int,
        default=3,
        help="how many samples to produce for each given prompt. A.k.a. batch size",
    )
    parser.add_argument(
        "--n_rows",
        type=int,
        default=0,
        help="rows in the grid (default: n_samples)",
    )
    parser.add_argument(
        "--scale",
        type=float,
        default=7.5,
        help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
    )
    parser.add_argument(
        "--from-file",
        type=str,
        help="if specified, load prompts from this file",
    )
    parser.add_argument(
        "--config",
        type=str,
        default="configs/stable-diffusion/v1-inference-box.yaml",
        help="path to config which constructs model",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        default="models/ldm/stable-diffusion-v1/model.ckpt",
        help="path to checkpoint of model",
    )    
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )
    parser.add_argument(
        "--precision",
        type=str,
        help="evaluate at this precision",
        choices=["full", "autocast"],
        default="autocast"
    )
    parser.add_argument(
        "--embedding_path", 
        type=str, 
        help="Path to a pre-trained embedding manager checkpoint")

    opt = parser.parse_args(argument)
    return opt

In [None]:
opt = parse(
"""
--plms --scale 4.0 --n_iter 1 --ddim_steps 50 --outdir outputs/
--n_samples 1 
""".split())

In [None]:
opt

In [None]:
opt.config = "configs/stable-diffusion/v1-inference-iterinpaint.yaml"

# Download Pretrained Checkpoint

In [None]:
!mkdir -p checkpoints
!wget https://huggingface.co/j-min/IterInpaint_CLEVR/resolve/main/iterinpaint_CLEVR_FG30.ckpt -O checkpoints/iterinpaint_CLEVR_FG30.ckpt

In [None]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

In [None]:
opt.ckpt = 'checkpoints/iterinpaint_CLEVR_FG30.ckpt'

In [None]:
seed_everything(opt.seed)

from transformers import logging
logging.set_verbosity_error()


config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if opt.plms:
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

In [None]:
opt.iterinpaint_nopaste = False

In [None]:
from ldm.gen_utils import sample_images, prepare_text, encode_scene, prepare_clip_tokenizer, sample_images_iterative_inpaint
from ldm.viz_utils import plot_results, fig2img, show_images

In [None]:
from ldm.gen_utils import inference_from_custom_annotation, encode_from_custom_annotation

In [None]:
opt

# Define a custom layout

In [None]:
custom_annotations = [
    {'x': 19,
     'y': 61,
     'width': 158,
     'height': 169,
     'label': 'blue metal cube'},
    {'x': 183,
      'y': 94,
      'width': 103,
      'height': 109,
      'label': 'brown rubber sphere'},
    {'x': 289, 'y': 112, 'width': 82, 'height': 77, 'label': 'gray metal sphere'},
    {'x': 374,
      'y': 128,
      'width': 48,
      'height': 46,
      'label': 'yellow rubber cylinder'},
    {'x': 22,
      'y': 346,
      'width': 82,
      'height': 73,
      'label': 'gray metal cylinder'},
    {'x': 110,
      'y': 326,
      'width': 107,
      'height': 102,
      'label': 'cyan rubber sphere'},
    {'x': 218,
      'y': 313,
      'width': 124,
      'height': 125,
      'label': 'green rubber cube'},
    {'x': 343,
      'y': 295,
      'width': 164,
      'height': 179,
      'label': 'red metal cylinder'},
]

In [None]:
scene = encode_from_custom_annotation(custom_annotations, size=512)

print(scene['boxes_normalized'])
print(scene['box_captions'])

In [None]:
layout_img = fig2img(plot_results(
    Image.new('RGB', (512, 512), color='white'),
    boxes=scene['boxes_normalized'],
    captions=scene['box_captions'],
    colors=['blue'] * len(scene['box_captions'])
))
layout_img

In [None]:
layout_name = "eight_objects_two_rows"

In [None]:
opt.box_generation_order = None

# Generate Image

In [None]:
generated = inference_from_custom_annotation(custom_annotations, sampler, opt)
show_images(
    generated['context_imgs'],
    generated['generated_images']
)
generated['final_image']