In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from diffusers import DDPMScheduler
from diffusers import UNet2DModel
from tqdm import tqdm

import sys
sys.path.append('/groups/mlprojects/dm_diffusion/Dark-Matter-Diffusion/src/')
from datasets import SlicedDataset, NPYDataset
from utils import sample, LogTransform

In [None]:
npy_file = "/groups/mlprojects/dm_diffusion/data/Maps_Mcdm_IllustrisTNG_LH_z=0.00.npy"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.RandomCrop((64, 64)),
                                LogTransform(),
                                transforms.Normalize((0.0,), (1.0,))])
dataset = NPYDataset(npy_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [None]:
for img in dataloader:
  plt.figure(figsize = (2,2))
  plt.imshow(img[0].permute(1, 2, 0))
  plt.axis("off")
  break

In [None]:
# Create a model
model = UNet2DModel(
    sample_size=64,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
    ),
    up_block_types=(
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
    ),
)
model.to(device)

In [None]:
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []
for epoch in range(30):
    # Add tqdm progress bar for the epoch
    with tqdm(dataloader, desc=f"Epoch {epoch+1}/{30}", unit="batch") as tepoch:
        for step, image in enumerate(tepoch):

            clean_images = image.to(device)
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
            ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # Get the model prediction
            noise_pred = model(noisy_images, timesteps)[0]

            # Calculate the loss
            loss = F.mse_loss(noise_pred, noise)
            loss.backward()
            losses.append(loss.item())

            # Update the model parameters with the optimizer
            optimizer.step()
            optimizer.zero_grad()

            # Update tqdm with the current loss
            tepoch.set_postfix(loss=loss.item())

    # Calculate and print the loss for the last epoch
    loss_last_epoch = sum(losses[-len(dataloader):]) / len(dataloader)
    print(f"Epoch {epoch+1} loss: {loss_last_epoch}")


In [None]:
sample(model, noise_scheduler)