In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from tqdm import tqdm

In [2]:
def sample_batch(batch_size, device='cpu'):

    data, _ = make_swiss_roll(batch_size)
    data = data[:, [2,0]] / 10
    data = data *np.array([1, -1])
    return torch.from_numpy(data).to(device)

In [3]:
class MLP(nn.Module):

    def __init__(self, N=40, data_dim=2, hidden_dim=64):
        super(MLP, self).__init__()

        self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim),
                                          nn.ReLU(),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(),)
        
        self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                                         nn.ReLU(),
                                                         nn.Linear(hidden_dim, data_dim * 2),) for t in range(N)])
        

    def forward(self, x, t):

        h = self.network_head(x) # [batch_size, hidden_dim]
        tmp = self.network_tail[t](h) # [batch_size, data_dim * 2]
        mu, h = torch.chunk(tmp, 2, dim=1)
        var = torch.exp(h)
        std = torch.sqrt(var)

        return mu, std

In [4]:
model = torch.load('model_paper1')
t = 5
x = torch.randn((64, 2))
mu, std = model(x, t)

print(mu.shape)
print(std.shape)

torch.Size([64, 2])
torch.Size([64, 2])


In [5]:
class DiffusionModel():
    def __init__(self, T, model: nn.Module, device, dim=2):
        self.betas = (torch.sigmoid(torch.linspace(-18, 10, T)) * (3e-1 - 1e-5) + 1e-5).to(device) # scaling data
        self.alphas = 1-self.betas
        self.alphas_bar = torch.cumprod(self.alphas, 0)

        self.T = T
        self.model = model
        self.dim = dim


    def forward_process(self, x0, t):

        # param t: Number of diffusion steps
        # param x0: data

        assert t > 0, 't should be greater than 0'
        assert self.T <= self.T, 'T should be lower or equal than {self.T}'

        t = t - 1 # Because we start indexing at 0

        mu = torch.sqrt(self.alphas_bar[t].cuda()) * x0
        std = torch.sqrt(1 - self.alphas_bar[t])
        epsilon = torch.randn_like(x0)
        xt = mu + epsilon * std # data ~ N(mu, std)

        std_q = torch.sqrt((1 - self.alphas_bar[t-1]) / (1 - self.alphas_bar[t]) * self.betas[t])
        m1 = torch.sqrt(self.alphas_bar[t-1]) * self.betas[t] / (1 - self.alphas_bar[t])
        m2 = torch.sqrt(self.alphas[t]) * (1 - self.alphas_bar[t-1]) / (1 - self.alphas_bar[t])
        mu_q = m1 * x0 + m2 * xt

        return mu_q, std_q, xt
    
    def reverse_process(self, xt, t):

        # param t: Number of diffusion steps
        # param x0: data

        assert t > 0, 't should be greater than 0'
        assert self.T <= self.T, 'T should be lower or equal than {self.T}'

        t = t - 1 # Because we start indexing at 0

        mu, std = self.model(xt, t)
        epsilon = torch.randn_like(xt)

        return mu, std, mu + epsilon * std # data ~ N(mu, std)
    

    def sample(self, batch_size, device):
        
        noise = torch.randn((batch_size, self.dim)).to(device)
        x = noise

        samples = [x]

        for t in range(self.T, 0, -1):

            if not t == 1:
                _, _, x = self.reverse_process(x,t) # only reverse if t greater than 1

            samples.append(x)

        return samples[::-1]
    
    def get_loss(self, x0):

        # param x0: batch [batch_size, self.dim]

        t = torch.randint(2, 40+1, (1,))
        mu_q, sigma_q, xt = self.forward_process(x0, t)
        mu_p, sigma_p, xt_minus1 = self.reverse_process(xt.float(), t)

        KL = torch.log(sigma_p) - torch.log(sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 * sigma_p**2)
        K = - KL.mean() # should be maximized
        loss = - K # should be minimized

        return loss

       

