# **PixelSNAIL: An Improved Autoregressive Generative Model 리뷰**

## Abstract

Autoregressive Generative Model → best results in high dimensional data 

ex) images or audio

as a sequence modeling task

deal with long-range dependency → ***casual convolutions with self-attention***

## Introduction

Autoregressive Generative Model은 보통 다음 식과 같이 표현될 수 있음

![joint distribution as a product of conditionals](https://user-images.githubusercontent.com/66329748/178414558-6cf8bb1c-2125-43e4-80ea-bcf3364ab850.png)

joint distribution as a product of conditionals

GAN과 비교했을 때, Neural Autoregressive Model은

- offer tractable likelihood computation
- ease of training
- outperform latent variable models

Main Design Consideration

must be able to easily refer to earlier parts of the sequence

1. Traditional RNNs
    
    정보들을 hidden state에 담아 다음 timestep으로 전달
    
    data의 long-range relationship에 방해 (-)
    
2. Casual Convolutions
    
    current prediction is only influenced by previous element
    
    high-bandwidth access to the earlier parts of the sequence (+)
    
    finite size of receptive field (-)
    
    sequence에서 먼 거리의 element로부터는 전달 감쇠 (-)
    
3. Self-Attention
    
    unbounded receptive field (+)
    
    sequence에서 먼 거리의 element에도 undeteriorated access (+)
    
    pinpoint access to small amounts of information (-)
    
    positional information 사용하려면 additional mechanism 필요 (-)
    

***Casual Convolution과 Self-Attention의 장단점이 서로 보완해주는 점***

Casual Convolution : high-bandwidth access over a finite context size (가까운 거리만)

Self-Attention : access over an infinitely large context (하지만 pinpoint access to small amount of information)

***Interleaving → high-bandwidth access without constraints on the amount of information it can effectively use***

## Model Architecture

![스크린샷 2022-07-07 오후 7.44.02.png](https://user-images.githubusercontent.com/66329748/178414679-7a6a7710-3aea-480e-a6a1-8b49a670ffe7.png)

![스크린샷 2022-07-08 오후 11.11.31.png](https://user-images.githubusercontent.com/66329748/178414723-09a63fef-3b18-4dd5-92e1-8a282aa82b13.png)

(a) residual block 

masked convolutions → current pixel은 왼쪽 위의 pixel만 access

![Untitled](https://user-images.githubusercontent.com/66329748/178414765-bea9cc64-c548-4039-bb69-5640205e4c32.png)

(b) attention block

project input to lower dimensionality to produce keys & values

## Comparison

![Untitled](https://user-images.githubusercontent.com/66329748/178414823-58cc95a9-ca3b-4d9a-b1f5-1002a7e7d8f4.png)

![Untitled](https://user-images.githubusercontent.com/66329748/178414958-65f1b1d1-47ea-471e-95be-781cd4c22eb8.png)

![Untitled](https://user-images.githubusercontent.com/66329748/178414995-78cbe6f1-1227-42ca-b353-4abafaf5344b.png)

## Conclusion

Autoregressive Generative Model

tractable likelihood (+)

strong empirical performance (+)

slow sampling (-)

# PixelSNAIL 구현

In [None]:
import torch
from torch import distributions
from torch import nn
from torch.nn import functional as F

from pytorch_generative import nn as pg_nn
from pytorch_generative.models import base

In [None]:
def _elu_conv_elu(conv, x):
    return F.elu(conv(F.elu(x)))


class ResidualBlock(nn.Module):
    """Residual block with a gated activation function."""

    def __init__(self, n_channels):
        """Initializes a new ResidualBlock.
        Args:
            n_channels: The number of input and output channels.
        """
        super().__init__()
        self._input_conv = nn.Conv2d(
            in_channels=n_channels, out_channels=n_channels, kernel_size=2, padding=1
        )
        self._output_conv = nn.Conv2d(
            in_channels=n_channels,
            out_channels=2 * n_channels,
            kernel_size=2,
            padding=1,
        )
        self._activation = pg_nn.GatedActivation(activation_fn=nn.Identity())

    def forward(self, x):
        _, c, h, w = x.shape
        out = _elu_conv_elu(self._input_conv, x)[:, :, :h, :w]
        out = self._activation(self._output_conv(out)[:, :, :h, :w]) #elu conv elu 후에 찢어서 한쪽만 sigmoid하고 다시 elementwise mul하는 과정을 생략한듯. 바로 Input과 +
        return x + out

In [None]:
class PixelSNAILBlock(nn.Module):
    """Block comprised of a number of residual blocks plus one attention block.
    Implements Figure 5 of [1].
    """

    def __init__(
        self,
        n_channels,
        input_img_channels=1,
        n_residual_blocks=2,
        attention_key_channels=4,
        attention_value_channels=32,
    ):
        """Initializes a new PixelSnailBlock instance.
        Args:
            n_channels: Number of input and output channels.
            input_img_channels: The number of channels in the original input_img. Used
                for the positional encoding channels and the extra channels for the key
                and value convolutions in the attention block.
            n_residual_blocks: Number of residual blocks.
            attention_key_channels: Number of channels (dims) for the attention key.
            attention_value_channels: Number of channels (dims) for the attention value.
        """
        super().__init__()

        def conv(in_channels):
            return nn.Conv2d(in_channels, out_channels=n_channels, kernel_size=1)

        self._residual = nn.Sequential(
            *[ResidualBlock(n_channels) for _ in range(n_residual_blocks)] # Figure1의 (a)보면 residual block이 R번 반복됨
        )
        self._attention = pg_nn.CausalAttention(
            in_channels=n_channels + 2,
            embed_channels=attention_key_channels,
            out_channels=attention_value_channels,
            mask_center=True,
            extra_input_channels=input_img_channels,
        )
        self._residual_out = conv(n_channels)
        self._attention_out = conv(attention_value_channels)
        self._out = conv(n_channels)

    def forward(self, x, input_img):
        """Computes the forward pass.
        Args:
            x: The input.
            input_img: The original image only used as input to the attention blocks.
        Returns:
            The result of the forward pass.
        """
        res = self._residual(x)
        pos = pg_nn.image_positional_encoding(input_img.shape).to(res.device) # attention 초기 작업
        """Generates positional encodings for 2d images.
        The positional encoding is a Tensor of shape (N, 2, H, W) of (x, y) pixel
        coordinates scaled to be between -.5 and .5.
        Args:
            shape: NCHW shape of image for which to generate positional encodings.
        Returns:
            The positional encodings.
        """
        attn = self._attention(torch.cat((pos, res), dim=1), input_img)
        res, attn = (
            _elu_conv_elu(self._residual_out, res),
            _elu_conv_elu(self._attention_out, attn),
        )
        return _elu_conv_elu(self._out, res + attn)

In [None]:
class PixelSNAIL(base.AutoregressiveModel):
    """The PixelSNAIL model.
    Unlike [1], we implement skip connections from each block to the output.
    We find that this makes training a lot more stable and allows for much faster
    convergence.
    """

    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        n_channels=64,
        n_pixel_snail_blocks=8,
        n_residual_blocks=2,
        attention_key_channels=4,
        attention_value_channels=32,
        sample_fn=None,
    ):
        """Initializes a new PixelSNAIL instance.
        Args:
            in_channels: Number of input channels.
            out_channels: Number of output_channels.
            n_channels: Number of channels to use for convolutions.
            n_pixel_snail_blocks: Number of PixelSNAILBlocks.
            n_residual_blocks: Number of ResidualBlock to use in each PixelSnailBlock.
            attention_key_channels: Number of channels (dims) for the attention key.
            attention_value_channels: Number of channels (dims) for the attention value.
            sample_fn: See the base class.
        """
        super().__init__(sample_fn)
        self._input = pg_nn.CausalConv2d(
            mask_center=True,
            in_channels=in_channels,
            out_channels=n_channels,
            kernel_size=3,
            padding=1,
        )
        self._pixel_snail_blocks = nn.ModuleList(
            [
                PixelSNAILBlock(
                    n_channels=n_channels,
                    input_img_channels=in_channels,
                    n_residual_blocks=n_residual_blocks,
                    attention_key_channels=attention_key_channels,
                    attention_value_channels=attention_value_channels,
                )
                for _ in range(n_pixel_snail_blocks)
            ]
        )
        self._output = nn.Sequential(
            nn.Conv2d(
                in_channels=n_channels, out_channels=n_channels // 2, kernel_size=1
            ),
            nn.Conv2d(
                in_channels=n_channels // 2, out_channels=out_channels, kernel_size=1
            ),
        )

    def forward(self, x):
        input_img = x
        x = self._input(x)
        for block in self._pixel_snail_blocks:
            x = x + block(x, input_img) # skip connection. 논문과는 조금 다른 구현
        return self._output(x)





In [None]:
def reproduce(
    n_epochs=457,
    batch_size=128,
    log_dir="/tmp/run",
    n_gpus=1,
    device_id=0,
    debug_loader=None,
):
    """Training script with defaults to reproduce results.
    The code inside this function is self contained and can be used as a top level
    training script, e.g. by copy/pasting it into a Jupyter notebook.
    Args:
        n_epochs: Number of epochs to train for.
        batch_size: Batch size to use for training and evaluation.
        log_dir: Directory where to log trainer state and TensorBoard summaries.
        n_gpus: Number of GPUs to use for training the model. If 0, uses CPU.
        device_id: The device_id of the current GPU when training on multiple GPUs.
        debug_loader: Debug DataLoader which replaces the default training and
            evaluation loaders if not 'None'. Do not use unless you're writing unit
            tests.
    """
    from torch import optim
    from torch.nn import functional as F
    from torch.optim import lr_scheduler

    from pytorch_generative import datasets
    from pytorch_generative import models
    from pytorch_generative import trainer

    train_loader, test_loader = debug_loader, debug_loader
    if train_loader is None:
        train_loader, test_loader = datasets.get_mnist_loaders(
            batch_size, dynamically_binarize=True
        )

    model = models.PixelSNAIL(
        in_channels=1,
        out_channels=1,
        n_channels=64,
        n_pixel_snail_blocks=8,
        n_residual_blocks=2,
        attention_value_channels=32,  # n_channels / 2
        attention_key_channels=4,  # attention_value_channels / 8
    )
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda _: 0.999977)

    def loss_fn(x, _, preds):
        batch_size = x.shape[0]
        x, preds = x.view((batch_size, -1)), preds.view((batch_size, -1))
        loss = F.binary_cross_entropy_with_logits(preds, x, reduction="none")
        return loss.sum(dim=1).mean()

    trainer = trainer.Trainer(
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        train_loader=train_loader,
        eval_loader=test_loader,
        lr_scheduler=scheduler,
        log_dir=log_dir,
        n_gpus=n_gpus,
        device_id=device_id,
    )
    trainer.interleaved_train_and_eval(n_epochs)