In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/Acad/ADS/Project/

In [None]:
'''
Notebook to train DDPM model
'''

# Import required libraries
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torchvision
from PIL import Image
from modules_conditional import UNet, Diffusion
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

# Load training data
# High-Resolution lensing data
x_trainHR = np.load('./Data/train_labels.npy').astype(np.float32)
# Low-Resolution lensing data
x_trainLR = np.load('./Data/train_images.npy').astype(np.float32)
x_trainHR = torch.Tensor(x_trainHR)
x_trainLR = torch.Tensor(x_trainLR)
# Print data dimensions
print(x_trainHR.shape)
print(x_trainLR.shape)

# Create dataset and dataloader for efficient data loading and batching
dataset = TensorDataset(x_trainHR,x_trainLR)
dataloader = DataLoader(dataset, batch_size=5)

device = "cuda"
model = UNet().to(device)
#model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=128, device=device)
l = len(dataloader)
epochs = 100

for epoch in range(1, epochs):
    print(f"Starting epoch {epoch}:")
    pbar = tqdm(dataloader)
    avg_mse = 0
    for i, (images, conditions) in enumerate(pbar):
        images = images.to(device)
        conditions = conditions.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t, conditions)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(MSE=loss.item())
        avg_mse += loss.item()

    print(f'Average MSE: {avg_mse/1000:.5f}\n')
    torch.save(model.state_dict(), os.path.join("Weights", f"Diff_ckpt_1.pt"))