In [1]:
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

In [2]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


In [6]:
config

{'model': {'base_learning_rate': 5e-05, 'target': 'ldm.models.diffusion.ddpm.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 200, 'timesteps': 1000, 'first_stage_key': 'image', 'cond_stage_key': 'caption', 'image_size': 32, 'channels': 4, 'cond_stage_trainable': True, 'conditioning_key': 'crossattn', 'monitor': 'val/loss_simple_ema', 'scale_factor': 0.18215, 'use_ema': False, 'unet_config': {'target': 'ldm.modules.diffusionmodules.openaimodel.UNetModel', 'params': {'image_size': 32, 'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_heads': 8, 'use_spatial_transformer': True, 'transformer_depth': 1, 'context_dim': 1280, 'use_checkpoint': True, 'legacy': False}}, 'first_stage_config': {'target': 'ldm.models.autoencoder.AutoencoderKL', 'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_

In [3]:
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml")
model = load_model_from_config(
config, "models/ldm/text2img-large/model.ckpt")  # TODO: check path

Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 872.30 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [5]:
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

In [6]:
plms = None
outdir = 'outputs/txt2img-samples'

In [7]:
if plms:
    sampler = PLMSSampler(model)
else:
    sampler = DDIMSampler(model)

In [12]:
outdir = "outputs/txt2img-samples"
prompt = "a painting of a virus monster playing guitar"
H = 256
W = 256
n_samples = 4
scale = 1.0
ddim_eta = 0.0
n_iter = 4
ddim_steps = 50

In [13]:


os.makedirs(outdir, exist_ok=True)
outpath = outdir

prompt = prompt

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))

all_samples = list()
with torch.no_grad():
    with model.ema_scope():
        uc = None
        if scale != 1.0:
            uc = model.get_learned_conditioning(n_samples * [""])
        for n in trange(n_iter, desc="Sampling"):
            c = model.get_learned_conditioning(n_samples * [prompt])
            shape = [4, H//8, W//8]
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                            conditioning=c,
                                            batch_size=n_samples,
                                            shape=shape,
                                            verbose=False,
                                            unconditional_guidance_scale=scale,
                                            unconditional_conditioning=uc,
                                            eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp(
                (x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)

            for x_sample in x_samples_ddim:
                x_sample = 255. * \
                    rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                Image.fromarray(x_sample.astype(np.uint8)).save(
                    os.path.join(sample_path, f"{base_count:04}.png"))
                base_count += 1
            all_samples.append(x_samples_ddim)

# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(
    os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))

print(
f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.")


Sampling:   0%|          | 0/4 [00:00<?, ?it/s]

Data shape for DDIM sampling is (4, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:27<00:00,  1.80it/s]
Sampling:  25%|██▌       | 1/4 [00:29<01:27, 29.08s/it]

Data shape for DDIM sampling is (4, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:27<00:00,  1.80it/s]
Sampling:  50%|█████     | 2/4 [00:57<00:57, 28.82s/it]

Data shape for DDIM sampling is (4, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:28<00:00,  1.78it/s]
Sampling:  75%|███████▌  | 3/4 [01:26<00:28, 28.91s/it]

Data shape for DDIM sampling is (4, 4, 32, 32), eta 0.0
Running DDIM Sampling with 50 timesteps


DDIM Sampler: 100%|██████████| 50/50 [00:28<00:00,  1.76it/s]
Sampling: 100%|██████████| 4/4 [01:56<00:00, 29.02s/it]


Your samples are ready and waiting four you here: 
outputs/txt2img-samples 
Enjoy.
