## PyTorch implementation of Deep Bilteral Learning for Real Time Image Enhancement

In [35]:
from __future__ import division, print_function, unicode_literals
import numpy as np
from PIL import Image
import os, sys, glob
# import matplotlib.pyplot as plt
#Torch Imports
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms

In [28]:
size = (256, 256)
batch_size = 100
learning_rate = 0.01
num_epochs = 5

In [29]:
class Dataset(data.Dataset):
    def __init__(self, root_dir, train, transform=None):
        #Init Function
        super(Dataset, self).__init__()
        self.root_dir = root_dir
        self.train = train
        self.transform = transform
        
        self.full_res = []
        self.low_res = []
        
        if (train):
            dir = self.root_dir + '/train/'
        else :
            dir = self.root_dir + '/test/'
        
        for img_path in glob.glob (dir + '*.jpg'):
            
            himage = Image.open (img_path)
            limage = himage.resize (size)
            
            self.full_res.append (himage)
            self.low_res.append (limage)

    def __len__(self):
        #Length function ?
        return len(self.full_res)

    def __getitem__(self, idx):
        #Accessor Function
        if self.transform is None:
            return (self.full_res[idx],self.low_res[idx])
        else:
            limg_transformed = self.transform(self.low_res[idx])
            himg_transformed =  self.transform(self.low_res[idx])
            return (limg_transformed, himg_transformed)
       

In [37]:
composed_transform = transforms.Compose([transforms.ToTensor()])
train_dataset = Dataset (root_dir = '../data', train = True, transform = composed_transform)
im, im2 = train_dataset.__getitem__(0)
# im2.show()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [64]:
class LocalFeatureNet (nn.Module):

#     def fusionLayer (self, localfeat, globalfeat, bias = True):
#         self.local_feature = localfeat
#         self.global_feature = globalfeat
            
#         self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
#         if bias:
#             self.bias = nn.Parameter(torch.Tensor(output_features))
#         else:
#             # You should always register all possible parameters, but the
#             # optional ones can be None if you want.
#             self.register_parameter('bias', None)

#         # Not a very smart way to initialize weights
#         self.weight.data.uniform_(-0.1, 0.1)
#         if bias is not None:
#             self.bias.data.uniform_(-0.1, 0.1)
        
#     def bilateralGrid (self, locaFeat, globalFeat, bias = True):
           
    def __init__(self):
        super (LocalFeatureNet, self).__init__()
        
        self.relu  = nn.ReLU (inplace = True)
        
        self.conv1 = nn.Conv2d (in_channels = 3,   out_channels = 8, kernel_size = 3, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d (in_channels = 8,  out_channels = 16,  kernel_size = 3, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d (in_channels = 16, out_channels = 32,  kernel_size= 3, stride = 2,padding = 1)
        self.conv4 = nn.Conv2d (in_channels = 32, out_channels = 64,  kernel_size = 3, stride = 2, padding = 1)
        
        
        self.localconv1 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 1, padding = 1)
        self.localconv2 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size= 3, stride = 1, padding = 1)
        
        self.globalconv1 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 2,padding = 1)
        self.globalconv2 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 2, padding = 1)
        
        self.globalfc1 = nn.Linear (1024, 256)
        self.globalfc2 = nn.Linear (256, 128)
        self.globalfc3 = nn.Linear (128, 64)
        
    def forward (self, x):
        
#         print (x.size())
        
        x = self.relu (self.conv1 (x))
#         print (x.size())
        x = self.relu (self.conv2 (x))
#         print (x.size())
        x = self.relu (self.conv3 (x))
#         print (x.size())
        x = self.relu (self.conv4 (x))
#         print (x.size())
#         print ("lel")
        y = self.localconv1 (x)
#         print (y.size())
        y = self.localconv2 (y)
        
        
        print (y.size())
        
        z = self.globalconv1 (x)
        z = self.globalconv2 (z)
        z = self.globalfc1 (z.view(1, -1))
        z = self.globalfc2 (z)
        z = self.globalfc3 (z)
        
        print (z.size())
        
        
        fused = self.fusionLayer (y, z)
        
#         print fused.size()
        
#         bilat = self.bilateralGrid (fused)
        
#         print bilat.size()
        
#         return bilat

In [65]:
model = LocalFeatureNet() 

# Add code for using CUDA here if it is available
use_gpu = False
if(torch.cuda.is_available()):
    use_gpu = True
    model.cuda()

# Loss function and optimizers
criterion = nn.CrossEntropyLoss()# Define cross-entropy loss
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)# Use Adam optimizer, use learning_rate hyper parameter

In [66]:
def train():
    # Code for training the model
    # Make sure to output a matplotlib graph of training losses
    loss_arr = []
    for epoch in range(num_epochs):
        for i, (himage, limage) in enumerate(train_loader):  
            # Convert torch tensor to Variable
            himage = Variable(himage)
            limage = Variable(limage)
            if(use_gpu):
                himage=himage.cuda()
                limage=limage.cuda()
            # Forward + Backward + Optimize
            optimizer.zero_grad()  # zero the gradient buffer
            outputs = model(limage)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             loss_arr.append(loss.data[0])
#             if (i+1) % batch_size == 0:       
#                 print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' 
#                        %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
    
#     plt.plot( np.array(range(1,len(loss_arr)+1)), np.array(loss_arr))
#     plt.show()


In [67]:
%time train()

torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
CPU times: user 52 ms, sys: 8 ms, total: 60 ms
Wall time: 38.4 ms
