In [1]:
%load_ext autoreload
%autoreload 2
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
from ai.diffusion.tiny_diffusion import ddpm
from ai.diffusion.tiny_diffusion import datasets
from dataclasses import dataclass

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

from ai.diffusion.tiny_diffusion.positional_embeddings import PositionalEmbedding
from ai.diffusion.tiny_diffusion.ddpm import MLP, NoiseScheduler

@dataclass
class Config:
    experiment_name: str = 'base'
    dataset: str = 'dino'
    train_batch_size: int = 32
    eval_batch_size: int = 1000
    num_epochs: int = 200
    learning_rate: float = 1e-3
    num_timesteps: int = 50
    beta_schedule: str = "linear"
    embedding_size: int = 128
    hidden_size: int = 128
    hidden_layers: int = 3
    time_embedding: str = "sinusoidal"
    input_embedding: str = "sinusoidal"
    save_images_step: int = 1
    
config = Config()

In [4]:
device = torch.device('cuda')
dataset = datasets.get_dataset('moons')

dataloader = DataLoader(
    dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=True)

model = MLP(
    hidden_size=config.hidden_size,
    hidden_layers=config.hidden_layers,
    emb_size=config.embedding_size,
    time_emb=config.time_embedding,
    input_emb=config.input_embedding).to(device)

noise_scheduler = NoiseScheduler(
    num_timesteps=config.num_timesteps,
    beta_schedule=config.beta_schedule)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
)

global_step = 0
frames = []
losses = []

print("Training model...")

for epoch in range(config.num_epochs):
    model.train()
    for step, batch in enumerate(dataloader):
        batch = batch[0]
        noise = torch.randn(batch.shape)
        timesteps = torch.randint(
            0, noise_scheduler.num_timesteps, (batch.shape[0],)
        ).long()
        noisy = noise_scheduler.add_noise(batch, noise, timesteps)
        noise_pred = model(noisy.to(device), timesteps.to(device))
        loss = F.mse_loss(noise_pred, noise.to(device))
        loss.backward(loss)

        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        # progress_bar.update(1)
        # logs = {"loss": loss.detach().item(), "step": global_step}
        losses.append(loss.detach().item())
        
        # progress_bar.set_postfix(**logs)
        if global_step % 300 == 0:
            print(f'{epoch}|{global_step} loss={loss.detach().item():.2f}')
        
        global_step += 1

    if epoch % config.save_images_step == 0 or epoch == config.num_epochs - 1:
        # generate data with the model to later visualize the learning process
        model.eval()
        sample = torch.randn(config.eval_batch_size, 2)
        timesteps = list(range(len(noise_scheduler)))[::-1]
        for i, t in enumerate(timesteps):
            t = torch.from_numpy(np.repeat(t, config.eval_batch_size)).long()
            with torch.no_grad():
                residual = model(sample.to(device), t.to(device)).to('cpu')
            sample = noise_scheduler.step(residual, t[0], sample)
        frames.append(sample.numpy())


Training model...
0|0 loss=1.42
0|100 loss=0.79
0|200 loss=0.84
1|300 loss=0.87
1|400 loss=0.67
2|500 loss=0.81
2|600 loss=0.59
2|700 loss=0.86
3|800 loss=0.62
3|900 loss=0.48
4|1000 loss=0.77
4|1100 loss=0.58
4|1200 loss=0.70
5|1300 loss=0.43
5|1400 loss=0.92
6|1500 loss=0.37
6|1600 loss=0.78
6|1700 loss=0.80
7|1800 loss=0.68
7|1900 loss=0.89
8|2000 loss=0.66
8|2100 loss=0.87
8|2200 loss=0.57
9|2300 loss=0.69
9|2400 loss=0.48
10|2500 loss=0.60
10|2600 loss=0.77
10|2700 loss=0.57
11|2800 loss=0.64
11|2900 loss=0.86
12|3000 loss=0.69
12|3100 loss=0.70
12|3200 loss=0.53
13|3300 loss=0.75
13|3400 loss=0.65
14|3500 loss=0.55
14|3600 loss=0.71
14|3700 loss=0.83
15|3800 loss=0.84
15|3900 loss=0.41
16|4000 loss=0.64
16|4100 loss=0.50
16|4200 loss=0.85
17|4300 loss=0.70
17|4400 loss=0.74
18|4500 loss=0.61
18|4600 loss=0.52
18|4700 loss=1.08
19|4800 loss=0.67
19|4900 loss=0.79
20|5000 loss=0.78
20|5100 loss=0.96
20|5200 loss=0.55
21|5300 loss=0.47
21|5400 loss=0.60
22|5500 loss=0.51
22|5600 los

KeyboardInterrupt: 

In [5]:
print("Saving model...")
outdir = f"exps/{config.experiment_name}"
os.makedirs(outdir, exist_ok=True)
torch.save(model.state_dict(), f"{outdir}/model.pth")

print("Saving images...")
imgdir = f"{outdir}/images"
os.makedirs(imgdir, exist_ok=True)
frames = np.stack(frames)
xmin, xmax = -6, 6
ymin, ymax = -6, 6
for i, frame in enumerate(frames):
    plt.figure(figsize=(10, 10))
    plt.scatter(frame[:, 0], frame[:, 1])
    plt.xlim(xmin, xmax)
    plt.ylim(ymin, ymax)
    plt.savefig(f"{imgdir}/{i:04}.png")
    plt.close()

print("Saving loss as numpy array...")
np.save(f"{outdir}/loss.npy", np.array(losses))

print("Saving frames...")
np.save(f"{outdir}/frames.npy", frames)

Saving model...
Saving images...
Saving loss as numpy array...
Saving frames...
