# Review: Conditional Image Generation with PixelCNN Decoders

## Paper Reviews

### Prior Reseach

+ image distribution을 특정 pixel에서부터 dependency가 있다고 가정하면 seqential하게 conditional distribution으로 표현 가능함   

$$ p(\mathbf{x}) = \prod^{n^2}_{i=1}p(x_i|x_1, ... , x_{i-1}), \, \mathbf{x} \in \mathbb{R}^{n \times n}$$

+ 이러한 concept를 구현한 Autoregressive generative model은 PixelRNN(Diagonal BiLSTM)으로 성공적인 접근법임을 보였음  
+ 특히 explict하게 distribution이 표현되고, model이 diverse하게 distribution을 generate 할 수 있다는 점은 고유의 장점
+ 그러나 PixelRNN은 BiLSTM 구조를 사용하였기 때문에 병렬화에 불리했으며 inference time이 너무 느렸음

<img src="https://user-images.githubusercontent.com/86907286/161759587-e5dc5ef6-6790-4fa4-a904-b55ecbec5b3f.JPG" alt="1" width="400px" align="center" />

+ 이 때 PixelRNN에 Diagnoal BiLSTM 구조보다 성능이 떨어졌지만, Convolution layer를 활용한 **PixelCNN** 또한 제안되었음
+ Convolution layer는 RNN보다 병렬화에 유리하기 때문에, 이를 개선할 수 있다면 PixelRNN의 단점을 극복이 가능함  
+ 이를 위해 저자는 **Gated PixelCNN** architecture를 제안

<img src="https://user-images.githubusercontent.com/86907286/161759591-348cdb0c-c84c-4bc0-9f1d-f67beb5c00e3.JPG" alt="2" width="600px" align="center"/>

### Problem of PixelCNN

+ PixelCNN은 raseter scan sequence로 pixel 간의 dependency가 있다고 가정한 모델
+ 따라서 raster scan sequence 상 이전 pixel만을 covolution하는 **Masked Convolution**으로 이를 구현함 
+ Masking을 통해서 generate하려는 pixel 이후의 sequence에 대해서는 반영하지 않게 하여 dependece를 구현  

<img src="https://user-images.githubusercontent.com/86907286/161759593-6ab9412c-978f-4a01-a91b-c1f05d66db23.JPG" alt="3" width="400px" align="center"/>

+ 그러나 이러한 Masked Convolution을 사용하는 PixelCNN은 3가지 problem이 존재
1. 한 pixel에 대해서 BiLSTM보다 상대적으로 짧은 dependecy length를 가짐  
→ BiLSTM은 previous sequence가 hidden input으로 반영되지만, convolution을 사용할 시에는 neihborhood pixel만이 반영됨
2. Masked Convolution은 본질적으로 target pixel의 우상향 대각선으로 **Blind Spot**이 발생함  
→ 아무리 layer를 stack해도 Masking의 형태 상 local receptive field에서 반영되지 않는 previous sequence가 존재하게 됨
3. BiLSTM은 gate가 존재하여 보다 복잡한 nonlinearity를 학습할 수 있음  
→ Convolution layer는 output으로 하나의 activation이 작용하지만, BiLSTM은 forget, input, output마다 다른 activation이 적용됨


|Masked Convolution|Result|
|:-:|:-:|
| <img src="https://user-images.githubusercontent.com/86907286/161759594-fdd2ccbc-43c0-4856-9a4f-a8d36d8fcdff.gif" alt="4" width="200px"/> | <img src="https://user-images.githubusercontent.com/86907286/161759596-c2bea2d9-d985-4cd5-a1b7-cc026f9a6788.png" alt="5" width="200px"/>|

<img src="https://user-images.githubusercontent.com/86907286/161759599-5f9e6ac9-3151-4a0b-919f-d4d93397374e.JPG" alt="6" width="400px" align="center" />

### Vertical Stack, Horizontal Stack

+ Problem 2을 위해 Masked CNN이 아닌 **Vertical Stack**과 **Horizontal Stack**으로 구성된 2개의 sperated layer를 결합하는 방법을 제안
+ 특정한 pixel을 generate할 때 해당 pixel과 dependency가 존재하는 pixel는 2가지로 나눌 수 있음
  1. 같은 row에 있는 previous sequence에 해당하는 pixel
  2. 이전 row에 있는 모든 pixel
+ 각 유형만을 rocal receptive field로 cover하는 두 종류의 convolution filter를 사용하여 각각의 feature map을 결합
+ 이 때 중요한 것은  **각 pixel의 dependecy를 유지**하는 것!
  

