# SEDD reproduce

1. dataset 만들기
    
    - toy dataset (2d) 
    
    - MNIST

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from omegaconf import OmegaConf

# load config
device = torch.device('cuda')
dtype = torch.float32

config = {
    'noise' : {
        'num_train_timesteps' : 1000,
        'type'  : 'loglinear',
        'eps'   : 1e-4,
    },
    'graph' : {
        'type'  : 'absorb',
    },
    'dataset' : {
        'tokens' : 32,
        'samples' : 128,
    },
    'model' : {
        'hidden_size'   : 32,
        'cond_dim'      : 64,
        'n_heads'       : 1,
        'n_blocks'      : 3,
        'dropout'       : 0.1,
        'scale_by_sigma' : False,
    },
    'optim' : {
        'lr' : 1e-3,
    }
}

config = OmegaConf.create(config)

#######
# run #
#######
# path
pass

# loggin (wandb)
pass

In [2]:
from torchvision import datasets
from torchvision import transforms as tfs

# load dataset
trans = tfs.Compose([tfs.ToTensor(), tfs.Normalize(mean=[0.0], std=[1/config.dataset.tokens])])
ds = datasets.MNIST(root='/mnt/image-net-full/gayoung.lee/yonghyun.park/', train=True, download=True, transform=trans)
ds.data = ds.data[:config.dataset.samples]
ds.targets = ds.targets[:config.dataset.samples]

dl = torch.utils.data.DataLoader(ds, batch_size=128)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch.nn as nn
import torch.nn.functional as F
import math

class SEDD(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.time_emb = TimestepEmbedder(config.model.hidden_size)
        self.pos_emb = PositionalEncoding(config.model.hidden_size, dropout=0.0)

        self.W_in = nn.Embedding(
            config.dataset.tokens + 1, 
            config.model.hidden_size,
        )
        self.W_out = nn.Linear(
            config.model.hidden_size, 
            config.dataset.tokens + 1, 
            bias=False
        )
        self.blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(d_model=config.model.hidden_size, nhead=config.model.n_heads, batch_first=False)
                for _ in range(config.model.n_blocks)
            ]
        )
        

    def forward(self, x, t):
        assert len(x.shape) == 2
        
        # in
        x = self.W_in(x)
        t_emb = self.time_emb(t)[:, None, :]
        x = x + t_emb
        
        x = x.permute(1,0,2)
        x = self.pos_emb(x)

        # mid
        for block in self.blocks:
            x = block(x)

        # out
        x = self.W_out(x)
        x = x.permute(1,0,2)
        return x
    

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, silu=True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size


    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: torch.Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


In [11]:
# load model
# from sedd.model import SEDD
model = SEDD(config)

In [12]:
# load scheduler
from sedd.scheduler import Scheduler, ScoreEntropyLoss
scheduler = Scheduler(config)
loss_fn = ScoreEntropyLoss(scheduler)

In [13]:
# prepare training
optimizer = torch.optim.Adam(model.parameters(), lr=config.optim.lr)

In [14]:
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

def validation(output_dir, name):
    scheduler.set_timesteps(num_inference_steps=1000, offset=0, device=device)

    num_batch = 5
    size = 28

    xt = (scheduler.num_vocabs - 1) * torch.ones(num_batch, size**2, dtype=torch.long) # base distribution
    xt = xt.to(device)

    for t in tqdm(scheduler.timesteps):
        if t == scheduler.timesteps[999]:
            break

        with torch.no_grad():
            # forward
            t = torch.tensor([t], device=xt.device)
            score = model(xt, t).exp()
            
            # step
            xt = scheduler.step(score, t, xt)

            # prev_t
            prev_t = scheduler.timesteps[(scheduler.timesteps == t).long().argmax()+1].unsqueeze(0).repeat(xt.size(0))
            # xt = scheduler.add_noise(xt, prev_t)

    fig, axs = plt.subplots(1, num_batch, figsize=(20, 8))
    for i in range(num_batch):
        img = xt[i].view(size, size).cpu()
        axs[i].imshow(img, cmap='gray')
    plt.savefig(os.path.join(output_dir, name))
    plt.close()
    pass

In [15]:
from tqdm import tqdm

# training
epochs = 50000

output_dir = 'runs/mnist-subset'
# output_dir = 'runs/mnist'
os.makedirs(output_dir, exist_ok=True)

model.to(device, dtype)
scheduler.to(device, dtype)

loss_traj = []
for epoch in tqdm(range(epochs)):
    
    model.train()
    for x0, _ in dl:
        x0 = x0.to(device)
        x0 = x0.flatten(start_dim=1).long()
        
        # perturb x0
        t = torch.randint(1, config.noise.num_train_timesteps, (x0.size(0),), device=device)
        xt = scheduler.add_noise(x0, t)
        
        # model forward
        log_score = model(xt, t)
        
        # compute loss function 
        loss = loss_fn(log_score, t, xt, x0)

        if loss.isnan():
            raise ValueError('loss is nan')
        
        # update
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_traj.append(loss.item())

    if epoch % 1000 == 0:
        model.eval()
        validation(output_dir, name=f'{epoch}.png')

        plt.plot(loss_traj)
        # plt.yscale('log')
        plt.savefig(os.path.join(output_dir, 'loss.png'))
        plt.close()

100%|█████████▉| 999/1000 [00:04<00:00, 239.19it/s]
100%|█████████▉| 999/1000 [00:03<00:00, 262.44it/s]/s]
100%|█████████▉| 999/1000 [00:03<00:00, 271.64it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 267.33it/s]/s] 
100%|█████████▉| 999/1000 [00:04<00:00, 215.15it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 266.98it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 269.67it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 269.86it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 265.54it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 268.24it/s]/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 265.21it/s]t/s]
100%|█████████▉| 999/1000 [00:03<00:00, 258.80it/s]t/s] 
100%|█████████▉| 999/1000 [00:04<00:00, 243.25it/s]t/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 262.68it/s]t/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 266.02it/s]t/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 257.47it/s]t/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 257.08it/s]t/s] 
100%|█████████▉| 999/1000 [00:03<00:00, 260.44i