# Diffusion Model

## Config

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEST_TRAIN_SPLIT = 0.3
BATCH_SIZE = 32
EPOCHS = 100
PATH = 'augmented_data'
LEARNING_RATE = 1e-4

print(DEVICE)

## UNet Model

### Noise scheduler

In [4]:
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

timesteps = 1000
betas = linear_beta_schedule(timesteps=timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

### Model architecture

In [5]:
# handle input shape mismatch
def pad_to_match(x1, x2):
    diff_height = x2.size(2) - x1.size(2)
    diff_width = x2.size(3) - x1.size(3)
    x1 = F.pad(x1, (0, diff_width, 0, diff_height))
    return x1

In [12]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        def up_conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.bottleneck = conv_block(512, 1024)
        
        self.upconv4 = up_conv_block(1024, 512)
        self.decoder4 = conv_block(1024, 512)
        self.upconv3 = up_conv_block(512, 256)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = up_conv_block(256, 128)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = up_conv_block(128, 64)
        self.decoder1 = conv_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
    
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        
        bottleneck = self.bottleneck(self.pool(enc4))
        
        dec4 = self.upconv4(bottleneck)
        dec4 = pad_to_match(dec4, enc4)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = pad_to_match(dec3, enc3)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = pad_to_match(dec2, enc2)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = pad_to_match(dec1, enc1)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        return self.final_conv(dec1)

### Forward process

In [7]:
def forward_diffusion_sample(x_0, t, device="cpu"):
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, x_0.shape)
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

## Load the data

In [9]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(PATH, transform=transform)

n_test = int(np.floor(TEST_TRAIN_SPLIT * len(dataset)))
n_train = len(dataset) - n_test

train_ds, test_ds = random_split(dataset, [n_train, n_test])

train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)
test_dl = DataLoader(test_ds, batch_size = BATCH_SIZE, shuffle = False, num_workers = 4)


# some useful info about the dataset
print(f"Classes: {dataset.classes}")
print(f"Number of training samples: {len(train_ds)}")
print(f"Number of testing samples: {len(test_ds)}")

for i, (x, _) in enumerate(train_dl):
    print(f"Image shape: {x.shape}")
    break

Classes: ['AMD', 'DME', 'ERM', 'NO', 'RAO', 'RVO', 'VID']
Number of training samples: 1440
Number of testing samples: 616
Image shape: torch.Size([32, 1, 199, 546])


## Training

In [None]:
model = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

def train(model, dataloader, optimizer, criterion, epochs=100, device="cpu"):
    model.train()
    for epoch in tqdm(range(epochs), desc="Training", unit="epoch"):
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        
        for batch in progress_bar:
            x_0 = batch[0].to(device)
            t = torch.randint(0, timesteps, (x_0.shape[0],), device=device).long()
            x_t, noise = forward_diffusion_sample(x_0, t, device)
            predicted_noise = model(x_t)
            loss = criterion(predicted_noise, noise)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            progress_bar.set_postfix(loss=loss.item())
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

train(model, train_dl, optimizer, criterion, epochs=EPOCHS, device=DEVICE)

Training:   0%|          | 0/100 [00:00<?, ?epoch/s]

## Sampling

In [None]:
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=1, device="cpu"):
    model.eval()
    img = torch.randn((batch_size, channels, image_size[0], image_size[1]), device=DEVICE)
    for i in reversed(range(timesteps)):
        t = torch.full((batch_size,), i, device=device, dtype=torch.long)
        predicted_noise = model(img)
        alpha_t = get_index_from_list(alphas, t, img.shape)
        sqrt_one_minus_alpha_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, img.shape)
        img = (img - (1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t * predicted_noise) / torch.sqrt(alpha_t)
        if i > 0:
            noise = torch.randn_like(img)
            beta_t = get_index_from_list(betas, t, img.shape)
            img += torch.sqrt(beta_t) * noise
    img = torch.clamp(img, -1., 1.)
    return img

# Generate new images
generated_images = sample(model, image_size=(199, 546), batch_size=BATCH_SIZE, channels=1, device=DEVICE)

In [14]:
torch.save(model.state_dict(), "model.pth")