In [8]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from tqdm import tqdm
import os
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [9]:
# Data preparation
def sample_batch(batch_size, device='cpu'):
    data, _ = make_swiss_roll(batch_size)
    data = data[:, [2, 0]] / 10
    data = data * np.array([1, -1])
    return torch.from_numpy(data).to(device)

In [10]:
# Data visualization
def plot(model, file_name, device):

    fontsize = 14
    fig = plt.figure(figsize=(10, 6))

    N = 5_000
    x0 = sample_batch(N).to(device)
    samples = model.sample(N, device=device)

    data = [x0.cpu(), model.forward_process(x0, 20)[-1].cpu(), model.forward_process(x0, 40)[-1].cpu()]
    for i in range(3):

        plt.subplot(2, 3, 1+i)
        plt.scatter(data[i][:, 0].data.numpy(), data[i][:, 1].data.numpy(), alpha=0.1, s=1)
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')

        if i == 0: plt.ylabel(r'$q(\mathbf{x}^{(0..T)})$', fontsize=fontsize)
        if i == 0: plt.title(r'$t=0$', fontsize=fontsize)
        if i == 1: plt.title(r'$t=\frac{T}{2}$', fontsize=fontsize)
        if i == 2: plt.title(r'$t=T$', fontsize=fontsize)

    time_steps = [0, 20, 40]
    for i in range(3):

        plt.subplot(2, 3, 4+i)
        plt.scatter(samples[time_steps[i]][:, 0].data.cpu().numpy(), samples[time_steps[i]][:, 1].data.cpu().numpy(),
                    alpha=0.1, c='r', s=1)
        plt.xlim([-2, 2])
        plt.ylim([-2, 2])
        plt.gca().set_aspect('equal')

        if i == 0: plt.ylabel(r'$p(\mathbf{x}^{(0..T)})$', fontsize=fontsize)

    plt.savefig(file_name, bbox_inches='tight')
    plt.close()

In [11]:
# Model to predict the mean and variance during the reverse diffusion process
class MLP(nn.Module):

    def __init__(self, N=40, data_dim=2, hidden_dim=64):
        super(MLP, self).__init__()

        self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(),)

        self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(hidden_dim, data_dim * 2),) for t in range(N)])

    def forward(self, x, t):

        h = self.network_head(x) # [batch_size, hidden_dim]
        tmp = self.network_tail[t](h) # [batch_size, data_dim * 2]
        """
        tmp produces a tensor of size [batch_size, data_dim * 2].
        This is split across the last dimension using torch.chunk to seperate tmp into 2 tensors.
        The first is the mean, and the second the diagonal matrix of 2x2 covariance matrix
        (i.e. only the variance)
        """
        mu, h = torch.chunk(tmp, 2, dim=1)
        var = torch.exp(h) # Keeps the value of var > 0
        std = torch.sqrt(var)

        return mu, std

In [None]:
class DiffusionModel(torch.nn.Module):

    def __init__(self, T, model: nn.Module, device, dim=2):

        super().__init__()
        # betas and alphas - Vector contains the variance of noise to add to the input data during forward process.
        # It is a scaling factor used to reduce the amount of data and increase the noise in the input.
        self.betas = (torch.sigmoid(torch.linspace(-18, 10, T)) * (3e-1 - 1e-5) + 1e-5).to(device)
        self.alphas = 1 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, 0)

        self.T = T
        self.model = model
        self.dim = dim

    def forward_process(self, x0, t):
        """
        :param t: Number of diffusion steps
        """

        assert t > 0, 't should be greater than 0'
        assert t <= self.T, f't should be lower or equal than {self.T}'

        t = t - 1 # Because we start indexing at 0

        """
        The mean and varaince for the forward process is calculated using x0 and the alphas using the
        closed form equation.
        """
        mu = torch.sqrt(self.alphas_bar[t]) * x0
        std = torch.sqrt(1 - self.alphas_bar[t])
        epsilon = torch.randn_like(x0)
        xt = mu + epsilon * std # data ~ N(mu, std)


        """
        The above equations provides the x_t for timestep t. However we need the posterior probability
        of x_t-1 given x_t and x_0.
        """
        std_q = torch.sqrt((1 - self.alphas_bar[t-1])/ (1 - self.alphas_bar[t]) * self.betas[t])
        m1 = torch.sqrt(self.alphas_bar[t-1]) * self.betas[t] / (1 - self.alphas_bar[t])
        m2 = torch.sqrt(self.alphas[t]) * (1 - self.alphas_bar[t-1]) / (1 - self.alphas_bar[t])
        mu_q = m1 * x0 + m2 * xt

        return mu_q, std_q, xt

    def reverse_process(self, xt, t):
        """
        :param t: Number of diffusion steps
        """

        assert t > 0, 't should be greater than 0'
        assert self.T <= self.T, f't should be lower or equal than {self.T}'

        t = t - 1 # Because we start indexing at 0

        mu, std = self.model(xt, t)
        epsilon = torch.randn_like(xt)

        return mu, std, mu + epsilon * std # data ~ N(mu, std)


    def sample(self, batch_size, device):

        noise = torch.randn((batch_size, self.dim)).to(device)
        x = noise

        samples = [x]
        for t in range(self.T, 0, -1):

            if not (t == 1):
                _, _, x = self.reverse_process(x, t)

            samples.append(x)

        return samples[::-1]


    def get_loss(self, x0):
        """
        :param x0: batch [batch_size, self.dim]
        """

        t = torch.randint(2, 40+1, (1,))
        mu_q, sigma_q, xt = self.forward_process(x0, t)

        mu_p, sigma_p, xt_minus1 = self.reverse_process(xt.float(), t)

        """
        KL divergence is not symmetric and it always non negative
        The negative of KL is maximized as it forced to be closer to zero. It is same as minimizing the very high positive value.
        Therefore loss = -(-K) = +k, the sign readjusted (double negation trick to convert the maximization problem into a minimization).
        """
        KL = torch.log(sigma_p) - torch.log(sigma_q) + (
            sigma_q**2 + (mu_q - mu_p)**2) / (2 * sigma_p**2)
        K = - KL.mean() # Should be maximized
        loss = - K # Should be minimized

        return loss

