In [None]:
# may have to restart runtime/reset kernel after running
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e ./taming-transformers
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
!pip install git+https://github.com/openai/CLIP.git
!pip install transformers kornia imageio imageio_ffmpeg pillow

In [None]:
!mkdir -p models/ldm/text2img-large/
!wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

In [3]:
# imports
import os
import time
import torch
import imageio
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid

from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

In [4]:
# variables
use_laion400m = True
use_plms = False

if use_laion400m:
    print("Falling back to LAION 400M model...")
    config_location = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    model_location = "models/ldm/text2img-large/model.ckpt"
    outdir = "outputs/txt2img-samples-laion400m"
else:
    print("Using Stable Diffusion model...")
    config_location = 'configs/stable-diffusion/v1-inference.yaml'
    model_location = 'models/ldm/stable-diffusion-v1/model.ckpt'
    outdir = 'outputs/txt2img-samples'

Falling back to LAION 400M model...


In [5]:
# utility functions
def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model

def get_model():
    config = OmegaConf.load(config_location)
    model = load_model_from_config(config, model_location)
    return model

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def load_img(path, size):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) and resized to ({size[0]}, {size[1]}) from {path}")
    w, h = size #map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def slerp(val, low, high):
    low_norm = low/torch.norm(low)
    high_norm = high/torch.norm(high)
    omega = torch.acos((low_norm*high_norm).sum())
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so)*low + (torch.sin(val*omega)/so) * high
    return res

def get_slerp_vectors(start, end, frames=60):
    out = torch.Tensor(frames, start.shape[0]).to(device)
    factor = 1.0 / (frames - 1)
    for i in range(frames):
        out[i] = slerp(factor*i, start, end)
    return out

def get_starting_code_and_conditioning_vector(seed, prompt):
    if seed is None:
        seed = np.random.randint(np.iinfo(np.int32).max)
    seed_everything(seed)
    start_code = torch.randn([batch_size, C, H // f, W // f], device=device)
    c = model.get_learned_conditioning(prompt)
    return (c, start_code)

In [None]:
model = get_model()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if use_plms:
    print("Warning: img2img not compatible with PLMSSampler")
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
sample_path = os.path.join(outdir, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outdir)) - 1

Text -> Image

In [None]:
prompts = [
    "Joe Biden"
] # list of string prompts
seed = None

init_image_location = "" # file location of init image (use a blank string if none desired)
init_noise_strength = 0.75 # how much to noise init image (0-1.0 where 1.0 is full destruction of init image information)

ddim_steps = 50 # ddim sampling steps
ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling) (must be 0.0 if using PLMS sampling)
unconditional_guidance_scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

precision = 'autocast' # precision to evaluate at (full or autocast)

n_iter = 1 # how many sample iterations
batch_size = 1 # how many samples to generate per prompt
n_rows = 0 # rows in the grid (will use batch_size if set to 0)

fixed_code = False # if enabled, uses the same starting code across samples
skip_grid = True # do not save a grid, only individual samples
skip_save = False # do not save individual samples
show_images = True # whether or not to show images after generation

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

if n_rows == 0:
    n_rows = batch_size

if seed is None:
    seed = np.random.randint(np.iinfo(np.int32).max)
seed_everything(seed)

start_code = None
if fixed_code:
    start_code = torch.randn([batch_size, C, H // f, W // f], device=device)

precision_scope = autocast if precision=="autocast" else nullcontext
data = chunk(prompts, batch_size)

if init_image_location != "":
    assert os.path.isfile(init_image_location)
    init_image = load_img(init_image_location, ()).to(device)
    init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
    init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space

    sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)

    assert 0. <= init_noise_strength <= 1., 'can only work with strength in [0.0, 1.0]'
    t_enc = int(init_noise_strength * ddim_steps)
    print(f"target t_enc is {t_enc} steps")

In [None]:
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            for n in trange(n_iter, desc="Sampling"):
                for prompts in tqdm(data, desc="data"):
                    uc = None
                    if unconditional_guidance_scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)

                    if len(init_image_location) == 0:
                        shape = [C, H // f, W // f]
                        samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                            conditioning=c,
                                                            batch_size=batch_size,
                                                            shape=shape,
                                                            verbose=False,
                                                            unconditional_guidance_scale=unconditional_guidance_scale,
                                                            unconditional_conditioning=uc,
                                                            eta=ddim_eta,
                                                            x_T=start_code)
                    else:
                        # encode (scaled latent)
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                        # decode it
                        samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=unconditional_guidance_scale,
                                                unconditional_conditioning=uc)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    if not skip_save:
                        for x_sample in x_samples_ddim:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            img.save(
                                os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1
                            
                            if show_images:
                                display(img)
                                
                    if skip_save and show_images:
                        for x_sample in x_samples_ddim:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            img = Image.fromarray(x_sample.astype(np.uint8))
                            display(img)

                    if not skip_grid:
                        all_samples.append(x_samples_ddim)

            if not skip_grid:
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outdir, f'grid-{grid_count:04}.png'))
                grid_count += 1

            toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir}\n"
        f"\nTime elapsed: {round(toc - tic, 2)} seconds")

