# SEDD reproduce

1. dataset 만들기
    
    - toy dataset (2d) 
    
    - MNIST

In [1]:
%load_ext autoreload    
%autoreload 2

In [5]:
import torch
from omegaconf import OmegaConf

# load config
device = torch.device('cuda')
dtype = torch.float32

config = {
    'noise' : {
        'num_train_timesteps' : 1000,
        'type'  : 'loglinear',
        'eps'   : 1e-4,
    },
    'graph' : {
        'type'  : 'absorb',
    },
    'dataset' : {
        'tokens' : 4,
        'size' : 28,
        'samples' : 16,
        'batch_size' : 16,
    },
    'model' : {
        'positional_embedding': '2d',
        'hidden_size'   : 32,
        'cond_dim'      : 32,
        'n_heads'       : 1,
        'n_blocks'      : 2,
        'dropout'       : 0.1,
        'scale_by_sigma' : False,
    },
    'optim' : {
        'lr' : 1e-3,
    }
}

config = OmegaConf.create(config)

#######
# run #
#######
# path
pass

# loggin (wandb)
pass

In [6]:
from torchvision import datasets
from torchvision import transforms as tfs

# load dataset
trans = tfs.Compose([tfs.ToTensor(), tfs.Normalize(mean=[0.0], std=[1/config.dataset.tokens])])
ds = datasets.MNIST(root='/mnt/image-net-full/gayoung.lee/yonghyun.park/', train=True, download=True, transform=trans)
ds.data = ds.data[:config.dataset.samples]
ds.targets = ds.targets[:config.dataset.samples]

dl = torch.utils.data.DataLoader(ds, batch_size=config.dataset.batch_size)

In [7]:
# load model
from sedd.model import SEDD
model = SEDD(config)

In [8]:
# load scheduler
from sedd.scheduler import Scheduler, ScoreEntropyLoss
scheduler = Scheduler(config)
loss_fn = ScoreEntropyLoss(scheduler)

In [9]:
# prepare training
optimizer = torch.optim.Adam(model.parameters(), lr=config.optim.lr)

In [11]:
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

def validation(output_dir, name):
    scheduler.set_timesteps(num_inference_steps=1000, offset=0, device=device)

    num_batch = 5
    size = 28

    xt = (scheduler.num_vocabs - 1) * torch.ones(num_batch, size**2, dtype=torch.long) # base distribution
    xt = xt.to(device)

    for t in tqdm(scheduler.timesteps):
        if t == scheduler.timesteps[999]:
            break

        with torch.no_grad():
            # forward
            t = torch.tensor([t], device=xt.device)
            score = model(xt, t).exp()
            
            # step
            xt = scheduler.step(score, t, xt)

            # prev_t
            # prev_t = scheduler.timesteps[(scheduler.timesteps == t).long().argmax()+1].unsqueeze(0).repeat(xt.size(0))
            # xt = scheduler.add_noise(xt, prev_t)

    fig, axs = plt.subplots(1, num_batch, figsize=(20, 8))
    for i in range(num_batch):
        img = xt[i].view(size, size).cpu()
        axs[i].imshow(img, cmap='gray')
    plt.savefig(os.path.join(output_dir, name))
    plt.close()
    pass

In [12]:
from tqdm import tqdm

# training
epochs = 50000

output_dir = f'runs/mnist-{config.dataset.samples}'
# output_dir = 'runs/mnist'
os.makedirs(output_dir, exist_ok=True)

model.to(device, dtype)
scheduler.to(device, dtype)

loss_traj = []
for epoch in tqdm(range(epochs)):
    
    model.train()
    for x0, _ in dl:
        x0 = x0.to(device)
        x0 = x0.flatten(start_dim=1).long()
        
        # perturb x0
        t = torch.randint(1, config.noise.num_train_timesteps, (x0.size(0),), device=device)
        xt = scheduler.add_noise(x0, t)
        
        # model forward
        log_score = model(xt, t)
        
        # compute loss function 
        loss = loss_fn(log_score, t, xt, x0)

        if loss.isnan():
            raise ValueError('loss is nan')
        
        # update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_traj.append(loss.item())

    if epoch % 1000 == 0:
        model.eval()
        validation(output_dir, name=f'{epoch}.png')

        plt.plot(loss_traj)
        # plt.yscale('log')
        plt.savefig(os.path.join(output_dir, 'loss.png'))
        plt.close()

100%|█████████▉| 999/1000 [00:02<00:00, 363.30it/s]
100%|█████████▉| 999/1000 [00:02<00:00, 362.74it/s] 
100%|█████████▉| 999/1000 [00:02<00:00, 370.11it/s]]  
100%|█████████▉| 999/1000 [00:02<00:00, 370.71it/s]]  
  7%|▋         | 3676/50000 [01:15<15:49, 48.79it/s]  


KeyboardInterrupt: 