In [None]:
import os

# TODO: change ID
ID = "ddim-v1"

MIXED_PRECISION = True
# TODO: choose distributed or not
DISTRIBUTED = True
RANK = int(os.getenv("RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))

# TODO: change device
DEVICE_ID = 0  # None for CPU
DEVICE_IDS = [0, 1, 2, 3]
OMP_NUM_THREADS = 10
SEED = 42

# Dataset
DATASET_REPETITIONS = 5
IMAGE_SIZE = 64

# KID = Kernel Inception Distance, see related section
KID_IMAGE_SIZE = 75
KID_DIFFUSION_STEPS = 5
PLOT_DIFFUSION_STEPS = 20

# sampling
MIN_SIGNAL_RATE = 0.02
MAX_SIGNAL_RATE = 0.95

# architecture
EMBEDDING_DIM = 32
EMBEDDING_MAX_FREQUENCY = 1000.0
WIDTHS = [32, 64, 96, 128]
BLOCK_DEPTH = 2

# training
# TODO: change epochs
START_EPOCH = 0
EPOCHS = 10
PLOT_EVERY = 1
BATCH_SIZE = 256 // WORLD_SIZE
LEARNING_RATE = 1e-3
EMA = 0.998  # TODO: change EMA
WEIGHT_DECAY = 1e-4

CHECKPOINT_DIR = os.path.join("./ckpts/", ID)
CHECKPOINT_NAME = "ckpt"
OUTPUT_DIR = os.path.join("./outputs/", ID)
DISTRIBUTED_STORE = os.path.join("./dist_store/", ID)
SAVE_PLOTS = True

In [None]:
import os
import random
import torch
import warnings


# Check if CUDA is available
if torch.cuda.is_available():
    gpus = torch.cuda.device_count()
    DEVICE_ID = DEVICE_ID if DEVICE_ID < gpus else 0
    if DISTRIBUTED:
        DEVICE_IDS = [id for id in DEVICE_IDS if id < gpus]
    else:
        for device_id in DEVICE_IDS:
            if device_id >= gpus:
                raise ValueError(f"GPU {device_id} is not available.")
else:
    DEVICE_ID = None

if DISTRIBUTED:
    print(f"Rank {RANK} Using distributed training with devices: {DEVICE_IDS}")
else:
    print(f"Using device id: {DEVICE_ID}")

os.environ["OMP_NUM_THREADS"] = str(OMP_NUM_THREADS)

random.seed(SEED)
torch.manual_seed(SEED)

warnings.filterwarnings("ignore")

# Dataset

Load the Flowers102 dataset.

In [None]:
import torch.utils.data
import torch.utils.data.distributed
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from IPython import get_ipython


def transform_clamp(x):
    return torch.clamp(x, 0, 1)


def get_dataloader(
    batch_size: int,
    image_size: int,
    split: str = "train",
    repeats: int = 1,
    pin_memory: bool = True,
    num_workers: int = 4,
    shuffle: bool = False,
    rank: int = 0,
    distributed: bool = False,
):
    """
    Load the Flowers102 dataset.

    :param batch_size: The batch size.
    :param split: The split to load. Either "train", "val", or "test".
    :param num_workers: The number of workers to use for loading the data.
    :param shuffle: Whether to shuffle the data.
    """
    if split not in ["train", "val", "test"]:
        raise ValueError(f"Invalid split: {split}")
    if split == "train":
        transform = transforms.Compose(
            [
                # transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)),
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transform_clamp,
            ]
        )
    else:
        transform = transforms.Compose(
            [
                transforms.Resize(int(image_size * 1.1)),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
            ]
        )
    dataset = datasets.Flowers102(
        root="datasets/flowers102",
        split=split,
        transform=transform,
        download=True,
    )
    if split == "train":
        test_dataset = datasets.Flowers102(
            root="datasets/flowers102",
            split="test",
            transform=transform,
            download=True,
        )
        dataset = torch.utils.data.ConcatDataset([dataset] + [test_dataset])
    dataset = torch.utils.data.ConcatDataset([dataset] * repeats)
    if distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            rank=rank,
            shuffle=shuffle,
        )
        shuffle = None  # shuffle is mutually exclusive with sampler
    else:
        sampler = None
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=num_workers,
        sampler=sampler,
    )
    return dataloader


