In [None]:
from dataclasses import dataclass
from pathlib import Path
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor

In [None]:
def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def check_safety(x_image):
    return x_image, [False,] * x_image.shape[0]

In [None]:
@dataclass
class Config:
    root_dir: Path = Path.cwd().parent
    output_dir: str = "outputs/txt2img-samples-test"
    skip_grid: bool = False
    skip_save: bool = True
    ddim_steps: int = 50
    plms: bool = True
    laion400m: bool = False
    ddim_eta: float = 0.
    n_iter: int = 2
    n_samples: int = 2
    W: int = 512
    H: int = 512
    C: int = 4
    f: int = 8
    n_rows: int = 3
    scale: float = 7.5
    config: str = "configs/stable-diffusion/v1-inference.yaml"
    ckpt: str = "models/ldm/stable-diffusion-v1/model.ckpt"
    seed: int = 42
    precision: str = "autocast"
    fixed_code: bool = False
    show_images: bool = True
    
    def prepare_config(self):
        self.output_dir = self.root_dir / self.output_dir
        self.config = self.root_dir / self.config
        self.ckpt = self.root_dir / self.ckpt
        
config = Config()
    
if config.laion400m:
    print("Falling back to LAION 400M model...")
    config.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    config.ckpt = "models/ldm/text2img-large/model.ckpt"
    config.outdir = "outputs/txt2img-samples-laion400m"
    
config.prepare_config()

In [None]:
### set seed
seed_everything(config.seed)

In [None]:
# Load model
model_config = OmegaConf.load(f"{config.config}")
model = load_model_from_config(model_config, f"{config.ckpt}")

In [None]:
# get device and send model to device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

In [None]:
if config.plms:
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

In [None]:
outpath = Path(config.output_dir)
outpath.mkdir(parents=True, exist_ok=True)

In [None]:
n_rows = config.n_rows if config.n_rows > 0 else batch_size

In [None]:
sample_path = outpath / "samples"
sample_path.mkdir(parents=True, exist_ok=True)
base_count = len(list(sample_path.glob("*.png")))
grid_count = len(list(outpath.glob("*.png"))) - 1

In [None]:
start_code = None
if config.fixed_code:
    start_code = torch.randn(
        [config.n_samples, config.C, config.H // config.f, config.W // config.f],
        device=device
    )

In [None]:
precision_scope = autocast if config.precision=="autocast" else nullcontext

In [None]:
promt = ""
batch_size = config.n_samples
data = [batch_size * [promt]]

In [None]:
images = []
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time() # start time
            all_samples = list()
            for n in tqdm(range(config.n_iter), desc="Iterating by config.n_iter"):
                for prompts in data:
                    uc = None
                    if config.scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)
                    shape = [config.C, config.H // config.f, config.W // config.f]
                    samples_ddim, _ = sampler.sample(
                        S=config.ddim_steps,
                        conditioning=c,
                        batch_size=config.n_samples,
                        shape=shape,
                        verbose=False,
                        unconditional_guidance_scale=config.scale,
                        unconditional_conditioning=uc,
                        eta=config.ddim_eta,
                        x_T=start_code
                    )

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                    x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                    x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)

                    
                    for x_sample in x_checked_image_torch:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        if not config.skip_save:
                            img.save(sample_path / f"{base_count:05}.png")
                        images.append(img)
                        base_count += 1

                    if not config.skip_grid:
                        all_samples.append(x_checked_image_torch)

            if not config.skip_grid:
                    # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                    # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                img = Image.fromarray(grid.astype(np.uint8))
                img.save(outpath / f'grid-{grid_count:04}.png')
                grid_count += 1

            toc = time.time()
if config.show_images:
    for img in images:
        img.show()