## Conditional PixelCNN Implementation

In [1]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

#### Modifications needed for PixelCNN
Some Points :
1. ResBlock Architecture as shown in Figure 5 of PixelRNN Paper : https://arxiv.org/pdf/1601.06759.pdf

In [2]:
class MaskConv2D(nn.Conv2d):
    def __init__(self, mask_type, *args, class_conditioning = False, conditional_size = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros(self.weight.size())) # Weight shape is same as tensor shape
        self.conditional_size = conditional_size
        assert mask_type in ['A', 'B'], "Unknown Mask Type"
        h =  self.kernel_size[0]
        w = self.kernel_size[1]
        
        # Creating masks for autoregressive properties
        self.mask[:, :, :h//2, :] = 1
        self.mask[:, :, h//2, :w//2 + (mask_type == 'B')] = 1 
        
        # Adding autoregressive property of color channels
        if class_conditioning:
            in_third, out_third = self.in_channels // 3, self.out_channels // 3
            if mask_type == 'B':
                self.mask[out_third:, in_third:, h // 2, w // 2] = 1
                self.mask[out_third:2*out_third, :2*in_third, h // 2, w // 2] = 1
                self.mask[2*out_third:, :, h // 2, w // 2] = 1
            else:
                self.mask[out_third:2*out_third, :in_third, h // 2, w // 2] = 1  
                self.mask[2*out_third:, :2*in_third, h // 2, w // 2] = 1
                
        if self.conditional_size:
            if len(self.conditional_size) == 1:
                self.cond_op = nn.Linear(conditional_size[0], self.out_channels)
            else:
                self.cond_op = nn.Conv2d(conditional_size[0], self.out_channels, stride = 1,
                                         kernel_size = 3, padding = 1)
      
    def forward(self, x, cond = None):
        self.weight.data *= self.mask
        out = super(MaskConv2D, self).forward(x)
        if self.conditional_size:
            if len(self.conditional_size) == 1:
                out = out + self.cond_op(cond).view(x.shape[0], -1, 1, 1)
            else:
                out = out + self.cond_op(cond)
        return out

class ResBlock(nn.Module):
    def __init__(self, in_channels, **kwargs):
        self.net = nn.ModuleList([
            nn.Relu(),
            MaskConv2D('B', in_channels, in_channels // 2, 1, 1, 1 // 2, **kwargs),
            nn.ReLU(),
            MaskConv2D('B', in_channels // 2, in_channels // 2, 1, 7, 7 // 2, **kwargs),
            nn.ReLU(),
            MaskConv2D('B', in_channels // 2, in_channels, 1, 1, 1 // 2, **kwargs),
        ])
    def forward(self, x, cond = None):
        out = self.net(x)
        for layer in self.net:
            if isinstance(layer, MaskConv2D):
                out = layer(out, cond = cond)
            else:
                out = layer(out)
        return x + out

### Architecture of the Model
There will only be 1 Type A mask to maintain autoregressive properties. Following that all Type B masks would be present.
<br>Network Architecture ( can be modified, but just keep 1 Type A Mask ) : 
1. Type A Mask with Kernel Size = 7 and Padding = Kernel_Size // 2 to maintain input size
2. Type B Mask / Residual Blocks( x no_of_layers as defined in input )
3. Type B Mask with Kernel Size = 1 
4. Type B Mask with Kernel Size = 1, number of output channels is input_channels * number of colours ( 255 or 2 if binary 0 and 1 )

In [14]:
class Conditional_PixelCNN(nn.Module):
    def __init__(self, input_shape, channels, colors, no_of_layers,
                 color_conditioning, use_ResBlock, conditional_size = None, device = None):
        super(PixelCNN, self).__init__()
        self.input_shape = input_shape
        self.device = device
        self.channels = channels
        self.color_channels = colors
        self.color_conditioning = color_conditioning
        self.conditional_size = conditional_size
        
        # Define kwargs based on input
        kwargs = dict(
            color_conditioning = self.color_conditioning,
            conditional_size = self.conditional_size
        )
        
        # Initialize block function to be used repeatedly
        if use_ResBlock:
            block = lambda: ResBlock(channels, **kwargs)    
        else:
            block = lambda: MaskConv2D('B', True, conditional_size, channels, channels, 1, 7, 7 // 2, **kwargs)
        
        # 7 x 7 Conv2D operation using Type A Mask
        kernel_size = 7
        self.layers = []
        self.layers.extend([MaskConv2D('A', input_shape[0], channels, 1, kernel_size, kernel_size // 2, **kwargs)])
        
        # 5 7 x 7 Conv2D operation using Type B Mask
        for _ in range(no_of_layers):
            self.layers.extend([nn.ReLU(),
                                nn.BatchNorm2d(channels),
                                block(),])
        
        # 2 1 x 1 Conv2D operation using Type B Mask
        kernel_size = 1
        self.layers.extend([nn.ReLU(),
                            nn.BatchNorm2d(channels),
                            MaskConv2D('B', True, channels, channels, stride = 1,
                                       kernel_size = kernel_size, **kwargs),
                            nn.ReLU(),
                            nn.BatchNorm2d(channels),
                            MaskConv2D('B', True, channels, self.color_channels * self.input_shape[0],
                                       stride = 1, kernel_size = kernel_size, **kwargs)])
        self.net = nn.ModuleList(*self.layers)
        
        if self.conditional_size:
            if len(self.conditional_size) == 1:
                self.cond_op = lambda x: x 
            else:
                self.cond_op = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU()
        )

    def forward(self, x, cond = None):
        batch_size = x.shape[0]
        out = (x.float() / (self.n_colors - 1) - 0.5) / 0.5
        if self.conditional_size:
            cond = self.cond_op(cond)
        for layer in self.net:
            if isinstance(layer, MaskConv2D):
                out = layer(out, cond)
            else:
                out = layer(out)
            
        if self.color_conditioning:
            return out.view(batch_size, self.input_shape[0], self.color_channels,
                          *self.input_shape[1:]).permute(0, 2, 1, 3, 4)
        else:
            return out.view(batch_size, self.color_channels, *self.input_shape)
    
    def loss(self, x, cond = None):
        logits = self(x, cond)
        loss = F.cross_entropy(logits, x.long())
        return loss
 
    def get_samples(self, n):
        samples = torch.zeros([n, *self.input_shape]).to(self.device)
        with torch.no_grad():
            for r in range(self.input_shape[1]):
                for c in range(self.input_shape[2]):
                    for k in range(self.input_shape[0]):
                        out = self(samples)[:, :, k, r, c]
                        probs = F.softmax(out, dim = 1)
                        samples[:, k, r, c] = torch.multinomial(probs, 1).squeeze(-1)
        return samples.permute(0, 2, 3, 1).cpu().numpy()

In [0]:
def train(model, trainloader, optimizer, device):
    model.train()
    train_loss = []
    for img, labels in enumerate(trainloader):
        optimizer.zero_grad()
        loss = model.loss(img.to(device), labels)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    return train_loss

def evaluate(model, testloader, optimizer, device):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for img, labels in enumerate(testloader):
            loss = model.loss(img.to(device)).item()
            test_loss += loss * img.shape[0]
        test_loss /= len(testloader.dataset) # dividing by batch size
    return test_loss

In [None]:
class ModifiedDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [0]:
epochs = 10
minibatch = 128
d = 2

train_loss = np.zeros([minibatch * epochs, 1])
test_loss = np.zeros([epochs, 1])

trainLoader = torch.utils.data.DataLoader(train_dataset, batch_size = minibatch, shuffle = True)
testLoader = torch.utils.data.DataLoader(test_dataset, batch_size = minibatch, shuffle = True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PixelCNN((1, 28, 28), channels = 64, colors = 255, no_of_layers = 4, device = device).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02)

train_loss = []
test_loss = []

for epoch in range(epochs):
    print ("Epoch No. " + str(epoch))
    train_loss.extend(train(model, trainLoader, optimizer, device))
    test_loss.append(evaluate(model, testLoader, optimizer, device)) 

Epoch No. 0
Epoch No. 1
Epoch No. 2
Epoch No. 3
Epoch No. 4
Epoch No. 5
Epoch No. 6
Epoch No. 7
Epoch No. 8
Epoch No. 9


In [13]:
# s = model.get_samples(25)

In [12]:
# size = 5
# fig, axs = plt.subplots(size, size)
# for i in range(0, size):
#     for j in range(0, size):
#         axs[i, j].imshow(s[size * i + j].reshape(28, 28))

In [10]:
# plt.plot(train_loss)
# plt.show()

In [11]:
# plt.plot(test_loss)
# plt.show()