In [None]:
def train(diffusion_model, optimizer, batch_size, nb_epochs, device):

    figs_dir = "/content/gdrive/MyDrive/Diffusion_model/paper_1/training/figs"
    if not os.path.exists(figs_dir):
        os.makedirs(figs_dir)

    checkpoints_dir = '/content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints'
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)

    training_loss = []
    for epoch in tqdm(range(nb_epochs)):
        x0 = sample_batch(batch_size).to(device)
        loss = diffusion_model.get_loss(x0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        training_loss.append(loss.item())

        if epoch % 5000 == 0:

            checkpoint_path = os.path.join(checkpoints_dir, f'model_epoch_{epoch}.pth')
            torch.save(diffusion_model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")

            figs_path = os.path.join(figs_dir, f"training_loss_epoch_{epoch}.png" )
            plt.plot(training_loss)
            plt.savefig(figs_path)
            plt.close()

            plot(diffusion_model, figs_path, device)

    return training_loss

device = 'cuda'
mlp_model = MLP(hidden_dim=128).to(device)
model = DiffusionModel(40, mlp_model, device)
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=1e-4)

train(model, optimizer, 64_000, 300_000, device)

Mounted at /content/gdrive


  0%|          | 0/300000 [00:00<?, ?it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_0.pth


  2%|▏         | 4998/300000 [01:56<1:44:54, 46.87it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_5000.pth


  3%|▎         | 9997/300000 [03:53<1:45:55, 45.63it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_10000.pth


  5%|▍         | 14997/300000 [05:34<1:59:37, 39.71it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_15000.pth


  7%|▋         | 19999/300000 [07:30<1:50:26, 42.26it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_20000.pth


  8%|▊         | 25000/300000 [09:24<1:35:49, 47.83it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_25000.pth


 10%|▉         | 29995/300000 [11:06<1:24:50, 53.04it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_30000.pth


 12%|█▏        | 34997/300000 [13:05<1:36:07, 45.95it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_35000.pth


 13%|█▎        | 39996/300000 [15:01<1:31:43, 47.24it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_40000.pth


 15%|█▌        | 45000/300000 [16:42<1:18:58, 53.82it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_45000.pth


 17%|█▋        | 49999/300000 [18:37<1:32:29, 45.05it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_50000.pth


 18%|█▊        | 55000/300000 [20:17<1:14:51, 54.54it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_55000.pth


 20%|█▉        | 59997/300000 [22:13<1:23:57, 47.64it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_60000.pth


 22%|██▏       | 64998/300000 [24:08<1:22:27, 47.50it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_65000.pth


 23%|██▎       | 69999/300000 [25:50<1:10:55, 54.04it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_70000.pth


 25%|██▍       | 74998/300000 [27:30<1:25:22, 43.93it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_75000.pth


 27%|██▋       | 80000/300000 [29:22<1:07:17, 54.49it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_80000.pth


 28%|██▊       | 85000/300000 [31:02<1:05:37, 54.61it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_85000.pth


 30%|███       | 90000/300000 [32:54<1:23:46, 41.78it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_90000.pth


 32%|███▏      | 94999/300000 [34:34<1:02:56, 54.29it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_95000.pth


 33%|███▎      | 100000/300000 [36:30<1:09:58, 47.64it/s]

Saved checkpoint to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_100000.pth


 34%|███▍      | 102480/300000 [37:22<1:02:16, 52.87it/s]

In [None]:
def test_model(model, checkpoint_path, device):

    print(f"Loading model from checkpoint: {checkpoint_path}")
    try:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        model.to(device)
        model.eval()  # Set the model to evaluation mode
        print("Model loaded successfully. Plotting...")

        # Define a file name for the plot.
        file_name = "/content/gdrive/MyDrive/Diffusion_model/paper_1/training/figs/test_plot_from_checkpoint.png"

        # Call the existing plot function
        plot(model, file_name, device)
        print(f"Plot saved to {file_name}")

    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

checkpoint_to_load = '/content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_295000.pth'
device = 'cuda'
test_mlp_model = MLP(hidden_dim=128).to(device)
test_diffusion_model = DiffusionModel(40, test_mlp_model, device)
test_model(test_diffusion_model, checkpoint_to_load, device)

Loading model from checkpoint: /content/gdrive/MyDrive/Diffusion_model/paper_1/training/checkpoints/model_epoch_295000.pth
Model loaded successfully. Plotting...
Plot saved to /content/gdrive/MyDrive/Diffusion_model/paper_1/training/figs/test_plot_from_checkpoint.png
