In [102]:

import torch
import torchvision
import numpy as np
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import warnings
# import UserWarning

warnings.simplefilter("ignore", UserWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.get_device_name(0)


In [103]:

training_set = torch.utils.data.DataLoader(torchvision.datasets.MNIST('data', train=True, download=True, transform=torchvision.transforms.ToTensor()),
                     batch_size=16, shuffle=True)
validation_set = torch.utils.data.DataLoader(torchvision.datasets.MNIST('data', train=False, download=True, transform=torchvision.transforms.ToTensor()),
                     batch_size=16, shuffle=False)



In [104]:

class VerticalMaskedConv2d(torch.nn.Conv2d):
    
    def __init__(self, device="cpu", *arg, **kargs):
        super(VerticalMaskedConv2d, self).__init__(*arg, **kargs)
#         print(self.kernel_size)
        self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
        
        self.mask = torch.ones_like(self.weight.data)
        self.mask[:, :, self.kernel_size[0]//2:] = 0
        self.mask = self.mask.to(device)
        
    def forward(self, images):
        self.weight.data = self.weight.data * self.mask
        pred = super(VerticalMaskedConv2d, self).forward(images)
        return pred
    

In [105]:

# Test vertical conv2d
vertical_mask = VerticalMaskedConv2d(in_channels=5, out_channels=3, kernel_size=(7,7))
test_tensor = torch.randn(1, 5, 16, 16)
test_result = vertical_mask(test_tensor)
print(test_result.shape)


torch.Size([1, 3, 16, 16])


In [118]:


class HorizontalMaskedConv2d(torch.nn.Conv2d):
    
    def __init__(self, mask_type, device="cpu", *arg, **kargs):
        super(HorizontalMaskedConv2d, self).__init__(*arg, **kargs)

        self.padding = (0, self.kernel_size[1] // 2)
        batch, channel, height, width = self.weight.data.shape
        
        self.mask = torch.ones_like(self.weight.data)
        self.mask[:,:,0, width//2:] = 0
        if mask_type == "B":
             self.mask[:, :, 0, width//2] = 1        
        self.mask = self.mask.to(device)
        
    def forward(self, images):
        self.weight.data = self.weight.data * self.mask
        pred = super(HorizontalMaskedConv2d, self).forward(images)
        return pred


In [119]:

# horizontal_mask = HorizontalMaskedConv2d(mask_type="A", in_channels=1, out_channels=1, kernel_size=(1,11))
# test_tensor = torch.randn(1, 1, 16, 16)
# test_result = horizontal_mask(test_tensor)
# print(test_result.shape)
# print()

# horizontal_mask = HorizontalMaskedConv2d(mask_type="B", in_channels=1, out_channels=1, kernel_size=(1,11))
# test_tensor = torch.randn(1, 1, 16, 16)
# test_result = horizontal_mask(test_tensor)
# print(test_result.shape)

In [120]:

class Gated_Conv_Block(torch.nn.Module):
    
    def __init__(self, horizontal_mask_type, in_channels, out_channels, kernel_size):
        super(Gated_Conv_Block, self).__init__()
        
        self.out_channels = out_channels
        
        self.vertical_conv = VerticalMaskedConv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size)
        self.vertical_to_horizontal_conv = torch.nn.Conv2d(in_channels=out_channels*2, out_channels=out_channels*2, kernel_size=1, padding=0)
        self.horizontal_conv = HorizontalMaskedConv2d(mask_type=horizontal_mask_type, in_channels=in_channels, 
                                                             out_channels=out_channels*2, kernel_size=(1, kernel_size))
        self.res_conv = torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0)

    def forward(self, images):
        v_conv = self.vertical_conv(images)
        v_tanh, v_sigmoid = torch.split(tensor=v_conv, split_size_or_sections=self.out_channels, dim=1)
        v_output = torch.tanh(v_tanh) * torch.sigmoid(v_sigmoid)
        
        v_to_h = self.vertical_to_horizontal_conv(v_conv)
        h_conv = self.horizontal_conv(images)
        h_conv = h_conv + v_to_h
        h_tanh, h_sigmoid = torch.split(tensor=h_conv, split_size_or_sections=self.out_channels, dim=1)
        h_res = torch.tanh(h_tanh) * torch.sigmoid(h_sigmoid)
        h_res = self.res_conv(h_res)
        h_output = images + h_res
        return v_output, h_output
        


In [121]:

# Test gated conv
gated_conv = Gated_Conv_Block(horizontal_mask_type="A", in_channels=1, out_channels=5, kernel_size=5)
test_input = torch.randn(3, 1, 16, 16)
# v_output, h_output = gated_conv(test_input)


In [127]:

class Gated_Pixel_CNN_Mnist(torch.nn.Module):
    def __init__(self, nb_gated_block, hidden_layer_dim, in_channels, out_channels, kernel_size):
        super(Gated_Pixel_CNN_Mnist, self).__init__()
        
        self.nb_gated_block = nb_gated_block
        self.first_gated_block = Gated_Conv_Block(horizontal_mask_type="A", in_channels=in_channels, 
                                                  out_channels=hidden_layer_dim, kernel_size=kernel_size)
        
        self.gated_block_list = []
        self.batch_norm_list = []
        for i in range(1, nb_gated_block):
            new_cell = Gated_Conv_Block(horizontal_mask_type="B", in_channels=hidden_layer_dim, 
                                                  out_channels=hidden_layer_dim, kernel_size=kernel_size)
            self.gated_block_list.append(new_cell)
            self.batch_norm_list.append(torch.nn.BatchNorm2d(hidden_layer_dim))
        self.gated_block_list = torch.nn.ModuleList(self.gated_block_list)
        self.batch_norm_list = torch.nn.ModuleList(self.batch_norm_list)
        
        self.last_conv = torch.nn.Conv2d(in_channels=hidden_layer_dim, out_channels=out_channels, kernel_size=1)
        

    def forward(self, images):
        v_output, h_output = self.first_gated_block(images)
        for i in range(self.nb_gated_block-1):
            v_output, h_output = self.gated_block_list[i](h_output)
            h_output = self.batch_norm_list[i](h_output)
        pred = self.last_conv(h_output)
        return pred

gated_pixel_cnn = Gated_Pixel_CNN_Mnist(nb_gated_block=10, hidden_layer_dim=128, in_channels=1, out_channels=2, kernel_size=5)
gated_pixel_cnn.to(device=device)
optim = torch.optim.Adam(gated_pixel_cnn.parameters(), lr=0.001)


In [134]:

epochs = 20

for epoch in range(epochs):
    for images, label in tqdm(training_set):
        images = images.to(device)
        pred = gated_pixel_cnn(images)
        loss = torch.nn.CrossEntropyLoss()(pred, images[:,0,...].long())
        optim.zero_grad()
        loss.backward()
        optim.step()




HBox(children=(FloatProgress(value=0.0, max=3750.0), HTML(value='')))




In [125]:
loss = torch.nn.CrossEntropyLoss()(pred, images)


tensor(1.)

In [130]:
print(pred.shape)

torch.Size([16, 2, 28, 28])


In [132]:
print(images.shape)

torch.Size([16, 1, 28, 28])
