In [6]:
import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.optim import Adam, SGD
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F
eps = np.finfo(float).eps

import tqdm

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline
sns.set()

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
use_cuda = True
use_cuda = use_cuda and torch.cuda.is_available()

In [20]:
dataloader_train = DataLoader(dset.MNIST('data', train=True, download=True, transform=T.ToTensor()),
                     batch_size=128, shuffle=True, num_workers=1, pin_memory=True)
dataloader_val = DataLoader(dset.MNIST('data', train=False, download=True, transform=T.ToTensor()),
                     batch_size=128, shuffle=False, num_workers=1, pin_memory=True)

In [99]:
def loss_bce(x, x_hat):
    BCE = F.binary_cross_entropy(
        x_hat.view(-1, 1), x.view(-1, 1), reduction='mean')
    return BCE


def loss_ce(x, x_hat):
    CE = F.cross_entropy(
        x_hat.view(-1, 1), x.view(-1, 1), reduction='mean')
    return CE


def train_validate(model, dataloader, optim, loss_fn, train):
    model.train() if train else model.eval()
    total_loss = 0
    for batch_idx, (x, _ )in enumerate(dataloader):
        target = (x.data[:, 0]).long()
        x_hat = model(x)
#         print(x_hat.size())
#         print(x.size())
        loss = loss_fn(x_hat, target)

        if train:
            optim.zero_grad()
            loss.backward()
            optim.step()

        total_loss += loss.item()
    return total_loss / len(dataloader.dataset)


def init_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
            init.xavier_normal_(m.weight.data)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.Sequential):
            for sub_mod in m:
                init_weights(sub_mod)

In [100]:
class MaskedConv2d(nn.Conv2d):
    
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv2d, self).__init__(*args, **kwargs)
        self.mask_type = mask_type
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kernel_height, kernel_width = self.weight.size()
        self.mask.fill_(0)
        half_h, half_w = kernel_height // 2, kernel_width // 2
    
        self.mask[:, :, :half_h, :] = 1.0
        self.mask[:, :, half_h, :half_w] = 1.0
        if self.mask_type == 'A':
            self.mask[:, :, half_h, half_w] = 1.0
        
    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

    
class MaskedNet(nn.Module):
    def __init__(self, n_layers, n_filters):
        super(MaskedNet, self).__init__()
        self.n_layers = n_layers
        self.n_filters = n_filters
        self.hs = [self.n_filters] * self.n_layers
        
        self.network = []
        self.network.extend([
                            MaskedConv2d('A', 1, self.n_filters, 7, 1, 3, bias=False),
                            nn.BatchNorm2d(self.n_filters),
                            nn.ReLU(True)
                            ])
        
        for in_feat, out_feat in zip(self.hs, self.hs[1:]):
            self.network.extend([
                            MaskedConv2d('B', in_feat, out_feat, 7, 1, 3, bias=False),
                            nn.BatchNorm2d(out_feat),
                            nn.ReLU(True)
                            ])
        
        self.network.extend([nn.Conv2d(self.n_filters, 256, 1)])
        self.network = nn.Sequential(*self.network)
    
#         mask_layers = [l for l in self.network.modules() if isinstance(l, MaskedConv2d)]
#         for layer in mask_layers:
#             print(layer.mask_type)
    
    def forward(self, x):
        return self.network(x)

In [101]:
fm = 64
model = MaskedNet(8, 64)
model = model.cuda() if use_cuda else model
model.apply(init_weights)

optim = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=45, gamma=0.1)
    
loss_fn = F.cross_entropy
n_epochs = 1

train_loss = []
val_loss = []
for epoch in tqdm.tqdm_notebook(range(0, n_epochs)):
    scheduler.step(epoch)
    t_loss = train_validate(model, dataloader_train, optim, loss_fn, train=True)
    train_loss.append(t_loss)

    if epoch % 100 == 0:
        v_loss = train_validate(model, dataloader_val, optim, loss_fn, train=False)
        val_loss.append(v_loss)
        
    

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.3/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 