ipy = get_ipython()
if ipy is not None:
    # If running in a notebook, load the data and show some images
    ipy.run_line_magic("matplotlib", "inline")

    import numpy as np
    import matplotlib.pyplot as plt

    train_loader = get_dataloader(32, IMAGE_SIZE, split="train", shuffle=True)

    # Get a batch of data
    inputs, classes = next(iter(train_loader))

    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)

    def imshow(inp, title=None):
        """Imshow for Tensor."""
        plt.figure(figsize=(16, 12))
        inp = inp.numpy().transpose((1, 2, 0))
        inp = np.clip(inp, 0, 1)
        plt.imshow(inp)
        if title is not None:
            plt.title(title)

    imshow(out)
    plt.show()

# Kernel inception distance

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchmetrics import Metric


class KID(Metric):
    def __init__(self, kid_image_size=75, **kwargs):
        super(KID, self).__init__(**kwargs)

        # self.transforms = models.Inception_V3_Weights.DEFAULT.transforms()
        self.transforms = transforms.Compose(
            [
                transforms.Resize(kid_image_size),
                transforms.CenterCrop(kid_image_size),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        encoder: nn.Module = models.inception_v3(
            weights=models.Inception_V3_Weights.DEFAULT
        )
        encoder.dropout = (
            nn.Identity()
        )  # Replace the dropout with an identity function
        encoder.fc = (
            nn.Identity()
        )  # Replace the classifier with an identity function
        for param in encoder.parameters():
            param.requires_grad = False
        self.encoder = encoder
        self.encoder.eval()

        # Mean of kid
        self.add_state("kid", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state(
            "count", default=torch.tensor(0.0), dist_reduce_fx="sum"
        )

    def polynomial_kernel(
        self, x1: torch.Tensor, x2: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the polynomial kernel between two sets of features.

        :param x1 torch.Tensor: The first set of features. Shape (batch_size, feat_dim).
        :param x2 torch.Tensor: The second set of features. Shape (batch_size, feat_dim).
        :return: The kernel matrix. Shape (batch_size, batch_size).
        :rtype: torch.Tensor
        """
        feat_dim = x1.shape[1]
        output = (x1 @ x2.T / feat_dim + 1.0) ** 3.0
        return output

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        assert preds.shape == target.shape
        assert preds.device == target.device
        assert (
            preds.device == self.device
        ), f"Expected device {self.device}, got {preds.device}"
        batch_size = preds.shape[0]

        # Transform the images using the InceptionV3 preprocessing
        target = self.transforms(target)
        preds = self.transforms(preds)

        # Compute the features using the InceptionV3 encoder
        real_features: torch.Tensor = self.encoder(
            target
        )  # (batch_size, feat_dim)
        generated_features: torch.Tensor = self.encoder(
            preds
        )  # (batch_size, feat_dim)

        # compute polynomial kernels using the two sets of features
        real_kernel = self.polynomial_kernel(real_features, real_features)
        generated_kernel = self.polynomial_kernel(
            generated_features, generated_features
        )
        cross_kernel = self.polynomial_kernel(
            real_features, generated_features
        )

        # estimate the squared maximum mean discrepancy using the average kernel values
        eye = torch.eye(batch_size).cuda()
        real_mean = (real_kernel * (1.0 - eye)).sum() / (
            batch_size * (batch_size - 1)
        )
        generated_mean = (generated_kernel * (1.0 - eye)).sum() / (
            batch_size * (batch_size - 1)
        )
        cross_mean = cross_kernel.sum() / batch_size**2

        # Calculate KID
        kid = real_mean + generated_mean - 2.0 * cross_mean
        self.kid += kid
        self.count += 1

    def compute(self) -> torch.Tensor:
        return self.kid / self.count


def test_kid():
    torch.cuda.set_device(DEVICE_ID)
    kid = KID(kid_image_size=KID_IMAGE_SIZE).cuda()

    train_loader = get_dataloader(128, IMAGE_SIZE, split="train")
    images, _ = next(iter(train_loader))
    images = images.cuda()

    kid.update(images, images)
    val = kid.compute()
    print("kid value:", val.item())
    kid.reset()

    sample_preds = torch.randn(128, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
    sample_target = torch.randn(128, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

    kid.update(images, sample_target)
    val = kid.compute()
    print("kid value:", val.item())
    kid.reset()

    kid.update(sample_preds, sample_target)
    val = kid.compute()
    print("kid value:", val.item())
    kid.reset()


# test_kid()

# Network architecture

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalEmbedding(nn.Module):
    def __init__(
        self,
        embedding_min_frequency=1.0,
        embedding_max_frequency=1000.0,
        embedding_dims=32,
    ):
        super(SinusoidalEmbedding, self).__init__()
        assert embedding_dims % 2 == 0, "Embedding dimensions must be even"
        self.embedding_min_frequency = embedding_min_frequency
        self.embedding_max_frequency = embedding_max_frequency
        self.embedding_dims = embedding_dims

        frequencies = torch.exp(
            torch.linspace(
                torch.log(torch.tensor(self.embedding_min_frequency)),
                torch.log(torch.tensor(self.embedding_max_frequency)),
                self.embedding_dims // 2,
            )
        )
        angular_speeds = 2.0 * torch.pi * frequencies
        angular_speeds = angular_speeds.view(1, -1, 1, 1)
        self.angular_speeds = nn.Parameter(angular_speeds, requires_grad=False)

    def forward(self, x: torch.Tensor):
        embeddings = torch.cat(
            [
                torch.sin(self.angular_speeds * x),
                torch.cos(self.angular_speeds * x),
            ],
            dim=1,
        )
        return embeddings


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(
                in_channels, out_channels, kernel_size=1
            )
        else:
            self.residual_conv = nn.Identity()
        self.norm1 = nn.BatchNorm2d(in_channels, eps=1e-8, momentum=0.01)
        self.relu = nn.SiLU()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, padding=1
        )
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, padding=1
        )
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            nn.init.xavier_normal_(
                module.weight, gain=nn.init.calculate_gain("selu")
            )
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor):
        residual = self.residual_conv(x)
        x = self.norm1(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = x + residual
        return x


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, block_depth):
        super(DownBlock, self).__init__()
        assert block_depth > 0, "Block depth must be greater than 0"
        self.block_depth = block_depth
        self.residual_blocks = nn.ModuleList(
            [ResidualBlock(in_channels, out_channels)]
            + [
                ResidualBlock(out_channels, out_channels)
                for _ in range(block_depth - 1)
            ]
        )
        self.pool = nn.AvgPool2d(kernel_size=2)

    def forward(self, x: torch.Tensor, skips: list):
        for block in self.residual_blocks:
            x = block(x)
            skips.append(x)
        x = self.pool(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, block_depth):
        super(UpBlock, self).__init__()
        assert block_depth > 0, "Block depth must be greater than 0"
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block_depth = block_depth

        self.up = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=False
        )
        self.residual_blocks = nn.ModuleList(
            [ResidualBlock(in_channels + out_channels, out_channels)]
            + [
                ResidualBlock(out_channels * 2, out_channels)
                for _ in range(block_depth - 1)
            ]
        )

    def forward(self, x: torch.Tensor, skips: list):
        assert x.shape[1] == self.in_channels
        x = self.up(x)
        for block in self.residual_blocks:
            x = torch.cat([x, skips.pop()], dim=1)
            x = block(x)
        return x


class ResidualUNet(nn.Module):
    def __init__(
        self,
        image_size,
        widths,
        block_depth,
        embedding_dims,
        embedding_max_frequency,
    ):
        super(ResidualUNet, self).__init__()
        self.image_size = image_size
        self.widths = widths
        self.embedding_dims = embedding_dims

        # Initialize the components
        self.sinusoidal_embedding = SinusoidalEmbedding(
            embedding_max_frequency=embedding_max_frequency,
            embedding_dims=embedding_dims,
        )

        # Initial Conv2d layer
        self.conv1 = nn.Conv2d(3, widths[0], kernel_size=1)

        # Down blocks
        self.down_blocks = nn.ModuleList(
            [
                DownBlock(
                    self.image_size // 2 + self.embedding_dims,
                    widths[0],
                    block_depth=block_depth,
                )
            ]
            + [
                DownBlock(widths[i - 1], widths[i], block_depth=block_depth)
                for i in range(1, len(widths) - 1)
            ]
        )
        # Residual blocks in the bottleneck
        self.bottleneck = nn.ModuleList(
            [ResidualBlock(widths[-2], widths[-1])]
            + [
                ResidualBlock(widths[-1], widths[-1])
                for _ in range(block_depth - 1)
            ]
        )
        # Up blocks
        self.up_blocks = nn.ModuleList(
            [UpBlock(widths[-1], widths[-2], block_depth=block_depth)]
            + [
                UpBlock(widths[i], widths[i - 1], block_depth=block_depth)
                for i in reversed(range(1, len(widths) - 1))
            ]
        )

        # Final convolution to map to 3 channels (output image)
        self.final_conv = nn.Conv2d(widths[0], 3, kernel_size=1)

        # Initialize the conv layers
        # initial conv layer
        nn.init.xavier_normal_(
            self.conv1.weight, gain=nn.init.calculate_gain("selu")
        )
        if self.conv1.bias is not None:
            nn.init.zeros_(self.conv1.bias)
        # final conv layer
        nn.init.zeros_(self.final_conv.weight)
        if self.final_conv.bias is not None:
            nn.init.zeros_(self.final_conv.bias)

    def forward(
        self,
        noisy_images: torch.Tensor,
        noise_variances: torch.Tensor,
    ):
        """
        Forward pass of the model.

        :param noisy_images torch.Tensor: The noisy images. Shape (batch_size, 3, H, W).
        :param noise_variances torch.Tensor: The noise variances. Shape (batch_size, 1, 1, 1).
        :return: The denoised images. Shape (batch_size, 3, H, W).
        :rtype: torch.Tensor
        """
        assert (
            noisy_images.shape[2] == noisy_images.shape[3]
            and noisy_images.shape[2] == self.image_size
        ), f"Expected image size {self.image_size}, got {noisy_images.shape[2]}"
        # Generate sinusoidal embedding
        e = self.sinusoidal_embedding(noise_variances)
        e = F.interpolate(e, size=self.image_size, mode="nearest")

        # Initial conv layer with the noisy images
        x = self.conv1(noisy_images)
        x = torch.cat([x, e], dim=1)

        # Downsampling blocks
        skips = []
        for down_block in self.down_blocks:
            x = down_block(x, skips)

        # Residual blocks in the bottleneck
        for block in self.bottleneck:
            x = block(x)

        # Upsampling blocks
        for up_block in self.up_blocks:
            x = up_block(x, skips)

        # Final convolution to get 3 output channels (image)
        x = self.final_conv(x)

        return x


def test_residual_unet():
    torch.cuda.set_device(DEVICE_ID)
    # Example usage:
    model = ResidualUNet(
        IMAGE_SIZE, WIDTHS, BLOCK_DEPTH, EMBEDDING_DIM, EMBEDDING_MAX_FREQUENCY
    ).cuda()

    params = sum(p.numel() for p in model.parameters())
    print("Number of parameters of ResidualUNet is {}".format(params))

    # Input shapes
    noisy_images = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
    noise_variances = torch.randn(1, 1, 1, 1).cuda()

    # Forward pass
    output = model(noisy_images, noise_variances)
    print(output.shape)  # Should output (1, 3, image_size, image_size)
    print("Test residual unet done.")


# test_residual_unet()

In [None]:
from time import perf_counter
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DataParallel as DP
import matplotlib.pyplot as plt
from IPython import get_ipython


class DiffusionModel:
    def __init__(
        self,
        image_size,
        embedding_dims,
        embedding_max_frequency,
        widths,
        block_depth,
        kid_image_size,
        kid_diffusion_steps,
        plot_diffusion_steps,
        min_signal_rate=0.01,
        max_signal_rate=0.99,
        ema_decay=0.999,
        learning_rate=1e-3,
        weight_decay=1e-4,
        mixed_precision=False,
        distributed=False,
        rank=0,
        world_size=1,
        device_ids=[0],
    ):
        self.image_size = image_size
        self.widths = widths
        self.block_depth = block_depth
        self.kid_diffusion_steps = kid_diffusion_steps
        self.plot_diffusion_steps = plot_diffusion_steps
        self.min_signal_rate = min_signal_rate
        self.max_signal_rate = max_signal_rate
        self.ema_decay = ema_decay
        self.mixed_precision = mixed_precision
        self.distributed = distributed
        self.rank = rank
        self.world_size = world_size

        self.loss_func = nn.L1Loss()

        self.mean = (
            # torch.tensor([0.485, 0.456, 0.406])
            torch.tensor([0.4752, 0.3933, 0.3070])  # precomputed
            .view(3, 1, 1)
            .requires_grad_(False)
        )
        self.std = (
            # torch.tensor([0.229, 0.224, 0.225])
            torch.tensor([0.2902, 0.2372, 0.2679])  # precomputed
            .view(3, 1, 1)
            .requires_grad_(False)
        )
        self.network = (
            ResidualUNet(
                image_size,
                widths,
                block_depth,
                embedding_dims,
                embedding_max_frequency,
            )
            .cuda()
            .train()
        )
        self.ema_network = (
            ResidualUNet(
                image_size,
                widths,
                block_depth,
                embedding_dims,
                embedding_max_frequency,
            )
            .cuda()
            .eval()
        )
        self.ema_network.load_state_dict(self.network.state_dict())
        if self.distributed:
            self.network = DDP(self.network, device_ids=[device_ids[rank]])
        self.optimizer = optim.AdamW(
            self.network.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
        )

        # Metrics
        self.kid = KID(kid_image_size=kid_image_size).cuda()

        # Mixed precision training
        self.scaler = GradScaler("cuda")

    @torch.no_grad()
    def adapt_normalizer(self, dataloader):
        """
        Called before ddp spawning.
        """
        print("Adapting normalizer...")
        # calculate the mean and std of the images for normalization and denormalization
        total = torch.zeros(3).cuda()
        var = torch.zeros(3).cuda()
        if self.distributed:
            dataloader.sampler.set_epoch(0)
        for images, _ in dataloader:
            images = images.cuda()
            total += images.sum(dim=(0, 2, 3))
        if self.distributed:
            dist.all_reduce(total)
        mean = total / len(dataloader.dataset) / (IMAGE_SIZE * IMAGE_SIZE)
        for images, _ in dataloader:
            images = images.cuda()
            var += ((images - mean.view(1, 3, 1, 1)) ** 2).sum(dim=(0, 2, 3))
        if self.distributed:
            dist.all_reduce(var)
        var = var / len(dataloader.dataset) / (IMAGE_SIZE * IMAGE_SIZE)
        std = torch.sqrt(var)
        self.mean = mean.view(3, 1, 1)
        self.std = std.view(3, 1, 1)

        if self.rank == 0:
            print("mean: {}".format(mean))
            print("std: {}".format(std))

    def normalize(self, images: torch.Tensor):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        # normalize the pixel values to have mean 0 and std 1
        images = (images - self.mean) / (self.std + 1e-8)
        return images

    def denormalize(self, images: torch.Tensor):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        # convert the pixel values back to 0-1 range
        images = images * self.std + self.mean
        images = torch.clamp(images, 0.0, 1.0)
        return images

    def diffusion_schedule(self, diffusion_times):
        # Diffusion times -> angles
        diffusion_times = torch.clamp(diffusion_times, 0.0, 1.0)
        max_angle = torch.acos(torch.tensor(self.min_signal_rate))
        min_angle = torch.acos(torch.tensor(self.max_signal_rate))

        # Angles -> rates
        diffusion_angles = (
            diffusion_times * (max_angle - min_angle) + min_angle
        )
        # calculate the noise and signal rates
        noise_rates = torch.sin(diffusion_angles)
        signal_rates = torch.cos(diffusion_angles)

        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # Predict the noise component and calculate the image component using it
        # Here, use the signal_rate and noise_rate to derive the output components
        network = self.network if training else self.ema_network
        if training:
            network.train()
        else:
            network.eval()
        pred_noises = network(noisy_images, noise_rates**2)
        pred_images = (noisy_images - pred_noises * noise_rates) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        # reverse diffusion = sampling
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            # This process gradually reduces the noise to generate a clearer image

            # remix the predicted components
            # Use the signal_rate and noise_rate from the next step
            # to recombine image and noise components
            diffusion_times = (
                torch.tensor(1 - step * step_size)
                .repeat(num_images, 1, 1, 1)
                .cuda()
            )  # (num_images, 1, 1, 1)
            noise_rates, signal_rates = self.diffusion_schedule(
                diffusion_times
            )  # (num_images, 1, 1, 1), (num_images, 1, 1, 1)
            pred_noises, pred_images = self.denoise(
                next_noisy_images, noise_rates, signal_rates, training=False
            )  # (num_images, 3, H, W), (num_images, 3, H, W)

            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )  # (num_images, 1, 1, 1), (num_images, 1, 1, 1)

            next_noisy_images = (
                pred_noises * next_noise_rates
                + pred_images * next_signal_rates
            )  # (num_images, 3, H, W)

        return pred_images

    def sample_noise(self, num_images):
        noise = torch.randn((num_images, 3, self.image_size, self.image_size))
        noise = noise.cuda()
        return noise

    def generate(self, num_images, diffusion_steps):
        # noise -> images -> denormalized images
        initial_noise = self.sample_noise(num_images)
        generated_images = self.reverse_diffusion(
            initial_noise, diffusion_steps
        )
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images: torch.Tensor):
        batch_size = images.shape[0]

        with torch.no_grad():
            # normalize images to have standard deviation of 1, like the noises
            images = self.normalize(images)  # (batch_size, 3, H, W)
            noises = self.sample_noise(batch_size)

            # sample uniform random diffusion times
            diffusion_times = torch.rand((batch_size, 1, 1, 1)).cuda()
            noise_rates, signal_rates = self.diffusion_schedule(
                diffusion_times
            )
            # mix the images with noises accordingly
            noisy_images = signal_rates * images + noise_rates * noises

        with torch.set_grad_enabled(True), autocast(
            "cuda", dtype=torch.float16, enabled=self.mixed_precision
        ):
            # predict the noise and image components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )
            noise_loss = self.loss_func(
                pred_noises, noises
            )  # used for training
            image_loss = self.loss_func(
                pred_images, images
            )  # only used as metric

        self.optimizer.zero_grad(set_to_none=True)
        if self.mixed_precision:
            self.scaler.scale(noise_loss).backward()
            torch.nn.utils.clip_grad_norm_(
                self.network.parameters(), max_norm=1.0
            )
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            noise_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.network.parameters(), max_norm=1.0
            )
            self.optimizer.step()

        self.ema_update()

        return {
            "n_loss": noise_loss.item(),
            "i_loss": image_loss.item(),
        }

    @torch.no_grad()
    def ema_update(self):
        # Update the EMA network with the main network's parameters
        for param, ema_param in zip(
            self.network.module.parameters(), self.ema_network.parameters()
        ):
            ema_param.data.mul_(self.ema_decay).add_(
                param.data, alpha=1.0 - self.ema_decay
            )

    def test_step(self, images):
        with torch.no_grad():
            batch_size = images.shape[0]

            # normalize images to have standard deviation of 1, like the noises
            images = self.normalize(images)
            noises = self.sample_noise(batch_size)

            # sample uniform random diffusion times
            diffusion_times = torch.rand((batch_size, 1, 1, 1)).cuda()
            noise_rates, signal_rates = self.diffusion_schedule(
                diffusion_times
            )
            # mix the images with noises accordingly
            noisy_images = signal_rates * images + noise_rates * noises

            with autocast("cuda", enabled=self.mixed_precision):
                pred_noises, pred_images = self.denoise(
                    noisy_images, noise_rates, signal_rates, training=False
                )

                noise_loss = self.loss_func(pred_noises, noises)
                image_loss = self.loss_func(pred_images, images)

            # measure KID between real and generated images
            # this is computationally demanding, kid_diffusion_steps has to be small
            images = self.denormalize(images)
            generated_images = self.generate(
                num_images=batch_size,
                diffusion_steps=self.kid_diffusion_steps,
            )
            self.kid.reset()
            self.kid.update(generated_images, images)

            return {
                "n_loss": noise_loss.item(),
                "i_loss": image_loss.item(),
                "kid": self.kid.compute().item(),
            }

    def plot_images(
        self,
        num_rows=3,
        num_cols=6,
        save=False,
        output_dir="outputs",
        epoch=None,
    ):
        ipy = get_ipython()
        if ipy is None and not save:
            return

        with torch.no_grad():
            generated_images = self.generate(
                num_images=num_rows * num_cols,
                diffusion_steps=self.plot_diffusion_steps,
            )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                idx = row * num_cols + col
                plt.subplot(num_rows, num_cols, idx + 1)
                plt.imshow(
                    generated_images[idx]
                    .detach()
                    .cpu()
                    .permute(1, 2, 0)
                    .numpy()
                )
                plt.axis("off")
        plt.tight_layout()
        if ipy is not None:
            plt.show()
        if save:
            os.makedirs(output_dir, exist_ok=True)
            plt.savefig(
                os.path.join(
                    output_dir,
                    "img{}.png".format(
                        "" if epoch is None else "_epoch{}".format(epoch)
                    ),
                )
            )
            print("Saved generated images at", output_dir)
        plt.close()

    def get_model_state_dict(self):
        if self.distributed:
            network = self.network.module
        else:
            network = self.network
        ema_network = self.ema_network
        return {
            "network": network.state_dict(),
            "ema_network": ema_network.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

    def set_model_state_dict(self, state_dict):
        if self.distributed:
            self.network.module.load_state_dict(state_dict["network"])
        else:
            self.network.load_state_dict(state_dict["network"])
        self.ema_network.load_state_dict(state_dict["ema_network"])
        self.optimizer.load_state_dict(state_dict["optimizer"])


def test_model():
    torch.cuda.set_device(DEVICE_ID)
    model = DiffusionModel(
        image_size=IMAGE_SIZE,
        embedding_dims=EMBEDDING_DIM,
        embedding_max_frequency=EMBEDDING_MAX_FREQUENCY,
        widths=WIDTHS,
        block_depth=BLOCK_DEPTH,
        kid_image_size=KID_IMAGE_SIZE,
        kid_diffusion_steps=KID_DIFFUSION_STEPS,
        plot_diffusion_steps=PLOT_DIFFUSION_STEPS,
        min_signal_rate=MIN_SIGNAL_RATE,
        max_signal_rate=MAX_SIGNAL_RATE,
        ema_decay=EMA,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        mixed_precision=MIXED_PRECISION,
    )
    model.setup()

    # Test
    test_loader = get_dataloader(
        batch_size=BATCH_SIZE, image_size=IMAGE_SIZE, split="test"
    )
    images, _ = next(iter(test_loader))
    images = images.cuda()

    model.adapt_normalizer(test_loader)

    start_ts = perf_counter()
    train_metrics = model.train_step(images)
    print("train metrics:", train_metrics)
    print("train time:", perf_counter() - start_ts)

    start_ts = perf_counter()
    test_metrics = model.test_step(images)
    print("test metrics:", test_metrics)
    print("test time:", perf_counter() - start_ts)

    start_ts = perf_counter()
    model.plot_images(num_rows=1)
    print("plot time:", perf_counter() - start_ts)
    print("Test model done.")


# test_model()

# Training


In [None]:
import os
import torch


def save_checkpoint(epoch, model_state_dict, checkpoint_dir, checkpoint_name):
    """
    Save a checkpoint to a specified directory.

    :param epoch: The epoch number to save in the checkpoint.
    :param model: The model to save in the checkpoint.
    :param checkpoint_dir: The directory to save the checkpoint in.
    :param checkpoint_name: The name of the checkpoint file, will be appended with the epoch number.
    """
    checkpoint_path = os.path.join(
        checkpoint_dir, f"{checkpoint_name}_{epoch:03d}.pt"
    )
    os.makedirs(checkpoint_dir, exist_ok=True)
    torch.save(
        model_state_dict,
        checkpoint_path,
    )
    print(f"Saved checkpoint for epoch {epoch} at {checkpoint_path}")


def load_checkpoint(
    checkpoint_dir,
    checkpoint_name,
    epoch=None,
    device=None,
) -> dict:
    """
    Load a checkpoint from a specified directory.

    :param model: The model to load the checkpoint into.
    :param checkpoint_dir: The directory to search for the checkpoint.
    :param checkpoint_name: The name of the checkpoint file, will be appended with the epoch number.
    :param epoch: The epoch to load the checkpoint from. If None, the latest checkpoint is loaded.
    If no checkpoint is found, the function prints a message and returns.
    """
    if epoch is None:
        # Search for the latest checkpoint
        files = os.listdir(checkpoint_dir)
        files = [
            file
            for file in files
            if file.startswith(checkpoint_name) and file.endswith(".pt")
        ]
        if not files:
            print("No checkpoint found.")
            return None
        files.sort()
        checkpoint_path = os.path.join(checkpoint_dir, files[-1])
    else:
        checkpoint_path = os.path.join(
            checkpoint_dir, f"{checkpoint_name}_{epoch:03d}.pt"
        )
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint for epoch {epoch} not found.")
        return None
    model_state_dict = torch.load(
        checkpoint_path, weights_only=False, map_location=device
    )
    return model_state_dict

In [None]:
import socket
from datetime import timedelta
import torch
import torch.distributed as dist


def find_free_port():
    """
    Find a free port on the local machine.
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return str(s.getsockname()[1])


def setup(rank, world_size, device_ids=[0]):
    """Setup the distributed process group."""
    # timeout = timedelta(seconds=30)
    timeout = timedelta(minutes=10)
    torch.cuda.set_device(device_ids[rank])
    dist.init_process_group(
        "nccl", timeout=timeout, rank=rank, world_size=world_size
    )
    print("Setupped process group for rank", rank)


def cleanup():
    """Clean up the distributed process group."""
    dist.destroy_process_group()
    print("Cleaned up process group")

In [None]:
import time
from datetime import datetime
import torch


if EPOCHS > 0:
    if DISTRIBUTED:
        setup(RANK, WORLD_SIZE, device_ids=DEVICE_IDS)
        device = torch.device(f"cuda")
    else:
        device = (
            torch.device(DEVICE_ID)
            if DEVICE_ID is not None
            else torch.device("cpu")
        )

    model = DiffusionModel(
        image_size=IMAGE_SIZE,
        embedding_dims=EMBEDDING_DIM,
        embedding_max_frequency=EMBEDDING_MAX_FREQUENCY,
        widths=WIDTHS,
        block_depth=BLOCK_DEPTH,
        kid_image_size=KID_IMAGE_SIZE,
        kid_diffusion_steps=KID_DIFFUSION_STEPS,
        plot_diffusion_steps=PLOT_DIFFUSION_STEPS,
        min_signal_rate=MIN_SIGNAL_RATE,
        max_signal_rate=MAX_SIGNAL_RATE,
        ema_decay=EMA,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        mixed_precision=MIXED_PRECISION,
        distributed=DISTRIBUTED,
        rank=RANK,
        world_size=WORLD_SIZE,
        device_ids=DEVICE_IDS,
    )

    model_state_dict = load_checkpoint(
        CHECKPOINT_DIR,
        CHECKPOINT_NAME,
        epoch=START_EPOCH,
        device=device,
    )
    if model_state_dict is not None:
        model.set_model_state_dict(model_state_dict)

    # Get the data loader
    train_loader = get_dataloader(
        batch_size=BATCH_SIZE,
        image_size=IMAGE_SIZE,
        split="train",
        repeats=DATASET_REPETITIONS,
        shuffle=True,
        rank=RANK,
        distributed=DISTRIBUTED,
    )
    val_loader = get_dataloader(
        batch_size=BATCH_SIZE,
        image_size=IMAGE_SIZE,
        split="val",
        rank=RANK,
        distributed=DISTRIBUTED,
    )

    for epoch in range(START_EPOCH + 1, EPOCHS + START_EPOCH + 1):
        print(
            "{}, epoch {:3d}/{:3d}".format(
                datetime.now(), epoch, EPOCHS + START_EPOCH
            )
        )
        start_ts = time.perf_counter()
        epoch_train_metrics = torch.zeros(2).cuda()
        epoch_val_metrics = torch.zeros(3).cuda()

        # Training
        if DISTRIBUTED:
            train_loader.sampler.set_epoch(epoch)
        for _idx, (images, _) in enumerate(train_loader):
            idx = _idx + 1
            images = images.cuda()
            train_metrics = model.train_step(images)
            epoch_train_metrics += torch.tensor(
                [train_metrics["n_loss"], train_metrics["i_loss"]]
            ).cuda()
            if idx % 20 == 0:
                print(
                    "rank {:2d}, train epoch {:3d}/{:3d}, batch {:4d}/{:4d}, n_loss: {:.4f}, i_loss: {:.4f}".format(
                        RANK,
                        epoch,
                        EPOCHS + START_EPOCH,
                        idx,
                        len(train_loader),
                        train_metrics["n_loss"],
                        train_metrics["i_loss"],
                    )
                )

        # Save checkpoint
        if RANK == 0:
            save_checkpoint(
                epoch,
                model.get_model_state_dict(),
                CHECKPOINT_DIR,
                CHECKPOINT_NAME,
            )

        # Testing
        for idx, (images, _) in enumerate(val_loader):
            images = images.cuda()
            test_metrics = model.test_step(images)
            epoch_val_metrics += torch.tensor(
                [
                    test_metrics["n_loss"],
                    test_metrics["i_loss"],
                    test_metrics["kid"],
                ]
            ).cuda()

        if DISTRIBUTED:
            dist.all_reduce(epoch_train_metrics, op=dist.ReduceOp.AVG)
            dist.all_reduce(epoch_val_metrics, op=dist.ReduceOp.AVG)

        # Print metrics
        if RANK == 0:
            avg_train_metrics = epoch_train_metrics / len(train_loader)
            avg_val_metrics = epoch_val_metrics / len(val_loader)
            epoch_train_metrics.zero_()
            epoch_val_metrics.zero_()
            print(
                "rank {:2d}, epoch {:3d}/{:3d}, n_loss: {:.4f}, i_loss: {:.4f}, val n_loss: {:.4f}, val i_loss: {:.4f}, val kid: {:.4f}, took {:.2f}s".format(
                    RANK,
                    epoch,
                    EPOCHS + START_EPOCH,
                    avg_train_metrics[0],
                    avg_train_metrics[1],
                    avg_val_metrics[0],
                    avg_val_metrics[1],
                    avg_val_metrics[2],
                    time.perf_counter() - start_ts,
                )
            )

        if RANK == 0 and epoch % PLOT_EVERY == 0:
            model.plot_images(
                save=SAVE_PLOTS,
                output_dir=OUTPUT_DIR,
                epoch=epoch,
            )

        if DISTRIBUTED:
            dist.barrier()

    if DISTRIBUTED:
        cleanup()