![PNG%EC%9D%B4~3.PNG](https://user-images.githubusercontent.com/75057952/167243294-0400b833-5bbd-46ce-adc5-bcd1073969df.png)

# PixelCNN++ : IMPROVING THE PIXELCNN WITH DISCRETIZED LOGISTIC MIXTURE LIKELIHOOD AND OTHER MODIFICATIONS 
Tim Salimans, Andrej Karpathy, Xi Chen, Diederik P. Kingma
### Recap - **PixelCNN : Generative model of images with a tractable likelihood**
- Fully factorizes the probability density function on an image x over all its sub-pixels (color channels in a pixel)
- $p(\mathcal x) = \Pi_i p(x_i|X_{<i})$ : conditional distributions $ p(x_i|X_{<i}$) are parameterized by CNNs
---
### Modifications to PixelCNN
**1. Discretized logistic mixture likelihood**

**PixelCNN**
- PixelCNN model에서는 pixel의 color-channel 단위의 conditional distribution으로 분해함
- Probability distribution은 256-way(0~255) softmax로 구해지기 때문에 계산량이 많음(Regression Problem > 256 classification problem)
- 굉장히 Sparse한 gradient : early training step에서 특히 두드러짐
- 127과 128은 difference class로 인식, numerically adjacent하다는 것을 알지 못함
---
**PixelCNN++**
- VAE와 유사하게 Pixel value의 conditional probability를 설명하는 latent variable $\nu$의 분포를 추정하는 것이 목표
- $\nu$의 분포로 여기서는 logistic distribution을 가정
- $\nu \sim \sum_{i=1}^K \pi_i \text{logistic}(\mu_i, s_i)$
- $P(x|\pi, \mu, s) = \sum_{i=1}^K \pi_i[\sigma({(x+0.5-\mu_i) \over s_i}) - \sigma({(x-0.5-\mu_i) \over s_i})]$ (edge case 제외)
- CIFAR-10 데이터셋의 pixel conditional distribution plot
<img src = "https://user-images.githubusercontent.com/75057952/167244052-6db2a090-c151-42a0-805c-bde66e81fb81.png" width = "500dp"></img>
---
**2. Conditioning on whole pixels**

**PixelCNN**
- R > G > B 순서대로 channel 우선순위를 정하고, channel 단위로 분해해서 conditional distribution을 구함
- 하나의 pixel이 가지는 3개의 sub-pixel에 대한 generative model을 세움
- General한 dependency를 설명하기에는 적합하지만 너무 complicated
---
**PixelCNN++**
- 가정 : Inter-channel dependency는 복잡한 relationship이 아님
- Simple factorized model로도 충분히 설명 가능하다고 생각
- $C_{i,j}$ : context vector(이전 픽셀들의 정보, mixture indicator across all 3 channels)
- $p(r_{i,j}, g_{i,j}, b_{i,j} | C_{i,j}) = P(r_{i,j}|\mu_r(C_{i,j}), s_r(C_{i,j})) \times P(g_{i,j}|\mu_g(C_{i,j}, r_{i,j}), s_g(C_{i,j})) \times P(b_{i,j}|\mu_b(C_{i,j}, r_{i,j}, g_{i,j}), s_b(C_{i,j}))$
- $\mu_g{(C_{i,j}, r_{i,j})} = \mu_g{(C_{i,j})} + \alpha(C_{i,j}) r_{i,j}$
- $\mu_b(C_{i,j}, r_{i,j}, g_{i,j} = \mu_b(C_{i,j}) + \beta(C_{i,j})r_{i,j} + \gamma(C_{i,j})b_{i,j}$

---
**3.  Downsampling vs. dilated convolution**

**PixelCNN**
- Small receptive field의 convolution filter(주로 3X3)을 사용
- Local dependency capture에는 용이하지만 long-range structure modeling에는 부적합
---
**PixelCNN++**
- Receptive field를 크게 하기 위해서 Dilated convolution으로 multi-resolution view
- 혹은 downsampling 이후에 convolution을 하기도 함
    - loses information
    - compensate by introducing additional short-cut connections
    - 결과론적으로는 dilated convolution과 큰 성능 차이 X
---
**4.  Adding short-cut connections**
<img src = "https://user-images.githubusercontent.com/75057952/167244057-a190e3de-84e6-4ca7-bb0f-a44e448d3737.png" width = "700dp"></img>
- U-Net like architecture
- 총 6개의 layer로 구성되어 있는데, symmetric한 추상화 계층에 대해서 convolutional connection으로 연결
- Skip connection compensates information loss during downsampling, upsampling process

**5. Regularization using dropout**

### Experiments (and Ablation Study)

**Unconditional generation on CIFAR-10 : negative log-likelihood**
<img src = "https://user-images.githubusercontent.com/75057952/167244061-c205d93e-bcee-46e7-9f16-6f7cc43662af.png" width = "300dp"></img>
---
**Network depth and field of view size**
- PixelCNN Experiment을 보고 저자들이 세운 가설 : Receptive field size, removal of blind spots가 성능에 영향을 크게 미칠 것
- PixelCNN ++에서 Layer를 deep하지 않게 쌓거나 Receptive field size를 제한하면 > Network capacity가 떨어짐에도 불구하고 PixelCNN보다 나음(Plain)
- Plain에서 modification :  NIN(Network in Network), Autoregressive Channel을 도입(Capcity와 성능 측면 향상)
    - NIN : <img src = "https://user-images.githubusercontent.com/75057952/167244066-88d40748-6444-4eec-8108-c97e7169d283.png" width = "300dp"></img>
    - Autoregressive Channel : channel 사이의 skip connection(1X1 convolution gated ResNet block)를 도입
<img src = "https://user-images.githubusercontent.com/75057952/167244068-4165cf0a-5c3a-4192-82f1-ce3c03acd8c5.png" width = "300dp"></img>

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import arg_scope
import pixel_cnn_pp.nn as nn

![image.png](https://user-images.githubusercontent.com/75057952/167244074-db550f12-4875-4796-9d94-3256d36f546f.png)

In [None]:
def model_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu', energy_distance=False):

    counters = {}
    with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense], counters=counters, init=init, ema=ema, dropout_p=dropout_p):

        # parse resnet nonlinearity argument
        if resnet_nonlinearity == 'concat_elu':
            resnet_nonlinearity = nn.concat_elu
        elif resnet_nonlinearity == 'elu':
            resnet_nonlinearity = tf.nn.elu
        elif resnet_nonlinearity == 'relu':
            resnet_nonlinearity = tf.nn.relu
        else:
            raise('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported')

        with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h):

            # ////////// up pass through pixelCNN ////////
            xs = nn.int_shape(x)
            '''
            def int_shape(x):
                return list(map(int, x.get_shape()))
            '''
            x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on
            '''
            def down_shift(x):
                xs = int_shape(x)
                return tf.concat([tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]],1)
            '''
            u_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
            ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                       nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
            '''
            down_shift : input의 above에 zero layer를 하나 추가
            down_shifted_conv2d/deconv2d : 
            def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs):
                # NCHW 축으로의 padding 값을 넣어줌. H, W 축으로만 shift 해줌.
                x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]])
                return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs)
            '''
            for rep in range(nr_resnet): # nr_resnet = 5 (default)
                u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
                ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))
            '''
            gated_resnet : 위에 그림 참조
            '''
            u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
            ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

            for rep in range(nr_resnet):
                u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
                ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

            u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2]))
            ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2]))

            for rep in range(nr_resnet):
                u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d))
                ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d))

            # remember nodes
            for t in u_list+ul_list:
                tf.add_to_collection('checkpoints', t)

            # /////// down pass ////////
            u = u_list.pop()
            ul = ul_list.pop()
            for rep in range(nr_resnet):
                u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
                ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
                tf.add_to_collection('checkpoints', u)
                tf.add_to_collection('checkpoints', ul)

            u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
            ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

            for rep in range(nr_resnet+1):
                u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
                ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
                tf.add_to_collection('checkpoints', u)
                tf.add_to_collection('checkpoints', ul)

            u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2])
            ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2])

            for rep in range(nr_resnet+1):
                u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d)
                ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d)
                tf.add_to_collection('checkpoints', u)
                tf.add_to_collection('checkpoints', ul)

                '''
                nin : network in network
                '''
            if energy_distance:
                f = nn.nin(tf.nn.elu(ul), 64)

                # generate 10 samples
                fs = []
                for rep in range(10):
                    fs.append(f)
                f = tf.concat(fs, 0)
                fs = nn.int_shape(f)
                f += nn.nin(tf.random_uniform(shape=fs[:-1] + [4], minval=-1., maxval=1.), 64)
                f = nn.nin(nn.concat_elu(f), 64)
                x_sample = tf.tanh(nn.nin(nn.concat_elu(f), 3, init_scale=0.1))

                x_sample = tf.split(x_sample, 10, 0)

                assert len(u_list) == 0
                assert len(ul_list) == 0

                return x_sample

            else:
                x_out = nn.nin(tf.nn.elu(ul),10*nr_logistic_mix)

                assert len(u_list) == 0
                assert len(ul_list) == 0

                return x_out