In [14]:
import numpy as np
import pandas as pd
import os
import glob
import tqdm
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.optim import Adam, SGD
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F
eps = np.finfo(float).eps

import tqdm

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline
sns.set()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
use_cuda = False
use_cuda = use_cuda and torch.cuda.is_available()

In [16]:
def load_data(path):
    with open(path, 'rb') as fp:
        data = pickle.load(fp)
    train = data['train'].astype(np.float32)
    test = data['test'].astype(np.float32)
    return train, test

train_valid, x_test = load_data('../data/mnist-hw1.pkl')
x_train = train_valid[:int(len(train_valid)*0.8)]
x_valid = train_valid[int(len(train_valid)*0.8):]

batch_size = 128

x_train = torch.from_numpy(x_train).cuda() if use_cuda else torch.from_numpy(x_train)
x_valid = torch.from_numpy(x_valid).cuda() if use_cuda else torch.from_numpy(x_valid)
x_test = torch.from_numpy(x_test).cuda() if use_cuda else torch.from_numpy(x_test)

dataloader_train = DataLoader(TensorDataset(x_train), batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(TensorDataset(x_valid), batch_size=batch_size, shuffle=True)
dataloader_test = DataLoader(TensorDataset(x_test), batch_size=batch_size, shuffle=True)

In [17]:
x = dataloader_train.__iter__().__next__()[0]

In [18]:
# todo: transpose?
x.size()

torch.Size([128, 28, 28, 3])

In [19]:
def init_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
            init.xavier_normal_(m.weight.data)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.Sequential):
            for sub_mod in m:
                init_weights(sub_mod)
                
def one_hot(labels, n_class):
    # Ensure labels are [N x 1]
    if len(list(labels.size())) == 1:
        labels = labels.unsqueeze(1)
    mask = type_tdouble()(labels.size(0), n_class).fill_(0)
    # scatter dimension, position indices, fill_value
    return mask.scatter_(1, labels, 1)


In [54]:
class MaskedConv2d(nn.Conv2d):
    
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv2d, self).__init__(*args, **kwargs)
        self.mask_type = mask_type
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kernel_height, kernel_width = self.weight.size()
        self.mask.fill_(0)
        half_h, half_w = kernel_height // 2, kernel_width // 2
    
        self.mask[:, :, :half_h, :] = 1.0
        self.mask[:, :, half_h, :half_w] = 1.0
        if self.mask_type == 'A':
            self.mask[:, :, half_h, half_w] = 0.0
        else:
            self.mask[:, :, half_h, half_w] = 1.0
        
    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

    
class ResidualBlock(nn.Module):
    def __init__(self, h):
        super(ResidualBlock, self).__init__()
        self.h = h
        # todo: set padding to same
        self.network = []
        self.network.extend([
            nn.Conv2d(self.h, self.h // 2, (1, 1)),
            nn.BatchNorm2d(self.h // 2),
            nn.ReLU()
        ])
        self.network.extend([
            MaskedConv2d('B', self.h // 2, self.h // 2, (3, 3), padding=1),
            nn.BatchNorm2d(self.h // 2),
            nn.ReLU()
        ])
        
        self.network.extend([
            nn.Conv2d(self.h // 2, self.h, (1, 1)),
            nn.BatchNorm2d(self.h),
            nn.ReLU()
        ])
        
        self.network = nn.Sequential(*self.network)
        
    def forward(self, x):
        skip = x
        x = self.network(x)
        return F.relu(x + skip)


class PixelCNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PixelCNN, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.network = []
        
        # 7x7 Conv input, type A
        self.network.extend([
            MaskedConv2d('A', 3, self.in_channels, (7, 7), padding=3),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU()
        ])

        self.network.extend(
            [ResidualBlock(self.in_channels) for _ in range(15)]
        )
        
        # 3x3 Conv input, type B
        self.network.extend([
            MaskedConv2d('B', self.in_channels, self.in_channels, (3, 3), padding=1),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(),
        ])
        
        # 1x1 Conv input
        self.network.extend([
            nn.Conv2d(self.in_channels, self.in_channels, (1, 1)),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(),
            nn.Conv2d(self.in_channels, self.out_channels, (1, 1)),
        ])
        
        self.network = nn.Sequential(*self.network)
        
    def forward(self, x):
        x = self.network(x)
        torch.reshape(x, (-1, 28, 28, 12))
        sm = F.softmax(x, dim=-1)
        return x, sm

In [55]:
def loss_bcel(x, x_hat):
    BCEL = F.binary_cross_entropy_with_logits(
        x_hat.view(-1, 1), x.view(-1, 1), reduction='mean')
    return BCEL

# def nll(logits, x):
#     loss = tf.nn.softmax_cross_entropy_with_logits(tf.one_hot(tf.cast(x, dtype=tf.uint8), depth=4),
#                                                   logits, axis=-1) 
#     return tf.cast(tf.reduce_mean(loss), tf.float32)

def loss_bce(x, x_hat):
    BCE = F.binary_cross_entropy(
        x_hat.view(-1, 1), x.view(-1, 1), reduction='mean')
    return BCE


def train_validate(model, dataloader, optim, loss_fn, train):
    model.train() if train else model.eval()
    total_loss = 0
    for batch_idx, x in enumerate(dataloader):
        x = x[0]
        x_hat = model(x)
        loss = loss_fn(x, x_hat)

        if train:
            optim.zero_grad()
            loss.backward()
            optim.step()

        total_loss += loss.item()
    return total_loss / len(dataloader.dataset)


In [None]:
# input_size = x_train.size(1)
# output_size = x_train.size(1)
batch_size = 128

model = PixelCNN(128, 3*4)
model = model.cuda() if use_cuda else model
model.apply(init_weights)

In [52]:
sample = torch.Tensor(128, 3, 28, 28)
sample.size()
x, sm = model(sample)

RuntimeError: The size of tensor a (26) must match the size of tensor b (28) at non-singleton dimension 3

In [None]:
# def nll(logits, x):
#     loss = tf.nn.softmax_cross_entropy_with_logits(tf.one_hot(tf.cast(x, dtype=tf.uint8), depth=4),
#                                                   logits, axis=-1) 
#     return tf.cast(tf.reduce_mean(loss), tf.float32)
