# SEDD reproduce

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from tqdm import tqdm
from omegaconf import OmegaConf

# load config
device = torch.device('cuda')
dtype = torch.float32

config = {
    'noise' : {
        'num_train_timesteps' : 1000,
        'type'  : 'loglinear',
        'eps'   : 1e-4,
    },
    'graph' : {
        'type'  : 'absorb',
    },
    'dataset' : {
        'tokens' : 32,
    },
    'model' : {
        'hidden_size'   : 32,
        'cond_dim'      : 64,
        'n_heads'       : 1,
        'n_blocks'      : 3,
        'dropout'       : 0.1,
        'scale_by_sigma' : False,
    },
    'optim' : {
        'lr' : 1e-3,
    }
}

config = OmegaConf.create(config)

In [4]:
from sedd.utils import ToyDataset

# load dataset
ds = ToyDataset(n_samples=1024, tokens=config.dataset.tokens-1)
dl = torch.utils.data.DataLoader(ds, batch_size=1024)

In [5]:
# load model
from sedd.model import SEDD
model = SEDD(config)

In [6]:
# load scheduler
from sedd.scheduler import Scheduler, ScoreEntropyLoss
scheduler = Scheduler(config)
loss_fn = ScoreEntropyLoss(scheduler)

In [7]:
# prepare training
optimizer = torch.optim.Adam(model.parameters(), lr=config.optim.lr)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import matplotlib.pyplot as plt

def validation(output_dir, name):
    scheduler.set_timesteps(num_inference_steps=1000, offset=0, device=device)

    num_batch = 256
    seq_len = 2

    model.eval()

    xt = (scheduler.num_vocabs - 1) * torch.ones(num_batch, seq_len, dtype=torch.long) # base distribution
    xt = xt.to(device)

    for t in tqdm(scheduler.timesteps):
        if t == scheduler.timesteps[999]:
            break

        with torch.no_grad():
            # forward
            t = torch.tensor([t], device=xt.device)
            score = model(xt, t).exp()
            
            # step
            xt = scheduler.step(score, t, xt)

            # prev_t
            prev_t = scheduler.timesteps[(scheduler.timesteps == t).long().argmax()+1].unsqueeze(0).repeat(xt.size(0))
            # xt = scheduler.add_noise(xt, prev_t)

    plt.scatter(xt[:, 0].cpu(), xt[:, 1].cpu())
    plt.scatter(ds.x[:, 0].cpu(), ds.x[:, 1].cpu(), s=1)
    plt.savefig(os.path.join(output_dir, name))
    plt.close()
    pass

In [9]:
import os
from tqdm import tqdm

# training
epochs = 15000

output_dir = 'runs/toy-2d'
os.makedirs(output_dir, exist_ok=True)

model.train()
model.to(device, dtype)
scheduler.to(device, dtype)

loss_traj = []
for epoch in tqdm(range(epochs)):
    for x0 in dl:
        x0 = x0.to(device)
        
        # perturb x0
        t = torch.randint(0, config.noise.num_train_timesteps, (x0.size(0),), device=device)
        xt = scheduler.add_noise(x0, t)
        
        # model forward
        log_score = model(xt, t)
        
        # compute loss function 
        loss = loss_fn(log_score, t, xt, x0)
        
        # update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_traj.append(loss.item())
        # assert False

    if epoch % 1000 == 0:
        validation(output_dir, name=f'{epoch}.png')

        plt.plot(loss_traj)
        plt.yscale('log')
        plt.savefig(os.path.join(output_dir, 'loss.png'))
        plt.close()

100%|█████████▉| 999/1000 [00:03<00:00, 296.79it/s]
100%|█████████▉| 999/1000 [00:03<00:00, 300.11it/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 301.98it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 301.35it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 301.73it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 300.96it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 302.70it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 253.21it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 298.87it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 268.03it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 302.77it/s]]
100%|█████████▉| 999/1000 [00:03<00:00, 300.68it/s]s]
 90%|████████▉ | 896/1000 [00:02<00:00, 302.30it/s]s]
 80%|████████  | 12000/15000 [03:06<00:46, 64.18it/s]


KeyboardInterrupt: 