# Class-Conditional Latent Diffusion Models

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pytorch-lightning==1.9.4
!pip install einops
!pip install omegaconf
!pip install kornia
!pip install wget
!pip install super-image

In [None]:
!nvidia-smi

Thu Apr 13 13:04:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Image Generation

In [6]:
ROOT_PATH = "/content/drive/MyDrive/Tesi Bardella 2022/LDM_pipeline_porting/Second_stage_LDM"

import sys
sys.path.insert(0, ROOT_PATH)

In [7]:
import torch
from omegaconf import OmegaConf
import os
from modules.util import instantiate_from_config
from models.diffusion.ddpm import DDIMSampler
from models.diffusion.plms import PLMSSampler
import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = "cuda"#torch.device("cuda:1")
    model.to(device)
    model.eval()

    return model


config = OmegaConf.load(ROOT_PATH + "/configs/custom-ldm-cwa-vq-f8.yaml")  
sample_folder = ROOT_PATH + f"/sample/ldm/wikiart"

vq_gan_pretrained_ckpt_path = ROOT_PATH + "/pretrained_model/vq-f8/model.ckpt"
ldm_pretrained_ckpt_path = "/content/drive/MyDrive/Tesi Bardella 2022/LDM_pipeline_porting/Second_stage_LDM/LDM_training/version_12/checkpoints/epoch=299-step=69600.ckpt"
config.model.params.first_stage_config.params["ckpt_path"] = vq_gan_pretrained_ckpt_path

model = load_model_from_config(config, ldm_pretrained_ckpt_path)
# sampler = DDIMSampler(model)
sampler = PLMSSampler(model)

Loading model from /content/drive/MyDrive/Tesi Bardella 2022/LDM_pipeline_porting/Second_stage_LDM/LDM_training/version_12/checkpoints/epoch=299-step=69600.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 394.98 M params.
Keeping EMAs of 628.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Restored from /content/drive/MyDrive/Tesi Bardella 2022/LDM_pipeline_porting/Second_stage_LDM/pretrained_model/vq-f8/model.ckpt with 0 missing and 49 unexpected keys


Higher values of `scale` produce better samples at the cost of a reduced output diversity. 

Increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. 

Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`.

### Sample for Exposition

In [10]:
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8]   # define classes to be sampled here
n_samples_per_class = 15
unconditional_class = 8

ddim_steps = 50
ddim_eta = 0
scale = 15  # for unconditional guidance

all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        # Un-conditional class for free classifier sampling
        uc = model.get_learned_conditioning({model.cond_stage_key: torch.tensor(n_samples_per_class*[unconditional_class]).to(model.device)})
        
        for class_label in classes:
            #Conditional class
            print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            xc = torch.tensor(n_samples_per_class*[class_label])
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_class,
                                             shape=[4, 32, 32],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc, 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))


# Save the producted image
image_name = f"sample_labels_{classes}_scale{scale}_ddim_steps{ddim_steps}_ddim_eta_{ddim_eta}"
image_ext = ".png"
image_number = [0]
for entry in os.listdir(sample_folder):
    if os.path.isfile(os.path.join(sample_folder, entry)):
        splitted_entry = entry[:-len(image_ext)].split("_")
        number = int(splitted_entry.pop(-1))
        if splitted_entry == image_name.split("_"):
            image_number.append(number)


save_folder = sample_folder + "/" + image_name + "_" + str( max(image_number) + 1 )
img.save(save_folder + image_ext)
# img

rendering 15 examples of class '0' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s]


rendering 15 examples of class '1' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.04it/s]


rendering 15 examples of class '2' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s]


rendering 15 examples of class '3' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


rendering 15 examples of class '4' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


rendering 15 examples of class '5' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


rendering 15 examples of class '6' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


rendering 15 examples of class '7' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s]


rendering 15 examples of class '8' in 50 steps and using s=15.00.
Data shape for PLMS sampling is (15, 4, 32, 32)
Running PLMS Sampling with 50 timesteps


PLMS Sampler: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s]


### Sample for FID

In [None]:
import json

mapping_file = ROOT_PATH + "/mapping.json"

with open(mapping_file, "r") as fin:
    mapping = json.load(fin)
    mapping = {v:k for k, v in mapping}

In [None]:
# classes = [0,1,2,3,4,5,6,7,8]   # define classes to be sampled here
classes = [0]   # define classes to be sampled here
n_samples_per_class = 5000
max_sample_size = 1
unconditional_class = 8

# ddim_steps = [20, 50, 200, 500, 1000]
ddim_steps = [200]
ddim_eta = 0
scale = 1  # for unconditional guidance

for ddim_step in ddim_steps:
    for cls in classes:
        # Create the folder
        save_folder = sample_folder + f"/{ddim_step}/{mapping[cls]}"
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        all_samples = list()

        with torch.no_grad():
            with model.ema_scope():
                # Un-conditional class for free classifier sampling
                uc = model.get_learned_conditioning({model.cond_stage_key: torch.tensor(n_samples_per_class*[unconditional_class]).to(model.device)}) if scale != 1 else None
                
                for class_label in classes:
                    print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
                    if n_samples_per_class > max_sample_size:
                        iterations = [max_sample_size] * (n_samples_per_class//max_sample_size)
                        if n_samples_per_class % max_sample_size != 0 :
                          iterations.append(n_samples_per_class % max_sample_size)

                    for iteration in iterations:
                        #Conditional class
                        xc = torch.tensor(iteration*[class_label])
                        c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
                        
                        samples_ddim, _ = sampler.sample(S=ddim_step,
                                                        conditioning=c,
                                                        batch_size=max_sample_size,
                                                        shape=[4, 32, 32],
                                                        verbose=False,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=uc, 
                                                        eta=ddim_eta)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0).to("cpu")

                        for images in torch.chunk(x_samples_ddim, chunks = iteration, dim = 0):
                            all_samples.append(x_samples_ddim)

        for i, img in enumerate(all_samples):
            # display as grid
            
            final_image = 255. * rearrange(img, 'c h w -> h w c').cpu().numpy().astype(np.uint8)

            final_image.save(save_folder+ f"image_{i}.png")


# Save the producted image
# image_name = f"sample_labels_{classes}_scale{scale}_ddim_steps{ddim_steps}_ddim_eta_{ddim_eta}"
# image_ext = ".png"
# image_number = [0]
# for entry in os.listdir(sample_folder):
#     if os.path.isfile(os.path.join(sample_folder, entry)):
#         splitted_entry = entry[:-len(image_ext)].split("_")
#         number = int(splitted_entry.pop(-1))
#         if splitted_entry == image_name.split("_"):
#             image_number.append(number)


# save_folder = sample_folder + "/" + image_name + "_" + str( max(image_number) + 1 )
# img.save(save_folder + image_ext)

rendering 5000 examples of class '0' in [200] steps and using s=1.00.
Data shape for DDIM sampling is (5000, 4, 32, 32), eta 0
Running DDIM Sampling with 200 timesteps


DDIM Sampler:   0%|          | 0/200 [00:01<?, ?it/s]


OutOfMemoryError: ignored

## Image Upscale

In [None]:
from super_image import DrlnModel, ImageLoader
from PIL import Image
import requests

scale = 4
image = Image.open("/content/drive/MyDrive/Tesi Bardella 2022/LDM_pipeline_porting/Second_stage_LDM/sample/ldm/wikiart/sample_labels_[2]_scale1_ddim_steps500_ddim_eta_0_1.png")

model = DrlnModel.from_pretrained('eugenesiow/drln', scale=scale).to("cuda").eval()      # scale 2, 3 and 4 models available


In [None]:
import torch
with torch.no_grad():
  inputs = torch.tensor(ImageLoader.load_image(image)).to("cuda").detach()
  preds = model(inputs)

  # ImageLoader.save_image(preds, save_folder + f'_scaled_{scale}x' + image_ext)                        # save the output 2x scaled image to `./scaled_2x.png`
  # ImageLoader.save_compare(inputs, preds, save_folder + f'_scaled_{scale}x_compare' + image_ext)      # save an output comparing the super-image with a bicubic scaling
  ImageLoader.save_image(preds, f'./scaled_{scale}x.png')                        # save the output 2x scaled image to `./scaled_2x.png`
  ImageLoader.save_compare(inputs, preds, f'./scaled_{scale}x_compare.png')      # save an output comparing the super-image with a bicubic scaling