Text, Text -> Interpolation

In [None]:
prompts = [
    (741, "Joe Biden exhaling a large smoke cloud, featured on artstation"),
    (None, "The world on fire, trending on artstation"),
    (420, "Barack Obama exhaling a large smoke cloud, featured on artstation"),
] # (seed, prompt)

ddim_steps = 50 # ddim sampling steps
ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling) (must be 0.0 if using PLMS sampling)
unconditional_guidance_scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

precision = 'autocast' # precision to evaluate at (full or autocast)

batch_size = 1 # how many samples to generate per prompt (currently must be set to one for animation)

loop = True # if enabled, make animation loop by interpolating between end and start vectors
fixed_code = True # if enabled, uses the same starting code across samples
skip_save = False # do not save individual samples
show_images = True # whether or not to show images after generation

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

precision_scope = autocast if precision=="autocast" else nullcontext

frames = 60 # number of frames between text prompts
fps = 60 # frames per second of output mp4

In [None]:
previous_c = None
previous_start_code = None
slerp_c_vectors = []
slerp_start_codes = []
frames = frames + 2 # pad for beginning and end frame
for i, data in enumerate(map(lambda x: get_starting_code_and_conditioning_vector(*x), prompts)):
    c, start_code = data
    if i == 0:
        slerp_c_vectors.append(c)
        slerp_start_codes.append(start_code)
    else:
        original_c_shape = c.shape
        original_start_code_shape = start_code.shape
        c_vectors = get_slerp_vectors(previous_c.flatten(), c.flatten(), frames=frames)
        c_vectors = c_vectors.reshape(-1, *original_c_shape)
        slerp_c_vectors.extend(list(c_vectors[1:])) # drop first frame to prevent repeating frames
        start_codes = get_slerp_vectors(previous_start_code.flatten(), start_code.flatten(), frames=frames)
        start_codes = start_codes.reshape(-1, *original_start_code_shape)
        slerp_start_codes.extend(list(start_codes[1:])) # drop first frame to prevent repeating frames
        if loop and i == len(prompts) - 1:
            c_vectors = get_slerp_vectors(c.flatten(), slerp_c_vectors[0].flatten(), frames=frames)
            c_vectors = c_vectors.reshape(-1, *original_c_shape)
            slerp_c_vectors.extend(list(c_vectors[1:-1])) # drop first frame to prevent repeating frames
            start_codes = get_slerp_vectors(start_code.flatten(), slerp_start_codes[0].flatten(), frames=frames)
            start_codes = start_codes.reshape(-1, *original_start_code_shape)
            slerp_start_codes.extend(list(start_codes[1:-1])) # drop first and last frame to prevent repeating frames
    previous_c = c
    previous_start_code = start_code

In [None]:
video_out = imageio.get_writer('test.mp4', mode='I', fps=fps, codec='libx264')
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            #all_samples = list()
            for c, start_code in tqdm(zip(slerp_c_vectors, slerp_start_codes), desc="data", total=len(slerp_c_vectors)):
                uc = None
                if unconditional_guidance_scale != 1.0:
                    uc = model.get_learned_conditioning(batch_size * [""])
                if isinstance(c, tuple) or isinstance(c, list):
                    c = torch.stack(list(c), dim=0)
                if isinstance(start_code, tuple) or isinstance(start_code, list):
                    start_code = start_code[0]
                shape = [C, H // f, W // f]

                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                    conditioning=c,
                                                    batch_size=batch_size,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=uc,
                                                    eta=ddim_eta,
                                                    x_T=start_code)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                if not skip_save:
                    for x_sample in x_samples_ddim:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        video_out.append_data(x_sample)
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        base_count += 1

                        if show_images:
                            display(img)

                if skip_save and show_images:
                    for x_sample in x_samples_ddim:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        img = Image.fromarray(x_sample.astype(np.uint8))
                        display(img)

            toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir}\n"
        f"\nTime elapsed: {round(toc - tic, 2)} seconds")
video_out.close()

Image, Mask -> Image (inpainting)