In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch 
from src.models.denoising_diffusion import GaussianDiffusion
from src.models.unet import Unet
from src.utils.files import load_omegaconf_from_yaml
from src.config.settings import CONFIG_PATH, DATA_PATH

import matplotlib.pyplot as plt

# Model loading 

In the following cell, we initiate the model and load the checkpoint weights. \
You have to adapt the model checkpoint path to your checkpoint location. 


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = load_omegaconf_from_yaml(path=CONFIG_PATH)

path_weights = "./results/model-15.pt" #adapt the path to your checkpoint location

model = Unet(dim=config.ddpm.size, channels=config.ddpm.unet.channels, dim_mults=config.ddpm.unet.dim_mults)

batch_size = 4

diffusion = GaussianDiffusion(
    model,
    image_size=config.ddpm.size,
    timesteps=config.ddpm.gaussian_diffusion.timesteps,
    sampling_timesteps=config.ddpm.gaussian_diffusion.sampling_timesteps,
)

model_state_dict = torch.load(f=path_weights, map_location=device)
diffusion.load_state_dict(model_state_dict["model"])

diffusion = diffusion.to(device)

Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


# Generating samples using checkpoint

Let's generate chest X-ray samples using the pre-trained model.


In [None]:
sampled_images = diffusion.sample(batch_size = batch_size)

sampling loop time step:  58%|████████████████████████████████████████████████████████████████████▌                                                  | 144/250 [01:04<00:46,  2.30it/s]

# Qualitative and visual results

In the following cell, we visualize the samples generated by Denoising Diffusion Probabilistic Modeling. \

In [None]:
from src.utils.helpers import tensor_to_numpy_image

sampled_images_np = sampled_images.detach().cpu().numpy().transpose(0,2,3,1)

for i, sample in enumerate(sampled_images_np):
    plt.subplot(2,2,i+1)
    plt.imshow(sample, cmap='gray')
    plt.axis('off')
plt.show()