# Testing Diffuasion Model for generating new Emissivities


In [None]:
from utils import *
from common import *

In [2]:
class DiffSXRDataset(SXRDataset):
    def __init__(self, n, gs, real=False, noise_level:float=0.0, random_remove:int=0, ks=1.0):
        super().__init__(n, gs, real, noise_level, random_remove, ks)
        assert self.em.shape[-1] == self.em.shape[-2] == 64, f'Expected 64x64, got {self.em.shape}'
        self.em = self.em.view(-1, 1, 64, 64)
        # em_mean = self.em.mean()
        # self.em = self.em - em_mean # center
    def __getitem__(self, idx): return self.em[idx]

In [3]:
N_DS = 10_000
# N_DS = 50_000
BATCH_SIZE = 128

In [None]:
def show_images(data, num_samples=16, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(7,7))

    for i, img in enumerate(data):
        if i == num_samples: break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
        # plt.imshow(img[0])
        # print(f"img: {img.shape}")
        plt.imshow(img.cpu().numpy().reshape(64, 64), cmap="inferno")
        #colorbar
        plt.colorbar()

data = DiffSXRDataset(N_DS, GSIZE, True, 0, 0, KS/255)
print(f"Dataset: {len(data)}")
show_images(data)

Later in this notebook we will do some additional modifications to this dataset, for example make the images smaller, convert them to tensors ect.

# Building the Diffusion Model


## Step 1: The forward process = Noise scheduler




We first need to build the inputs for our model, which are more and more noisy images. Instead of doing this sequentially, we can use the closed form provided in the papers to calculate the image for any of the timesteps individually.

**Key Takeaways**:
- The noise-levels/variances can be pre-computed
- There are different types of variance schedules
- We can sample each timestep image independently (Sums of Gaussians is also Gaussian)
- No model is needed in this forward step

In [5]:
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(DEV) * x_0.to(DEV) \
        + sqrt_one_minus_alphas_cumprod_t.to(DEV) * noise.to(DEV), noise.to(DEV)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

Let's test it on our dataset ...

In [None]:
def show_tensor_image(image):
    if len(image.shape) == 4: image = image[0, :, :, :] # Take first image of batch
    plt.scatter(RR, ZZ, c=image, cmap="plasma", s=1)
    plt.axis("equal")
    plt.axis("off")
    plt.colorbar()


data = DiffSXRDataset(N_DS, GSIZE, True, 0, 0, KS/255)
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

#print stats about dataloader
print(f"Number of batches: {len(dataloader)}")

In [None]:
# Simulate forward diffusion
# image = next(iter(dataloader))[0]
image = next(iter(dataloader))[0]

plt.figure(figsize=(15,1))
plt.axis('off')
num_images = 12
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    img, noise = forward_diffusion_sample(image, t)
    show_tensor_image(img.cpu())

## Step 2: The backward process = U-Net



For a great introduction to UNets, have a look at this post: https://amaarora.github.io/2020/09/13/unet.html.


**Key Takeaways**:
- We use a simple form of a UNet for to predict the noise in the image
- The input is a noisy image, the ouput the noise in the image
- Because the parameters are shared accross time, we need to tell the network in which timestep we are
- The Timestep is encoded by the transformer Sinusoidal Embedding
- We output one single value (mean), because the variance is fixed


In [None]:
from torch import nn
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=DEV) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 1
        # down_channels = (64, 128, 256, 512, 1024)
        # up_channels = (1024, 512, 256, 128, 64)
        K = 8
        down_channels = (64//K, 128//K, 256//K, 512//K, 1024//K)
        up_channels = (1024//K, 512//K, 256//K, 128//K, 64//K)
        out_dim = 1
        # time_emb_dim = 32
        time_emb_dim = 16

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Edit: Corrected a bug found by Jakub C (see YouTube comment)
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))


**Further improvements that can be implemented:**
- Residual connections
- Different activation functions like SiLU, GWLU, ...
- BatchNormalization
- GroupNormalization
- Attention
- ...

## Step 3: The loss



**Key Takeaways:**
- After some maths we end up with a very simple loss function
- There are other possible choices like L2 loss ect.


In [9]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, DEV)
    noise_pred = model(x_noisy, t)
    # return F.l1_loss(noise, noise_pred)
    return F.mse_loss(noise, noise_pred)

## Sampling
- Without adding @torch.no_grad() we quickly run out of memory, because pytorch tacks all the previous images for gradient calculation
- Because we pre-calculated the noise variances for the forward pass, we also have to use them when we sequentially perform the backward process

In [10]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        # As pointed out by Luis Pereira (see YouTube comment)
        # The t's are offset from the t's in the paper
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = GSIZE
    img = torch.randn((1, 1, img_size, img_size), device=DEV)
    plt.figure(figsize=(15,2))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=DEV, dtype=torch.long)
        img = sample_timestep(img, t)
        # Edit: This is to maintain the natural range of the distribution
        # img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize)+1)
            show_tensor_image(img.detach().cpu())
    plt.tight_layout()
    plt.show()

## Training

In [None]:
# assert False
from torch.optim import Adam
from tqdm import tqdm

data = DiffSXRDataset(N_DS, GSIZE, True, 0, 0, KS/255)
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

model.to(DEV)
optimizer = Adam(model.parameters(), lr=0.001)
epochs = 80 # Try more! (50)

for epoch in range(epochs):
    for step, batch in enumerate(tqdm(dataloader, leave=False, desc=f"Epoch {epoch}")):
    # for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        t = torch.randint(0, T, (BATCH_SIZE,), device=DEV).long()
        loss = get_loss(model, batch, t)
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0 and step == 0:
            print(f"\rEpoch {epoch} | Loss: {loss.item()} ")
            sample_plot_image()
    # save the model
    torch.save(model.state_dict(), f"mg_data/{JOBID}/diffusion_model.pth")

In [None]:
# load the model
model = SimpleUnet()
model.load_state_dict(torch.load(f"mg_data/{JOBID}/diffusion_model.pth"))
model.to(DEV)

# sample and plot

@torch.no_grad()
def generate_samples(n):
    imgs = torch.randn((n, 1, GSIZE, GSIZE), device=DEV) # Sample noise
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=DEV, dtype=torch.long)
        imgs = sample_timestep(imgs, t)
        # imgs = torch.clamp(imgs, -1.0, 1.0)
    return imgs

# Generate samples
n = 7
samples = generate_samples(n*n)
plt.figure(figsize=(15, 20))
plt.axis('off')
for i in range(n):
    for j in range(n):
        plt.subplot(n, n, i*n+j+1)
        show_tensor_image(np.clip(samples[i*n+j].detach().cpu(), 0, 10000))
plt.suptitle("Generated samples")
plt.tight_layout()
#save image
plt.savefig(f"mg_data/{JOBID}/diffusion_generated_samples.png")
plt.show()

# samples from the dataset
data = DiffSXRDataset(N_DS, GSIZE, True, 0, 0, KS/255)
dataloader = DataLoader(data, batch_size=n*n, shuffle=True, drop_last=True)
# sample from the dataset
samples = next(iter(dataloader))
plt.figure(figsize=(15, 20))
plt.axis('off')
for i in range(n):
    for j in range(n):
        plt.subplot(n, n, i*n+j+1)
        show_tensor_image(np.clip(samples[i*n+j].detach().cpu(), 0, 10000))
plt.suptitle("Real samples")
plt.tight_layout()
#save image
plt.savefig(f"mg_data/{JOBID}/diffusion_real_samples.png")
plt.show()