In [2]:
import os
import torch
import numpy as np
import pickle as pkl
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision import transforms
from skimage.color import rgb2lab, lab2rgb
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

In [2]:
class FinalDataset(Dataset):

    def __init__(self, X_path, y_path):
      with open(X_path, 'rb') as x:
        with open(y_path, 'rb') as y:
          size = pkl.load(x)
          self.X = pkl.load(x)
          self.y = pkl.load(y)
          for i in range(size - 1):
            self.X = torch.cat([self.X, pkl.load(x)], dim= 0)
            self.y = torch.cat([self.y, pkl.load(y)], dim= 0)

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [3]:
class FinalDatasetLite(Dataset):
    def __init__(self, X_path, y_path):
        self.x = open(X_path, 'rb')
        self.y = open(y_path, 'rb')

        self.size = pkl.load(self.x)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return pkl.load(self.x).squeeze(1), pkl.load(self.y).squeeze(1)

    def permuteBatch(X):
        true_batch = 0
        for i in range(X.shape[0]):
          true_batch += X.shape[1]
        return X.view(true_batch, X.shape[2], X.shape[3])

In [None]:
def imshow(X, y):
    y = y.detach()
    lab_imgs = torch.cat([(X + 1) * 50., y * 110.], dim=1).permute(0, 2, 3, 1).cpu().numpy() # batch_size * h * w * c
    rgb_imgs = []
    for img in lab_imgs:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    for i in range(min(5, len(rgb_imgs))):
        ax = plt.subplot(1, 5, i + 1)
        ax.imshow(rgb_imgs[i])
        ax.axis("off")
    plt.show()

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channel, drop_prob, downSample = False):
    super(ResidualBlock, self).__init__()
    out_channel = in_channel
    if(downSample):
        out_channel = 2 * in_channel
        self.downSampleConv = nn.Conv2d(in_channel, out_channel, kernel_size=1)
        self.downSampleBatchNorm = nn.BatchNorm2d(out_channel)
    self.downSample = downSample
    self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding = 1)
    self.batchNorm1 = nn.BatchNorm2d(out_channel)
    self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding = 1)
    self.batchNorm2 = nn.BatchNorm2d(out_channel)
    self.drop = nn.Dropout2d(drop_prob)
    self.act = nn.ReLU(True)
    self._initialize_weights()

  def forward(self, x):
    residual = x
    x = self.conv1(x)
    x = self.batchNorm1(x)
    x = self.conv2(x)
    x = self.batchNorm2(x)
    x = self.drop(x)
    if(self.downSample):
        residual = self.downSampleConv(residual)
        residual = self.downSampleBatchNorm(residual)
    x += residual
    return self.act(x)

  def _initialize_weights(self):
    nn.init.constant_(self.batchNorm2.weight, 0)

In [None]:
def residual_layer(block, layerSize, in_channel, drop_prob):
    layers = []
    layers.append(block(in_channel, drop_prob, True))                       # First layer requires downsample and set num of filters to twice from before
    for i in range(layerSize-1):
        layers.append(block(2*in_channel, drop_prob))
    return nn.Sequential(*layers)

    return self.act(x)

In [None]:
def clearCache():
  torch.cuda.empty_cache()

In [None]:
def saveModel(model, filepath):
  torch.save(model.state_dict(), filepath)
  print('\nSaved model to ' + filepath + '.')