In [1]:
import torch
from torchvision.utils import save_image
from diffusers import AutoencoderKL
import argparse
import yaml
import os

from tqdm import tqdm

from src.ema import calculate_posthoc_ema
from utils import get_model, CLS_LOC_MAPPING
from diffusion import create_diffusion

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
result_dir = "/home/erbill/models/mp"
ema_std = 0.1
cfg_scale = 4.0
num_sampling_steps = 250
seed = 42 #or None

class_labels = [17, 17, 947, 947]

num_samples = 4
num_images = 512

n_col = 2

assert num_images % 64 == 0
assert len(class_labels) % n_col == 0

In [None]:
if seed:
    torch.manual_seed(seed)

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

with open(os.path.join(result_dir, "config.yaml"), "r") as f:
    train_args = yaml.safe_load(f)

# Load model
model = get_model(train_args).to(device)

# Load EMA state_dict
state_dict = calculate_posthoc_ema(ema_std, os.path.join(result_dir, "ema"), verbose=True)

model.load_state_dict(state_dict)
model.eval()

#Load VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)

out = []
for class_label in class_labels:
    z = torch.randn(num_samples, train_args["in_channels"], train_args["input_size"], train_args["input_size"], device=device)
    
    res = []
    interp_steps = num_images // num_samples
    for idx_s in range(num_samples): 
        for t in range(interp_steps):
            res.append(z[idx_s].lerp(z[(idx_s+1)%num_samples], t/interp_steps))
            
    z = torch.stack(res)
    
    # Labels to condition the model on
    y = torch.tensor([class_label] * num_images, device=device)
    
    # Setup CFG
    z = torch.cat([z, z], dim=0)
    y_null = torch.tensor([1000] * num_images, device=device)
    y = torch.cat([y, y_null], dim=0)
    model_kwargs = dict(y=y, cfg_scale=cfg_scale)
    
    # Sample images
    diffusion = create_diffusion(str(num_sampling_steps))
    samples = diffusion.ddim_sample_loop(
        model.forward_with_cfg,
        z.shape,
        z,
        clip_denoised=False,
        model_kwargs=model_kwargs,
        progress=True,
        device=device,
    )
    # Remove null class samples
    samples, _ = samples.chunk(2, dim=0)
    
    
    # Denormalize samples
    mean = torch.tensor(train_args["stats_mean"]).reshape(1, -1, 1, 1).to(device)
    std = torch.tensor(train_args["stats_std"]).reshape(1, -1, 1, 1).to(device)
    samples = samples * std + mean
    
    res = []
    for idx in tqdm(range(0, num_images, 64)):
        res.append(vae.decode(samples[idx:idx+64]).sample.cpu())
    
    samples = torch.cat(res, dim=0)
    samples = samples.clamp(-1, 1)

    out.append(samples)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
 24%|██▍       | 61/250 [00:30<01:35,  1.98it/s]

In [None]:
from PIL import Image
import numpy as np

# Concat to a grid of size (n//n_col) * n_col
samples = torch.cat([torch.cat(out[i:i+n_col], dim=-1) for i in range(0, len(out), n_col)], dim=-2)
images = [Image.fromarray(((img + 1) / 2 * 255).astype(np.uint8)) for img in samples.permute(0, 2, 3, 1).numpy()]

images[0].save(
    'output.gif',
    save_all=True,
    append_images=images[1:],  # Add the rest of the frames
    duration=50,  # Duration between frames in milliseconds
    loop=0  # Loop count, 0 means infinite
)