In [1]:
import torch
import dataclasses


@dataclasses.dataclass
class Hyperparameter:
    n_channels: int = 256
    n_res_layers: int = 5
    n_attn_layers: int = 12
    attn_n_hidden: int = 1
    attn_d_query: int = 16
    attn_d_value: int = 128
    attn_drop_rate: float = 0
    n_logistic_mix: int = 10
    seed: int = 0
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    n_cond_classes: int = 10
    n_bits: int = 4
    image_dims: tuple = (1, 28, 28)
    lr: float = 5e-4
    lr_decay: float = 0.999995
    polyak: float = 0.9995
    batch_size: int = 4
    n_epochs: int = 1
    step: int = 0
    start_epoch: int = 0
    log_interval: int = 50
    eval_interval: int = 10
    n_samples: int = 8
    num_workers: int = 0
    dataset_root: str = 'C:\\Users\\tama0\\.data'
    is_mini_dataset: bool = True
    
    
hp = Hyperparameter

In [2]:
import torchvision
import torchvision.transforms as T

from functools import partial


def preprocess(x, n_bits):
    return x.float().div(2**n_bits - 1).mul(2).add(-1)

transform = T.Compose([
    T.ToTensor(), 
    lambda x: x.mul(255).div(2**(8-hp.n_bits)).floor(), 
    partial(preprocess, n_bits=hp.n_bits)
])  
target_transform = \
    (lambda y: torch.eye(hp.n_cond_classes)[y]) if hp.n_cond_classes else None

train_dataset = torchvision.datasets.MNIST(hp.dataset_root, train=True, download=True,
                                           transform=transform, target_transform=target_transform)
test_dataset = torchvision.datasets.MNIST(hp.dataset_root, train=False, download=True,
                                          transform=transform, target_transform=target_transform)

if hp.is_mini_dataset:
    train_dataset.data = train_dataset.data[:hp.batch_size*4]
    train_dataset.targets = train_dataset.targets[:hp.batch_size*4]

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hp.batch_size, shuffle=True, num_workers=hp.num_workers)

In [3]:
import torch.nn as nn
import torch.nn.functional as F


class Conv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        nn.utils.weight_norm(self)
        