<img src="https://user-images.githubusercontent.com/86907286/161759602-398bd9c0-c0f6-4916-be92-4dc4c25081a0.JPG" alt="8" width="200px" align="center" />

<img src="https://user-images.githubusercontent.com/86907286/161759605-58a6bc3b-d939-441c-b0a2-67daa49f8503.JPG" alt="9" width="400px" align="center" />

+ 3x3 size convolution을 Masked Convolution을 사용할 경우의 filter size라고 가정
+ 이 때, Vertical Stack은 2x1, Horizontal Stack은 3x2 filter로 표현될 수 있음
+ 그러나 원본 feature map에 바로 두 filter를 적용할 시에 Vertical convolution은 **다음 sequence pixel의 정보**를 참고함  
+ 이는 image distribution이 각 pixel의 **conditional sequence로 표현될 수 있다는 가정에서 위배됨**

<img src="https://user-images.githubusercontent.com/86907286/161759610-4746ea38-2369-4935-8381-93ae0d552f62.JPG" alt="10" width="300px" align="center" />

+ 따라서 Verical Stack filter에 대해서는 **input feature map의 상단에 1 row padding을 추가**  
+ Horizontal Stack filter가 적용되는 pixel index의 Verical Stack filter의 feature map pixel은 **1 row 과거가 됨**
+ 이렇게 진행 후 최종 Vertical Stack feature map에서 **최하단 1 row를 crop**
+ 따라서 Vertical Stack feature map의 각 pixel은 **원본 feature map pixel의 vertical local dependecy를 표현하게 됨**
+ 이렇게 얻은 두 feature map을 결합하게 되면 **Blind Point**이 더 이상 발생하지 않음
+ 동시에 masking을 통해 매번 불필요한 convolution이 사라지는 효과를 통해 efficient하게 model을 구성 가능


<img src="https://user-images.githubusercontent.com/86907286/161759612-fb52eb3a-0f44-4de9-8b54-70ae97b1ff0e.JPG" alt="11" width="700px" align="center" />

|Masked Convolution|Vertical/Horizontal Stack|
|:-:|:-:|
| <img src="https://user-images.githubusercontent.com/86907286/161759567-07ca3354-c724-4017-b530-802e46ef11ed.gif" alt="11-1" width="300px" /> | <img src="https://user-images.githubusercontent.com/86907286/161759570-bdfcf60d-9357-4e9c-941f-53e8337d26a2.gif" alt="11-2" width="300px" />|

### Gated Convolutional Layers

+ problem 1에 대해서 저자는 layer의 많은 stack을 통해서 local receptive field를 넓게 확보하면 극복할 수 있다고 설명
+ 그러나 problem 3은 Masked Convolution의 근본적인 문제이기 때문에 개선이 필요함  
+ 따라서 저자는 covolution이 적용된 후의 feature map을 두개의 flow로 만들어 서로 다른 activation을 적용
+ 하나는 **feature map의 선택을 담당하도록 sigmoid**, 다른 하나는 **feature map의 normalize를 위해 tanh** 적용  
+ 이후 element-wise product을 통해서 feature map의 특정 element를 augment/diminish 할 수 있도록 구성

$$ \mathbf{y} = \text{tanh}(W_{k, f} * \mathbf{x}) \odot \sigma (W_{k,g} * \mathbf{x}) $$

+ 이러한 gate operation은 Vertical stack이 horizontal stack에 반영되기 전에 적용됨
+ **gate operation 이전에 1x1 convolution을 통해서 Vertical flow를 Horizontal flow에 반영**
+ 이 때 Horizontal flow는 첫번째 layer를 제외하고 **residual connection**이 적용되며, 최종 prediction에 사용되는 flow

<img src="https://user-images.githubusercontent.com/86907286/161759573-7565968c-888a-4d0f-9c6d-cdffa787c689.JPG" alt="12" width="500px" align="center" />

### Conditional PixelCNN

+ 지금까지 살펴 본 구조는 uncoditional한 조건에서의 generative model이었음
+ 만약 적절한 condition 정보가 주어진 상황에서의 distribution을 알고 싶다면 **latent vector $\mathbf{h}$를 도입**하여 이를 표현 가능


$$ p(\mathbf{x}|\mathbf{h}) = \prod^{n^2}_{i=1}p(x_i|x_1, ... , x_{i-1}, \mathbf{h}), \, \mathbf{x} \in \mathbb{R}^{n \times n}, \, \mathbf{h} \in \mathbb{R}^d$$

+ 저자는 이를 각 Gated Convolution Layer에서 gated activation을 적용하기 전에 latent vector를 더해주는 것으로 구현하였음  
+ 이 때, 각 layer의 dimension을 맞추기 위한 적절한 transform $V$를 적용

