In [1]:
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import os
from skimage import io
from skimage.transform import resize, downscale_local_mean
import numpy as np
from Utilities.data import DataWrapper



from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class GenResLego(nn.Module):
    def __init__(self, inp_channels=64, n=64, k=3, s=1):
        super(GenResLego, self).__init__()
        p = int(k/2)
        self.conv1 = nn.Conv2d(inp_channels, n, kernel_size=k, padding=p, stride=s, bias=True)
        self.bn1 = nn.BatchNorm2d(n)
        self.prelu = nn.PReLU(n)
        self.conv2 = nn.Conv2d(n, n, kernel_size=k, padding=1, stride=s, bias=True)
        self.bn2 = nn.BatchNorm2d(n)
                
    def forward(self,x):
        y = self.prelu(self.bn1(self.conv1(x)))
        z = self.bn2(self.conv2(y))
        out = z + x
        return out

In [3]:
class GenResBlock(nn.Module):
    def __init__(self, inp_channels=64, B=16, n=64, k=3, s=1):
        super(GenResBlock, self).__init__()
        
        assert B>=1
        
        self.building_Blocks = nn.Sequential()
        self.building_Blocks.add_module('Residual_Block_0' , GenResLego(inp_channels=inp_channels, 
                                                                        n=n, k=k, s=s))
        for b in range(B-1):
            self.building_Blocks.add_module('Residual_Block_' + str(b+1) , 
                                            GenResLego(inp_channels=n, n=n, k=k, s=s))
        p = int(k/2)    
        self.final_conv = nn.Conv2d(n, n, kernel_size=k, padding=p, stride=s, bias=True)
        self.final_bn = nn.BatchNorm2d(n)
        
    def forward(self,x):
        y = self.building_Blocks.forward(x)
        z = self.final_conv(y)
        out = self.final_bn(z) + x
        return out

In [4]:
class GenUpsampleLego(nn.Module):
    def __init__(self, inp_channels=64, n=256, k=3, s=1, upscale_factor=2):
        super(GenUpsampleLego, self).__init__()
        
        p = int(k/2)
        self.conv = nn.Conv2d(inp_channels, n, kernel_size=k, padding=p, stride=s, bias=True)
        self.upsamp = nn.PixelShuffle(upscale_factor = upscale_factor)
        self.prelu = nn.PReLU(inp_channels)
        
    def forward(self,x):
        y = self.conv(x)
        z = self.upsamp(y)
        out = self.prelu(z)
        return out

In [5]:
class GenUpsampleBlock(nn.Module):
    def __init__(self, upsample_B=2, inp_channels=64, n=256, k=3, s=1, upscale_factor=2):
        super(GenUpsampleBlock, self).__init__()
        
        errorlog ='The upscale factor controls the pixel shuffler and changes then number of channels. ' 
        errorlog +='Make sure that you choose the hyperparams in a way that n == upscale_factor^2 * inp_channels.'
        assert n == upscale_factor * upscale_factor * inp_channels ,  errorlog
        
        
        self.building_Blocks = nn.Sequential()
        for b in range(upsample_B):
            self.building_Blocks.add_module('Upsample_Block_' + str(b) ,
                                            GenUpsampleLego(inp_channels=inp_channels, n=n,
                                                            k=k, s=s, upscale_factor=upscale_factor))
        
    def forward(self,x):
        out = self.building_Blocks.forward(x)
        return out

In [6]:
class Generator(nn.Module):
    def __init__(self, first_stage_hyperparams={'k':9, 'n':64, 's':1}, 
                 residual_blocks_hyperparams={'k':3, 'n':64, 's':1, 'B':16}, 
                 upsample_blocks_hyperparams={'k':3, 'n':256, 's':1, 'B':2, 'f':2}, 
                 last_stage_hyperparams={'k':9, 's':1} ):
        super(Generator, self).__init__()
        
        fsh = first_stage_hyperparams
        rbh = residual_blocks_hyperparams
        ubh = upsample_blocks_hyperparams
        lsh = last_stage_hyperparams
        
        fsh['p']=int(fsh['k']/2)
        lsh['p']=int(lsh['k']/2)
        self.first_stage_conv = nn.Conv2d(3, fsh['n'], kernel_size=fsh['k'], 
                                          padding=fsh['p'], stride=fsh['s'], bias=True)
        self.first_stage_prelu = nn.PReLU(fsh['n'])
        self.ResBlocks = GenResBlock(inp_channels=fsh['n'], n=rbh['n'], k=rbh['k'], s=rbh['s'], B=rbh['B'])
        self.UpscaleBlocks = GenUpsampleBlock(upsample_B=ubh['B'], inp_channels=rbh['n'], n=ubh['n'], k=ubh['k'], 
                                              s=ubh['s'], upscale_factor=ubh['f'])
        self.last_stage_conv = nn.Conv2d(rbh['n'], 3, kernel_size=lsh['k'], 
                                          padding=lsh['p'], stride=lsh['s'], bias=True)
        
        
    def forward(self,x):
        y = self.first_stage_conv(x)
        z = self.first_stage_prelu(y)
        u = self.ResBlocks(z)
        v = self.UpscaleBlocks(u)
        out = self.last_stage_conv(v)
        return out

In [2]:
IS_GPU=True
import torch.nn.init as init

def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform(m.weight, gain=np.sqrt(2))
        init.constant(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        init.constant(m.weight, 1)
        init.constant(m.bias, 0)
    
# Create an instance of the nn.module class defined above:
net = Generator(first_stage_hyperparams={'k':9, 'n':64, 's':1}, 
                 residual_blocks_hyperparams={'k':3, 'n':64, 's':1, 'B':16}, 
                 upsample_blocks_hyperparams={'k':3, 'n':256, 's':1, 'B':2, 'f':2}, 
                 last_stage_hyperparams={'k':9, 's':1} )

net.apply(conv_init)

# For training on GPU, we need to transfer net and data onto the GPU
# http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu
if IS_GPU:
    import torch.backends.cudnn as cudnn
    net = net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(net))

a =  (81*3 + 1 ) *64 +64
a += ( (9*64 + 1)*64*2 + 5*64)*16 + (9*64 + 1)*64 + 2*64
a += ((9*64 + 1)*256 + 64)*2
a += (81*64+1)*3
print(a)

#print(net)

1550659
1550659


In [3]:
########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Here we use Cross-Entropy loss and SGD with momentum.
# The CrossEntropyLoss criterion already includes softmax within its
# implementation. That's why we don't use a softmax in our model
# definition.

import torch.optim as optim

#criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()

# Tune the learning rate.
# See whether the momentum is useful or not
#optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.005)

In [4]:
training=DataWrapper()
for i_batch , sample_batch in enumerate(training.dataset):
    #print((sample_batch))
    break

In [5]:
EPOCHS=1
for epoch in range(EPOCHS):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(training.loader, 0):
        # get the inputs
        
        inputs = data['Low']
        labels = data['High']

        if IS_GPU:
            inputs = inputs.cuda()
            labels = labels.cuda()

        # wrap them in Variable
        inputs = Variable(inputs)
        labels = Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        #print(inputs.shape)
        #print(outputs.shape)
        #print(labels.shape)
        #print(outputs.type)
        #print(labels.type)
        loss = criterion(outputs, labels)
        #a=(torch.cuda.FloatTensor(1))
        loss.backward()
        optimizer.step()

        # print statistics
        print(np.mean(np.square(outputs.cpu().data.numpy() - labels.cpu().data.numpy())))
        running_loss += loss.data[0]
        
        print(running_loss)
        break

1.4342948
1.434294581413269
