In [1]:
%load_ext autoreload
%autoreload 2
import os
from copy import deepcopy
import random
from collections import OrderedDict, defaultdict
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Compose, Normalize
import numpy as np
import pandas as pd
import PIL
from PIL import Image
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8')

from mnist import MNIST, MNISTTrain, MNISTTest
os.environ["CUDA_LAUNCH_BLOCKING"]= "1"
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
plt.style.use("seaborn-v0_8")

In [2]:
batch_size = 1024 * 4
train_loader = DataLoader(dataset=MNISTTrain(transform=Compose([ToTensor(),])),
                          batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=MNISTTest(transform=Compose([ToTensor(),])),
                          batch_size=batch_size, shuffle=False)


In [39]:

def pad_down_shift(filter_sz, ):
    
    def fn(x):
        return F.pad(x, ( int((filter_sz[1]-1)/2), int((filter_sz[1] - 1)/2), filter_sz[0] - 1, 0,), )
    return fn


def pad_right_shift(filter_sz):
    
    def fn(x):
        return F.pad(x, (filter_sz[1] - 1, 0, filter_sz[0] - 1, 0))
    return fn


def down_shift(x: torch.Tensor):
    xs = list(x.size())
    return torch.cat([x.new_zeros([xs[0], xs[1], 1, xs[3]]), x[:, :, :xs[2] - 1, :]], 2)


def right_shift(x: torch.Tensor):
    xs = list(x.size())
    return torch.cat([x.new_zeros([xs[0], xs[1], xs[2], 1]), x[:, :, :, :(xs[3] - 1)]], 3)


class Block(nn.Module):
    
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):  
        pass
    
class GatedNet(nn.Module):
    def __init__(self, ):
        super().__init__()

    def forward(self, x: torch.Tensor):  
        pass
    
    
class PixelCNNPP(nn.Module):
    
    def __init__(self, in_dim=3, nlayers=5, nfilters=128, nlogmix=10, dropout=0.5, activation="relu"):
        super().__init__()
        self.nlayers = nlayers
        self.nfilters = nfilters
        self.nlogmix = nlogmix
        self.dropout_rate = 0.5
        self.act = {"relu": nn.ReLU(), "elu": nn.ELU(), "gelu": nn.GELU(), "tanh": nn.Tanh()}[activation]
        
        self.net1 = [
            pad_down_shift([2, 3]), # torch.Size([2, 4, 33, 34])
            nn.Conv2d(in_channels=in_dim + 1, out_channels=nfilters, kernel_size=([2, 3])), # torch.Size([2, 128, 32, 32])
            down_shift, # torch.Size([2, 128, 32, 32])
           
        ]
        
        self.net2 = [
            pad_down_shift([1, 3]), #torch.Size([2, 4, 33, 34])
            nn.Conv2d(in_channels=in_dim + 1, out_channels=nfilters, kernel_size=([1, 3])), # torch.Size([2, 128, 32, 32])
            down_shift,  # torch.Size([2, 128, 32, 32])
        ]
        
        self.net3 =  [
            pad_right_shift([2, 1]), # torch.Size([2, 4, 33, 32])
            nn.Conv2d(in_channels=in_dim + 1, out_channels=nfilters, kernel_size=([2, 1])), # torch.Size([2, 128, 32, 32])
            right_shift,  # torch.Size([2, 128, 32, 32])
        ]
    
    def forward(self, x: torch.Tensor):
        size = [x.size(0)] + [1] + [x.size(2), x.size(3)]
        x_pad = torch.cat([x, x.new_ones(size)], 1)
        
        states1 = []
        net_inp1 = x_pad.clone()
        for net in self.net1:
            net_inp1 = net(net_inp1)
            states1.append(net_inp1)
        
        net_inp2 = x_pad.clone()
        states2 = []
        for net in self.net2:
            net_inp2 = net(net_inp2)
            states2.append(net_inp2)
        
        states3 = []
        net_inp3 = x_pad.clone()
        for net in self.net3:
            net_inp3 = net(net_inp3)
            states3.append(net_inp3)
        x1, x2 = states1[-1], states2[-1] + states3[-1]
        return net_inp3.size()

torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

x = torch.randn(2, 3, 32, 32)
PixelCNNPP()(x)

torch.Size([2, 128, 32, 32])