$$ \mathbf{y} = \text{tanh}(W_{k, f} * \mathbf{x} + V^T_{k,f} \cdot \mathbf{h}) \odot \sigma (W_{k,g} * \mathbf{x} + V^T_{k,g} \cdot \mathbf{h}) $$

+ 만약 $\mathbf{h}$가 class를 표현하는 one-hot vector라면 이는 class 정보를 담는 bias가 됨  
+ 그러나 이는 **보다 복잡한 condition, 예를 들어 position이나 pose에 대한 정보가 부족**
+ 따라서 이러한 정보가 embedding된 representation $\mathbf{s}$를 생성 할 수 있다면 보다 복잡한 representation을 표현 가능
+ 이를 생성할 수 있는 decovolution neural network $\mathbf{s} = m(\mathbf{h})$를 사용하면 다음과 같이 표현할 수 있음

$$ \mathbf{y} = \text{tanh}(W_{k, f} * \mathbf{x} + V^T_{k,f} * \mathbf{s}) \odot \sigma (W_{k,g} * \mathbf{x} + V^T_{k,g} * \mathbf{s}) $$

<img src="https://user-images.githubusercontent.com/86907286/161759576-7ed0e29e-de3b-4216-9227-df8307f11d97.JPG" alt="13" width="700px" align="center" />

### PixelCNN as /Auto-Encoders

+ PixelCNN은 diverse하게 distribution을 generate 할 수 있으면서 latent vector를 통해 multimodal하게 generate 가능
+ 이를 활용하면 PixelCNN을 **Auto-Encoder의 decoder로서 활용할 수 있음**  
+ 이에 대한 실험의 결과를 보면 **latent vector에 반응하는 양상이 conventinal Auto-Encoder와 달랐음**
+ 일반적인 Auto-Encoder는 reconstruction을 시도하지만 **PixelCNN decoder는 유사한 image를 generate하려고 시도** 


<img src="https://user-images.githubusercontent.com/86907286/161759578-3e1a4a50-07fc-4a9d-9682-3f2a5a07ba7c.JPG" alt="14" width="700px" align="center" />

## Implementation Reviews

### [Implementation](https://github.com/rogertrullo/Gated-PixelCNN-Pytorch) (Unconditional Gated PixelCNN)

reference implementaion은 **padding/crop이 아니라 masking 방식으로 구현**

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

