In [1]:
# Set this path to be compatible with both linux and windows
ROOT_PATH = "F:\\Thesis\\Deep-Learning-Techniques-for-Image-Generation-from-Music"

import sys
sys.path.append(ROOT_PATH)

import torch
from omegaconf import OmegaConf
import os
from modules.util import instantiate_from_config
from pytorch_lightning import seed_everything
from models.diffusion.ddim import DDIMSampler
from models.diffusion.plms import PLMSSampler
import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid

def load_model_from_config(config, ckpt, device_n):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = torch.device(f"cuda:{device_n}")
    model.to(device)
    model.eval()

    return model

seed_everything(43) # Make the experiment reproducible
torch.set_float32_matmul_precision('high')

config = OmegaConf.load(ROOT_PATH + "/configs/custom-ldm-cwa-vq-f8.yaml")  
sample_folder = ROOT_PATH + f"/sample/wikiart/ldm"

vq_gan_pretrained_ckpt_path = ROOT_PATH + "/pretrained_model/vq-f8/model.ckpt"
# ldm_pretrained_ckpt_path = ROOT_PATH + "/model_checkpts/ldm/wikiart/epoch=136-step=31784.ckpt"
ldm_pretrained_ckpt_path = "/mnt/data1/bardella_data/gitRepos/Thesis/ldm_porting/model_checkpts/ldm/wikiart/epoch=299-step=69600.ckpt"
config.model.params.first_stage_config.params["ckpt_path"] = vq_gan_pretrained_ckpt_path

model = load_model_from_config(config, ldm_pretrained_ckpt_path, device_n=0)
#sampler = DDIMSampler(model)
sampler = PLMSSampler(model)

2023-07-17 01:12:38.174927: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1
Global seed set to 43


Loading model from /mnt/data1/bardella_data/gitRepos/Thesis/ldm_porting/model_checkpts/ldm/wikiart/epoch=299-step=69600.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 394.98 M params.
Keeping EMAs of 628.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Restored from /mnt/data1/bardella_data/gitRepos/Deep-Learning-Techniques-for-Image-Generation-from-Music/pretrained_model/vq-f8/model.ckpt with 0 missing and 49 unexpected keys


In [2]:
classes = [7]   # define classes to be sampled here
n_samples_per_class = 15
unconditional_class = 8

ddim_steps = 75
ddim_eta = 0
scale = 2 # for unconditional guidance

all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning({model.cond_stage_key: torch.tensor(n_samples_per_class*[unconditional_class]).to(model.device)})
        
        
        for class_label in classes:
            print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            xc = torch.tensor(n_samples_per_class*[class_label])
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_class,
                                             shape=[4, 32, 32],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc, 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))


# Save the producted image
image_name = f"sample_steps-{ddim_steps}_scale-{scale}_labels_-{','.join(str(n) for n in classes)}_eta-{ddim_eta}_uc-{unconditional_class}"
image_ext = ".png"
image_number = [0]
for entry in os.listdir(sample_folder):
    if os.path.isfile(os.path.join(sample_folder, entry)):
        splitted_entry = entry[:-len(image_ext)].split("_")
        number = int(splitted_entry.pop(-1))
        if splitted_entry == image_name.split("_"):
            image_number.append(number)

img.save(sample_folder+"/"+image_name+"_"+str(max(image_number) + 1)+image_ext)

rendering 15 examples of class '7' in 75 steps and using s=2.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 77 timesteps


PLMS Sampler:   0%|          | 0/77 [00:00<?, ?it/s]

PLMS Sampler: 100%|██████████| 77/77 [00:41<00:00,  1.87it/s]