class DownShiftedConv2d(Conv2d):
    def forward(self, x):
        hk, wk = self.kernel_size
        x = F.pad(x, ((wk-1)//2, (wk-1)//2, hk-1, 0))
        return super().forward(x)
    
class DownRightShiftedConv2d(Conv2d):
    def forward(self, x):
        hk, wk = self.kernel_size
        x = F.pad(x, (wk-1, 0, hk-1, 0))
        return super().forward(x)
        

![](https://github.com/crlotwhite/ML_Study/blob/pixelsnail/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-19%2004-10-42-669.png?raw=true)

In [4]:
class GatedResidualBlock(nn.Module):
    def __init__(self, conv, n_channels, kernel_size, 
                 drop_rate=0, shortcut_channels=None, n_cond_classes=None):
        super().__init__()
        
        self.conv_1 = conv(2 * n_channels, n_channels, kernel_size)
        if shortcut_channels:
            self.conv_1_sc = Conv2d(2 * shortcut_channels, n_channels, kernel_size=1)
        if drop_rate > 0:
            self.drop_1 = nn.Dropout(drop_rate)
        self.conv_2 = conv(2 * n_channels, 2 * n_channels, kernel_size)
        
        if n_cond_classes:
            self.proj_h = nn.Linear(n_cond_classes, 2 * n_channels)
            
    def forward(self, x, a=None, h=None):
        c1 = self.conv_1(F.elu(torch.cat([x, -x], dim=1)))
        if a is not None:
            c1 = c1 + self.conv_1_sc(F.elu(torch.cat([a, -a], dim=1)))
        c1 = F.elu(torch.cat([c1, -c1], dim=1))
        if hasattr(self, 'drop_1'):
            c1 = self.drop_1(c1)
        c2 = self.conv_2(c1)
        if h is not None:
            c2 += self.proj_h(h)[:, :, None, None]
        a, b = c2.chunk(2, 1)
        out = x + a * torch.sigmoid(b)
        
        return out

![](https://github.com/crlotwhite/ML_Study/blob/pixelsnail/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-19%2004-10-47-981.png?raw=true)

In [5]:
class AttentionGatedResidualBlock(nn.Module):
    def __init__(self, n_channels, n_background_ch, n_res_layers, n_cond_classes,
                 drop_rate, n_hidden, d_query, d_value, attn_drop_rate):
        super().__init__()
        
        self.n_hidden = n_hidden
        self.d_query = d_query
        self.d_value = d_value
        self.attn_drop_rate = attn_drop_rate
        
        self.input_gated_resnet = nn.ModuleList([
            GatedResidualBlock(DownRightShiftedConv2d, n_channels, (2, 2), 
                               drop_rate, None, n_cond_classes) 
            for _ in range(n_res_layers)])
        
        self.in_proj_kv = nn.Sequential(
            GatedResidualBlock(Conv2d, 2 * n_channels + n_background_ch, 1, 
                               drop_rate, None, n_cond_classes), 
            Conv2d(2 * n_channels + n_background_ch, d_query + d_value, 1)
        )
        self.in_proj_q = nn.Sequential(
            GatedResidualBlock(Conv2d, n_channels + n_background_ch, 1, 
                               drop_rate, None, n_cond_classes),
            Conv2d(n_channels + n_background_ch, d_query, 1)
        )
        self.out_proj = GatedResidualBlock(Conv2d, n_channels, 1, drop_rate, 
                                           d_value, n_cond_classes)
    
    def forward(self, x, background, attn_mask, h=None):
        ul = x
        for layer in self.input_gated_resnet:
            ul = layer(ul, h=h)
        
        kv = self.in_proj_kv(torch.cat([x, ul, background], dim=1))
        k, v = kv.split([self.d_query, self.d_value], dim=1)
        q = self.in_proj_q(torch.cat([ul, background], dim=1))
        
        B, dq, H, W = q.shape
        _, dv, _, _ = v.shape
        
        flat_q = q.reshape(B, self.n_hidden,  dq//self.n_hidden, H, W).flatten(3) \
                 * (dq//self.n_hidden)**-0.5
        flat_k = k.reshape(B, self.n_hidden,  dq//self.n_hidden, H, W).flatten(3)
        flat_v = v.reshape(B, self.n_hidden,  dv//self.n_hidden, H, W).flatten(3)
        
        logits = flat_q.transpose(2, 3) @ flat_k
        logits = F.dropout(logits, p=self.attn_drop_rate, 
                           training=self.training, inplace=True)
        logits = logits.masked_fill(attn_mask == 0, -1e10)
        weights = F.softmax(logits, dim=-1)
        
        attn_out = weights @ flat_v.transpose(2, 3)
        attn_out = attn_out.transpose(2, 3)
        attn_out = attn_out.reshape(B, -1, H, W)
        return self.out_proj(ul, attn_out)
        

![](https://github.com/crlotwhite/ML_Study/blob/pixelsnail/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/Cap%202024-08-19%2004-11-03-466.png?raw=true)

In [6]:
def down_shift(x):
    return F.pad(x, (0,0,1,0))[:,:,:-1,:]


def right_shift(x):
    return F.pad(x, (1,0))[:,:,:,:-1]


class PixelSNAIL(nn.Module):
    def __init__(self, image_dims=(3, 32, 32), n_channels=128, n_res_layers=5, 
                 attn_n_layers=12, attn_n_hidden=1, attn_d_query=16, 
                 attn_d_value=128, attn_drop_rate=0, n_logistic_mix=10, 
                 n_cond_classes=None, drop_rate=0.5):
        super().__init__()
        
        C, H, W = image_dims
        background_v = (((torch.arange(H, dtype=torch.float) - H / 2) / 2)
                        .view(1, 1, -1, 1).expand(1, C, H, W))
        background_h = (((torch.arange(W, dtype=torch.float) - W / 2) / 2)
                        .view(1, 1, 1, -1).expand(1, C, H, W))
        self.register_buffer('background', torch.cat([background_v, background_h], 1))
        attn_mask = torch.tril(torch.ones(1, 1, H*W, H*W), diagonal=-1).byte()
        self.register_buffer('attn_mask', attn_mask)
        
        self.ul_input_d = DownShiftedConv2d(C+1, n_channels, kernel_size=(1, 3))
        self.ul_input_dr = DownRightShiftedConv2d(C+1, n_channels, kernel_size=(2, 1))
        
        self.ul_modules = nn.ModuleList([
            AttentionGatedResidualBlock(
                n_channels, self.background.shape[1], n_res_layers, n_cond_classes, 
                drop_rate, attn_n_hidden, attn_d_query, attn_d_value, attn_drop_rate) 
            for _ in range(attn_n_layers)])
        self.output_conv = Conv2d(n_channels, (3 * C + 1) * n_logistic_mix, 1)
        
    def forward(self, x, h=None):
        x = F.pad(x, (0, 0, 0, 0, 0, 1), value=1)
        
        ul = down_shift(self.ul_input_d(x)) + right_shift(self.ul_input_dr(x))
        
        for module in self.ul_modules:
            ul = module(ul, self.background.expand(x.shape[0],-1,-1,-1), 
                        self.attn_mask, h)
            
        return self.output_conv(F.elu(ul))

**Polyak 평균**  
훈련 파라미터에 대한 Polyak 평균을 사용합니다.  
이는 파라미터의 변화가 부드럽게 진행되도록 도와줍니다.    
CIFAR-10의 경우, 0.9995의 지수 이동 평균(Exponential Moving Average) 가중치를 사용하고, ImageNet의 경우 0.9997를 사용합니다.

In [7]:
class Adam(torch.optim.Adam):
    def __init__(self, *args, polyak=0.0, **kwargs):
        if not 0.0 <= polyak <= 1.0:
            raise ValueError('polyak value must be in [0, 1]')
        super().__init__(*args, **kwargs)
        self.defaults['polyak'] = polyak
        
    def step(self, closure=None):
        super().step(closure)
        
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                
                if 'ema' not in state:
                    state['ema'] = torch.zeros_like(p.data)
                    
                state['ema'] -= (1 - self.defaults['polyak']) * (state['ema'] - p.data)
                
    def swap_ema(self):
        for group in self.param_groups:
            for p in group['params']:
                data = p.data
                state = self.state[p]
                p.data = state['ema']
                state['ema'] = data
    
    def __repr__(self):
        s = super().__repr__()
        return self.__class__.__mro__[1].__name__ \
            + ' (\npolyak: {}\n'.format(self.defaults['polyak']) \
            + s.partition('\n')[2]

**Loss Function**

- **모델 개요**:
  - 이 모델은 **변분 오토인코더(Variational Autoencoder, VAE)**의 아이디어를 활용하여 색상 강도인 **ν**를 모델링합니다.
  - ν는 **연속 분포**를 가정하며, 관측된 서브 픽셀 값 **x**는 ν를 가장 가까운 8비트 표현으로 반올림하여 얻습니다.
- **연속 분포 선택**:
  - ν를 모델링하기 위해 **로지스틱 분포(Logistic distribution)** 같은 단순한 연속 분포를 선택합니다.
  - 이를 통해 **x**에 대한 부드럽고 메모리 효율적인 예측 분포를 생성합니다.
- **로그 가능도 및 혼합 모델**:
  - ν는 여러 로지스틱 분포의 혼합으로 표현됩니다:
    $$
    \nu \sim \sum_{i=1}^{K} \pi_i \text{logistic}(\mu_i, s_i) \quad (1)
    $$
    - 여기서 **K**는 혼합 성분의 수입니다.
    - **πᵢ**는 각 성분의 혼합 가중치입니다.
    - **μᵢ**와 **sᵢ**는 각각 성분의 평균과 확산을 나타냅니다.
- **관측된 값 x의 가능도**:
  - 관측된 값 **x**의 확률은 다음과 같이 주어집니다:
    $$
    P(x | \pi, \mu, s) = \sum_{i=1}^{K} \pi_i [\sigma((x + 0.5 - \mu_i) / s_i) - \sigma((x - 0.5 - \mu_i) / s_i)] \quad (2)
    $$
    - **σ()**는 로지스틱 시그모이드 함수입니다.
    - 앞의 식에서 **σ(·)**는 특정 값을 0과 1 사이로 변환합니다, 이는 누적 분포 함수의 역할을 합니다.
  - 경계 값 처리:
    - 0의 경우: \( x - 0.5 \)를 \( -\infty \)로 대체합니다.
    - 255의 경우: \( x + 0.5 \)를 \( +\infty \)로 대체합니다.


In [8]:
def discretized_mix_logistic_loss(l, x, n_bits):
    B, C, H, W = x.shape
    n_mix = l.shape[1] // (1 + 3 * C)

    logits = l[:, :n_mix, :, :]
    l = l[:, n_mix:, :, :].reshape(B, 3 * n_mix, C, H, W)
    means, logscales, coeffs = l.split(n_mix, 1)
    logscales = logscales.clamp(min=-7)
    coeffs = coeffs.tanh()

    x = x.unsqueeze(1).expand_as(means)
    if C != 1:
        m1 = means[:, :, 0, :, :]
        m2 = means[:, :, 1, :, :] + coeffs[:, :, 0, :, :] * x[:, :, 0, :, :]
        m3 = means[:, :, 2, :, :] + coeffs[:, :, 1, :, :] * x[:, :, 0, :, :] + coeffs[:, :, 2, :, :] * x[:, :, 1, :, :]
        means = torch.stack([m1, m2, m3], 2)

    scales = torch.exp(-logscales)
    plus = scales * (x - means + 1 / (2 ** n_bits - 1))
    minus = scales * (x - means - 1 / (2 ** n_bits - 1))

    cdf_minus = torch.sigmoid(minus)
    log_one_minus_cdf_minus = -F.softplus(minus)
    cdf_plus = torch.sigmoid(plus)
    log_cdf_plus = plus - F.softplus(plus)

    log_probs = torch.where(x < -0.999, log_cdf_plus,
                            torch.where(x > 0.999, log_one_minus_cdf_minus,
                                        torch.log((cdf_plus - cdf_minus).clamp(min=1e-12))))
    log_probs = log_probs.sum(2) + F.log_softmax(logits, 1)

    return -log_probs.logsumexp(1).sum([1, 2])


In [9]:
from tqdm import tqdm

def sample_from_discretized_mix_logistic(l, image_dims):
    B, _, H, W = l.shape
    C = image_dims[0]
    n_mix = l.shape[1] // (1 + 3 * C)

    logits = l[:, :n_mix, :, :]
    l = l[:, n_mix:, :, :].reshape(B, 3 * n_mix, C, H, W)
    means, logscales, coeffs = l.split(n_mix, 1)
    logscales = logscales.clamp(min=-7)
    coeffs = coeffs.tanh()

    argmax = torch.argmax(logits - torch.log(-torch.log(torch.rand_like(logits).uniform_(1e-5, 1 - 1e-5))), dim=1)
    sel = torch.eye(n_mix, device=logits.device)[argmax].permute(0, 3, 1, 2).unsqueeze(2)

    means = means.mul(sel).sum(1)
    logscales = logscales.mul(sel).sum(1)
    coeffs = coeffs.mul(sel).sum(1)

    u = torch.rand_like(means).uniform_(1e-5, 1 - 1e-5)
    x = means + logscales.exp() * (torch.log(u) - torch.log1p(-u))

    if C == 1:
        return x.clamp(-1, 1)
    
    x0 = torch.clamp(x[:, 0, :, :], -1, 1)
    x1 = torch.clamp(x[:, 1, :, :] + coeffs[:, 0, :, :] * x0, -1, 1)
    x2 = torch.clamp(x[:, 2, :, :] + coeffs[:, 1, :, :] * x0 + coeffs[:, 2, :, :] * x1, -1, 1)
    return torch.stack([x0, x1, x2], 1)

def generate_fn(model, n_samples, image_dims, device, h=None):
    out = torch.zeros(n_samples, *image_dims, device=device)
    with tqdm(total=(image_dims[1] * image_dims[2]), desc='Generating {} images'.format(n_samples)) as pbar:
        for yi in range(image_dims[1]):
            for xi in range(image_dims[2]):
                l = model(out, h)
                out[:, :, yi, xi] = sample_from_discretized_mix_logistic(l, image_dims)[:, :, yi, xi]
                pbar.update()
    return out


In [10]:
model = (PixelSNAIL(hp.image_dims, hp.n_channels, hp.n_res_layers, hp.n_attn_layers, 
                   hp.attn_n_hidden, hp.attn_d_query, hp.attn_d_value, 
                   hp.attn_drop_rate, hp.n_logistic_mix, hp.n_cond_classes)
         .to(hp.device))

optimizer = Adam(model.parameters(), lr=hp.lr, betas=(0.95, 0.9995), 
                       polyak=hp.polyak, eps=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, hp.lr_decay)

  WeightNorm.apply(module, name, dim)


In [11]:
import numpy as np


def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, epoch, hp):
    model.train()

    with tqdm(total=len(dataloader), desc='epoch {}/{}'.format(epoch, hp.start_epoch + hp.n_epochs)) as pbar:
        for x,y in dataloader:
            hp.step += 1

            x = x.to(hp.device)
            logits = model(x, y.to(hp.device) if hp.n_cond_classes else None)
            loss = loss_fn(logits, x, hp.n_bits).mean(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler: 
                scheduler.step()

            pbar.set_postfix(bits_per_dim='{:.4f}'.format(loss.item() / (np.log(2) * np.prod(hp.image_dims))))
            pbar.update()

In [12]:
@torch.no_grad()
def evaluate(model, dataloader, loss_fn, hp):
    model.eval()

    losses = 0
    for x,y in tqdm(dataloader, desc='Evaluate'):
        x = x.to(hp.device)
        logits = model(x, y.to(hp.device) if hp.n_cond_classes else None)
        losses += loss_fn(logits, x, hp.n_bits).mean(0).item()
    return losses / len(dataloader)

In [13]:
@torch.no_grad()
def generate(model, generate_fn, hp):
    model.eval()
    if hp.n_cond_classes:
        samples = []
        for h in range(hp.n_cond_classes):
            h = torch.eye(hp.n_cond_classes)[h,None].to(hp.device)
            samples += [generate_fn(model, hp.n_samples, hp.image_dims, hp.device, h=h)]
        samples = torch.cat(samples)
    else:
        samples = generate_fn(model, hp.n_samples, hp.image_dims, hp.device)
    return torchvision.utils.make_grid(samples.cpu(), normalize=True, scale_each=True, nrow=hp.n_samples)

In [14]:
import os


def train_and_evaluate(model, train_dataloader, test_dataloader, 
                       optimizer, scheduler, loss_fn, generate_fn, hp):
    for epoch in range(hp.start_epoch, hp.start_epoch + hp.n_epochs):
        # train
        train_epoch(model, train_dataloader, 
                    optimizer, scheduler, loss_fn, epoch, hp)

        if (epoch+1) % hp.eval_interval == 0:
            # save model
            torch.save({'epoch': epoch,
                        'global_step': hp.step,
                        'state_dict': model.state_dict()},
                        os.path.join(hp.output_dir, 'checkpoint.pt'))
            torch.save(optimizer.state_dict(), os.path.join(hp.output_dir, 'optim_checkpoint.pt'))
            if scheduler: torch.save(scheduler.state_dict(), os.path.join(hp.output_dir, 'sched_checkpoint.pt'))

            # swap params to ema values
            optimizer.swap_ema()

            # evaluate
            eval_loss = evaluate(model, test_dataloader, loss_fn, hp)
            print('Evaluate bits per dim: {:.3f}'.format(eval_loss.item() / (np.log(2) * np.prod(hp.image_dims))))

            # generate
            samples = generate(model, generate_fn, hp)
            torchvision.utils.save_image(samples, os.path.join(hp.output_dir, 'generation_sample_step_{}.png'.format(hp.step)))

            # restore params to gradient optimized
            optimizer.swap_ema()

In [15]:
hp.output_dir = os.path.join(os.getcwd(), 'assets')
train_and_evaluate(model, train_loader, test_loader, optimizer, scheduler,
                   discretized_mix_logistic_loss, generate_fn, hp)

epoch 0/1: 100%|██████████| 4/4 [00:02<00:00,  1.34it/s, bits_per_dim=3.6236]


In [19]:
from torchviz import make_dot

x = torch.zeros(1, 1, 28, 28).to(hp.device)
(make_dot(model(x), params=dict(list(model.named_parameters())), show_attrs=True, show_saved=True)
 .render('pixelsnail', format='svg'))


'pixelsnail.svg'

레이어 구조: https://github.com/crlotwhite/ML_Study/blob/pixelsnail/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/pixelsnail.svg

In [20]:
samples = generate(model, generate_fn, hp)
torchvision.utils.save_image(samples, 
                             os.path.join(hp.output_dir, 
                                          'generation_sample_step_{}.png'.format(hp.step)))

Generating 8 images: 100%|██████████| 784/784 [03:29<00:00,  3.74it/s]
Generating 8 images: 100%|██████████| 784/784 [03:28<00:00,  3.76it/s]
Generating 8 images: 100%|██████████| 784/784 [03:29<00:00,  3.74it/s]
Generating 8 images: 100%|██████████| 784/784 [03:27<00:00,  3.78it/s]
Generating 8 images: 100%|██████████| 784/784 [03:26<00:00,  3.80it/s]
Generating 8 images: 100%|██████████| 784/784 [03:28<00:00,  3.75it/s]
Generating 8 images: 100%|██████████| 784/784 [03:28<00:00,  3.75it/s]
Generating 8 images: 100%|██████████| 784/784 [03:25<00:00,  3.81it/s]
Generating 8 images: 100%|██████████| 784/784 [03:26<00:00,  3.80it/s]
Generating 8 images: 100%|██████████| 784/784 [03:26<00:00,  3.79it/s]


**출력물**

![](https://github.com/crlotwhite/ML_Study/blob/pixelsnail/%EB%85%BC%EB%AC%B8%EA%B5%AC%ED%98%84/generative/assets/generation_sample_step_4.png?raw=true)

학습률이 매우 낮아서 확실히 결과가 나오지 않았습니다.