class MaskedConv(nn.Conv2d):
    '''
    Class that implements the masking for both streams vertical and horizontal.
    It is different if it is the first layer (A) or subsequent layers (B)
    '''
    def __init__(self, in_channels, out_channels, kernel_size, mask_type='A', ver_or_hor='V', use_gpu=True):
        assert mask_type in ['A', 'B'], 'only A or B are possible mask types'
        assert ver_or_hor in ['V', 'H'], 'only H or V are possible ver_or_hor types'

        if ver_or_hor == 'H':  # 1XN mask
            pad = (0, (kernel_size - 1) // 2)
            ksz = (1, kernel_size)

        else:  # NxN mask vertical
            ksz = kernel_size
            pad = (kernel_size - 1) // 2

        super().__init__(in_channels, out_channels, kernel_size=ksz, padding=pad)
        self.mask = torch.zeros_like(self.weight)
        if use_gpu:
            self.mask = self.mask.cuda()#TODO make gpu optional

        if mask_type == 'A':
            if ver_or_hor == 'V':  # NXN mask
                self.mask[:, :, 0:self.mask.shape[2] // 2, :] = 1

            else:  # horizontal 1xN
                self.mask[:, :, :, 0:self.mask.shape[3] // 2] = 1
        else:  # B
            if ver_or_hor == 'V':  # NXN mask
                self.mask[:, :, 0:self.mask.shape[2] // 2, :] = 1
                self.mask[:, :, self.mask.shape[2] // 2, :] = 1

            else:  # horizontal 1xN
                self.mask[:, :, :, 0:self.mask.shape[3] // 2 + 1] = 1

    def __call__(self, x):
        self.weight.data *= self.mask  # mask weights
        # print(self.weight)
        return super().__call__(x)


class GatedConvLayer(nn.Module):
    '''
    Main building block of the framework. It implements figure 2 of the paper.
    '''
    def __init__(self, in_channels, nfeats, kernel_size=3, mask_type='A'):
        super(GatedConvLayer, self).__init__()
        self.nfeats = nfeats
        self.mask_type = mask_type
        self.vconv = MaskedConv(in_channels=in_channels, out_channels=2 * nfeats, kernel_size=kernel_size,
                                ver_or_hor='V', mask_type=mask_type)

        self.hconv = MaskedConv(in_channels=in_channels, out_channels=2 * nfeats, kernel_size=kernel_size,
                                ver_or_hor='H', mask_type=mask_type)

        self.v_to_h_conv = nn.Conv2d(in_channels=2 * nfeats, out_channels=2 * nfeats, kernel_size=1)  # 1x1 conv

        self.h_to_h_conv = nn.Conv2d(in_channels=nfeats, out_channels=nfeats, kernel_size=1)  # 1x1 conv

    def GatedActivation(self, x):
        return torch.tanh(x[:, :self.nfeats]) * torch.sigmoid(x[:, self.nfeats:])

    def forward(self, x):
        # x should be a list of two elements [v, h]
        iv, ih = x
        ov = self.vconv(iv)
        oh_ = self.hconv(ih)
        v2h = self.v_to_h_conv(ov)
        oh = v2h + oh_

        ov = self.GatedActivation(ov)

        oh = self.GatedActivation(oh)
        oh = self.h_to_h_conv(oh)

        ##############################################################################
        #Due to the residual connection, if we add it from the first layer, ##########
        #the current pixel is included, in my implementation I removed the first #####
        #residual connection to solve this issue #####################################
        ##############################################################################
        if self.mask_type == 'B':
            oh = oh + ih

        return [ov, oh]


class PixelCNN(nn.Module):
    '''
    Class that stacks several GatedConvLayers, the output has Klevel maps.
    Klevels indicates the number of possible values that a pixel can have e.g 2 for binary images or
    256 for gray level imgs.
    '''
    def __init__(self, nlayers, in_channels, nfeats, Klevels=2, ksz_A=5, ksz_B=3):
        super(PixelCNN, self).__init__()
        self.layers = nn.ModuleList(
            [GatedConvLayer(in_channels=in_channels, nfeats=nfeats, mask_type='A', kernel_size=ksz_A)])
        for i in range(nlayers):
            gatedconv = GatedConvLayer(in_channels=nfeats, nfeats=nfeats, mask_type='B', kernel_size=ksz_B)
            self.layers.append(gatedconv)
        #TODO make kernel sizes as params

        self.out_conv = nn.Sequential(
            nn.Conv2d(nfeats, nfeats, 1),
            nn.ReLU(True),
            nn.Conv2d(nfeats, Klevels, 1)
        )


    def forward(self, x):
        x = [x, x]
        for i, layer in enumerate(self.layers):
            x = layer(x)
        logits = self.out_conv(x[1])

        return logits

### Training

In [2]:
batch_size=64
num_workers=2

show_every=100 # show info every this number of iterations
nlayers=12 # number of layers for pixelcnn
inchans=1 #number of input channels (currently only one is supported)
nfeats=16 #number of feature maps across the network
Klevels=4 #number of levels to use in discretization
nepochs=5 #number of epochs to train
lr=1e-3 #learning rate for optimizer
generate_every=300
nimgs_to_generate=16

In [3]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST('./data', download=True, train=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                           num_workers=num_workers, pin_memory=True)

In [4]:
def generate_imgs(model, shape, nimgs):
    x=torch.zeros((nimgs,1,shape[0],shape[1])).cuda()
    logits=model(x)
    model.eval()
    for i in range(x.shape[2]):
        for j in range(x.shape[3]):
            logits=model(x)
            probs=torch.softmax(logits[:,:,i,j],1)
            sample=probs.multinomial(1)
            x[:,:,i,j]=sample.float()/(Klevels-1)
    model.train()
    return x.cpu()

def discretize_imgs(img_tensor, nlevels):
    '''
    discretize a floating tensor to a discrete version into nlevels (quantization).
    The function assumes that the data is between [0,1]
    it return the two outputs; the first is again between [0,1] but only nlevels.
    The second is the equivalente but with integer indices between [0,nlevels-1]
    '''
    xnp=img_tensor.numpy()
    xnp_dig=(np.digitize(xnp, np.arange(nlevels) / nlevels) - 1).astype(np.long)
    xnp=xnp_dig/(nlevels -1)
    return torch.from_numpy(xnp).float(), torch.from_numpy(xnp_dig)

In [5]:
model = PixelCNN(nlayers=nlayers, in_channels=inchans, nfeats=nfeats, Klevels=Klevels).cuda()
optimizer = Adam(model.parameters(), lr=lr, betas=(0, 0.99))
criteria=nn.CrossEntropyLoss()

list_imgs=[]
for epoch in range(nepochs):
    for it,(images, labels) in enumerate(train_loader):

        imgs,imgs_quant= discretize_imgs(images, Klevels)
        imgs=imgs.cuda()
        imgs_quant=imgs_quant.cuda()
        logits=model(imgs)
        loss=criteria(logits,imgs_quant.squeeze())
        optimizer.zero_grad() # Backward & update weights
        loss.backward()
        optimizer.step()
        
        if it%show_every==0:
            print(f'epoch: {epoch}, it:{it}/{len(train_loader)}, loss:{loss.item()}')
        if it%generate_every==0:
            print('generating imgs...')
            samples=generate_imgs(model, (imgs.shape[2],imgs.shape[3]), nimgs_to_generate)
            list_imgs.append(samples)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


epoch: 0, it:0/938, loss:1.3460774421691895
generating imgs...
epoch: 0, it:100/938, loss:0.22751770913600922
epoch: 0, it:200/938, loss:0.19475848972797394
epoch: 0, it:300/938, loss:0.19172492623329163
generating imgs...
epoch: 0, it:400/938, loss:0.18021661043167114
epoch: 0, it:500/938, loss:0.19130343198776245
epoch: 0, it:600/938, loss:0.1697963923215866
generating imgs...
epoch: 0, it:700/938, loss:0.18709248304367065
epoch: 0, it:800/938, loss:0.18620853126049042
epoch: 0, it:900/938, loss:0.16931068897247314
generating imgs...
epoch: 1, it:0/938, loss:0.1788172423839569
generating imgs...
epoch: 1, it:100/938, loss:0.16182249784469604
epoch: 1, it:200/938, loss:0.1640923172235489
epoch: 1, it:300/938, loss:0.15908195078372955
generating imgs...
epoch: 1, it:400/938, loss:0.15759702026844025
epoch: 1, it:500/938, loss:0.16653862595558167
epoch: 1, it:600/938, loss:0.1695224642753601
generating imgs...
epoch: 1, it:700/938, loss:0.16781367361545563
epoch: 1, it:800/938, loss:0.1

### Results

In [10]:
!apt install imagemagick

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  fonts-droid-fallback fonts-noto-mono ghostscript gsfonts
  imagemagick-6-common imagemagick-6.q16 libcupsfilters1 libcupsimage2
  libdjvulibre-text libdjvulibre21 libgs9 libgs9-common libijs-0.35
  libjbig2dec0 liblqr-1-0 libmagickcore-6.q16-3 libmagickcore-6.q16-3-extra
  libmagickwand-6.q16-3 libnetpbm10 libwmf0.2-7 netpbm poppler-data
Suggested packages:
  fonts-noto ghostscript-x imagemagick-doc autotrace cups-bsd | lpr | lprng
  enscript gimp gnuplot grads hp2xx html2ps libwmf-bin mplayer povray radiance
  sane-utils texlive-base-bin transfig ufraw-batch inkscape libjxr-tools
  libwmf0.2-7-gtk poppler-utils fonts-japanese-mincho | fonts-ipafont-mincho
  fonts-japanese-gothic | fonts-ipafont-gothic fonts-arphic-ukai
  fonts-arphic-uming fonts-nanum
The following NEW packages will be installed:
  fonts-droid-fallback fonts-noto-mono g

In [11]:
from matplotlib.animation import FuncAnimation

print(len(list_imgs))
fig, ax=plt.subplots(figsize=(20,20))

def gen_grid(i):
    print(f'frame {i}')
    img=make_grid(list_imgs[i], nrow=4)
    npimg = img.numpy()
    ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    ax.set_title(f'iteration {i*generate_every}', fontsize=50)
    
  
anim = FuncAnimation(fig, gen_grid, frames=np.arange(len(list_imgs)), interval=500)
anim.save('digits.gif', dpi=80, writer='imagemagick')
plt.close()

20
frame 0
frame 0
frame 1
frame 2
frame 3
frame 4
frame 5
frame 6
frame 7
frame 8
frame 9
frame 10
frame 11
frame 12
frame 13
frame 14
frame 15
frame 16
frame 17
frame 18
frame 19


<img src="https://user-images.githubusercontent.com/86907286/161759581-0d88aff6-aa97-407e-af78-c365057a8782.gif" alt="digits" width="500px" align="center" />

## Reference

https://arxiv.org/pdf/1601.06759.pdf  
https://arxiv.org/pdf/1606.05328v2.pdf  
https://sergeiturukin.com/2017/02/24/gated-pixelcnn.html  
https://youtu.be/1BURwCCYNEI  
https://github.com/rogertrullo/Gated-PixelCNN-Pytorch