In [2]:
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e ./taming-transformers
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops

!mkdir -p /models/ldm/text2img-large/
!wget -O /models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

^C


In [1]:
# imports
import os
import time
import torch
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid

from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

In [2]:
# Variables
use_laion400m = True
use_plms = False

if use_laion400m:
    print("Falling back to LAION 400M model...")
    config_location = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
    model_location = "models/ldm/text2img-large/model.ckpt"
    outdir = "outputs/txt2img-samples-laion400m"
else:
    print("Using Stable Diffusion model...")
    config_location = 'configs/stable-diffusion/v1-inference.yaml'
    model_location = 'models/ldm/stable-diffusion-v1/model.ckpt'
    outdir = 'outputs/txt2img-samples'

Falling back to LAION 400M model...


In [3]:
def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load(config_location)
    model = load_model_from_config(config, model_location)
    return model

In [4]:
model = get_model()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if use_plms:
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

os.makedirs(outdir, exist_ok=True)
sample_path = os.path.join(outdir, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outdir)) - 1

Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 25.4kB/s]
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 551kB/s]  
Downloading: 100%|██████████| 455k/455k [00:00<00:00, 731kB/s] 
Downloading: 100%|██████████| 570/570 [00:00<00:00, 288kB/s]


In [5]:
prompts = [["test"]]
seed = 420

ddim_steps = 50 # ddim sampling steps
ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling)
unconditional_guidance_scale = 0.0 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))

precision = 'autocast' # precision to evaluate at (full or autocast)

n_iter = 1 # how many sample iterations
batch_size = 1 # how many samples to generate per prompt
n_rows = 0 # rows in the grid (will use batch_size if set to 0)

fixed_code = False # if enabled, uses the same starting code across samples
skip_grid = True # do not save a grid, only individual samples
skip_save = False # do not save individual samples

H = 512 # height
W = 512 # width
C = 4 # channels
f = 8 # downsampling factor

if n_rows == 0:
    n_rows = batch_size

start_code = None
if fixed_code:
    start_code = torch.randn([batch_size, C, H // f, W // f], device=device)

precision_scope = autocast if precision=="autocast" else nullcontext
seed_everything(seed)
data = prompts

Global seed set to 420


In [None]:
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            tic = time.time()
            all_samples = list()
            for n in trange(n_iter, desc="Sampling"):
                for prompts in tqdm(data, desc="data"):
                    uc = None
                    if unconditional_guidance_scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)
                    shape = [C, H // f, W // f]
                    samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                        conditioning=c,
                                                        batch_size=batch_size,
                                                        shape=shape,
                                                        verbose=False,
                                                        unconditional_guidance_scale=unconditional_guidance_scale,
                                                        unconditional_conditioning=uc,
                                                        eta=ddim_eta,
                                                        x_T=start_code)

                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    if not skip_save:
                        for x_sample in x_samples_ddim:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            Image.fromarray(x_sample.astype(np.uint8)).save(
                                os.path.join(sample_path, f"{base_count:05}.png"))
                            base_count += 1

                    if not skip_grid:
                        all_samples.append(x_samples_ddim)

            if not skip_grid:
                # additionally, save as grid
                grid = torch.stack(all_samples, 0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=n_rows)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outdir, f'grid-{grid_count:04}.png'))
                grid_count += 1

            toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir} \n"
        f" \nEnjoy.")