In [6]:
# x0 = sample_batch(3_000)
# mlp_model = torch.load('model_paper1')
# model = DiffusionModel(40, mlp_model, device = 'cuda')
# xT = model.forward_process(x0, 20)

In [7]:
def plot(model, file_name, device):

    font_size = 14
    fig = plt.figure(figsize=(10, 6))

    N = 5_000
    x0 = sample_batch(N).to(device)
    samples = model.sample(N, device=device)

    data = [x0.cpu(), model.forward_process(x0, 20)[-1].cpu(), model.forward_process(x0, 40)[-1].cpu()]
    for i in range(3):

        plt.subplot(2, 3, 1+i)
        plt.scatter(data[i][:, 0].data.numpy(), data[i][:, 1].data.numpy(), alpha=0.1, s=1)
        plt.xlim([-2,2])
        plt.ylim([-2,2])
        plt.gca().set_aspect('equal')

        if i == 0:
            plt.ylabel(r'$q(\mathbf{x}^{(0..T)})$', fontsize = font_size)
            plt.title(r'$t=0$', fontsize = font_size)
        elif i == 1:
            plt.title(r'$t=\frac{T}{2}$', fontsize = font_size)
        elif i == 2:
            plt.title(r'$t=T$', fontsize = font_size)

    time_steps = [0, 20, 40]

    for i in range(3):

        plt.subplot(2, 3, 4+i)
        plt.scatter(samples[time_steps[i]][:, 0].data.cpu().numpy(), samples[time_steps[i]][:, 1].data.cpu().numpy(),
                    alpha=0.1, c='r', s=1)
        plt.xlim([-2,2])
        plt.ylim([-2,2])
        plt.gca().set_aspect('equal')

        if i == 0:
            plt.ylabel(r'$p(\mathbf{x}^{(0..T)})$', fontsize = font_size)


    plt.savefig(file_name, bbox_inches = 'tight')
    plt.close()

In [8]:
torch.cuda.is_available()

True

In [9]:
def train(diffusion_model, optimizer, batch_size, nb_epochs, device):


    training_loss = []

    for epoch in tqdm(range(nb_epochs)):

        x0 = sample_batch(batch_size).to(device)
        loss = diffusion_model.get_loss(x0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        training_loss.append(loss.item())

        if epoch % 5000 == 0:
            plt.plot(training_loss)
            plt.savefig(f'figs/training_loss_epoch_{epoch}.png')
            plt.close()

            plot(diffusion_model, f'figs/training_epoch_{epoch}.png', device)

    return training_loss


device = 'cuda'
mlp_model = MLP(hidden_dim=128).to(device)
model = DiffusionModel(40, mlp_model, device)
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=1e-4)

train(model, optimizer, 64_000, 300_000, device)

100%|██████████| 300000/300000 [1:08:53<00:00, 72.59it/s]


[6.08629343079375,
 1.1292982364010222,
 5.820658242359925,
 3.717994088199966,
 1.2216253015541183,
 1.106133287881369,
 3.715602037454629,
 1.2204815674954717,
 1.1103958984481237,
 1.1036039820559094,
 1.0941890279809086,
 1.4875844043947117,
 1.0996790337136007,
 6.0266806564422915,
 3.3843460701752877,
 1.1004319960024975,
 1.085159769876299,
 2.354788451850932,
 6.148726904298986,
 1.1030782514257012,
 4.693061815538073,
 1.1108697969778447,
 5.939920778187928,
 1.107880284524887,
 1.703470035560936,
 1.7016903904101275,
 6.353447533332236,
 3.711755484593728,
 6.079645313353152,
 5.2648823085685,
 6.062612402068562,
 1.1041196952593775,
 4.063817403315206,
 1.1127500554856233,
 1.0827463991693216,
 1.0999909045676604,
 1.4850110073943907,
 1.482771415794581,
 6.0799197342241476,
 6.024881823329591,
 5.991397198151655,
 3.7088566436017016,
 2.0337345091046566,
 1.4787199444098886,
 5.2621452058144635,
 1.098718674173503,
 1.0921949004771865,
 3.019581993326693,
 6.206688781424399