## PyTorch implementation of Deep Bilteral Learning for Real Time Image Enhancement

In [189]:
from __future__ import division, print_function, unicode_literals
import numpy as np
from PIL import Image
import os, sys, glob
# import matplotlib.pyplot as plt
#Torch Imports
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms

In [190]:
size = (256, 256)
batch_size = 100
learning_rate = 0.01
num_epochs = 5

In [191]:
class Dataset(data.Dataset):
    def __init__(self, root_dir, train, transform=None):
        #Init Function
        super(Dataset, self).__init__()
        self.root_dir = root_dir
        self.train = train
        self.transform = transform
        
        self.full_res = []
        self.low_res = []
        
        if (train):
            dir = self.root_dir + '/train/'
        else :
            dir = self.root_dir + '/test/'
        
        for img_path in glob.glob (dir + '*.jpg'):
            
            himage = Image.open (img_path)
            limage = himage.resize (size)
            
            self.full_res.append (himage)
            self.low_res.append (limage)

    def __len__(self):
        #Length function ?
        return len(self.full_res)

    def __getitem__(self, idx):
        #Accessor Function
        if self.transform is None:
            return (self.full_res[idx],self.low_res[idx])
        else:
            limg_transformed = self.transform(self.low_res[idx])
            himg_transformed =  self.transform(self.full_res[idx])
            return (himg_transformed, limg_transformed)
       

In [192]:
composed_transform = transforms.Compose([transforms.ToTensor()])
train_dataset = Dataset (root_dir = '../data', train = True, transform = composed_transform)
im, im2 = train_dataset.__getitem__(0)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [292]:
class LocalFeatureNet (nn.Module):
           
    def __init__(self):
        super (LocalFeatureNet, self).__init__()
        
        self.relu  = nn.ReLU (inplace = True)
        
        self.conv1 = nn.Conv2d (in_channels = 3,   out_channels = 8, kernel_size = 3, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d (in_channels = 8,  out_channels = 16,  kernel_size = 3, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d (in_channels = 16, out_channels = 32,  kernel_size= 3, stride = 2,padding = 1)
        self.conv4 = nn.Conv2d (in_channels = 32, out_channels = 64,  kernel_size = 3, stride = 2, padding = 1)
        
        
        self.localconv1 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 1, padding = 1)
        self.localconv2 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size= 3, stride = 1, padding = 1)
        
        self.globalconv1 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 2,padding = 1)
        self.globalconv2 = nn.Conv2d (in_channels = 64, out_channels = 64,  kernel_size = 3, stride = 2, padding = 1)
        
        self.globalfc1 = nn.Linear (1024, 256)
        self.globalfc2 = nn.Linear (256, 128)
        self.globalfc3 = nn.Linear (128, 64)
        self.linear = nn.Conv2d(in_channels = 64, out_channels = 96,  kernel_size = 1, stride = 1)
        
        # Pixel Wise Network
        self.pixelwise_bias = nn.Parameter (torch.rand(3, 1), requires_grad=True)
        self.pixelwise_weight = nn.Parameter (torch.eye(3), requires_grad = True)
        self.pixelwise_obias = nn.Parameter (torch.eye(1), requires_grad = True)
        
        self.relu_slopes = nn.Parameter(torch.rand(3,16), requires_grad = True)
        self.relu_shifts = nn.Parameter(torch.rand(16,3), requires_grad = True)
        
    def custom_relu(self,channel,value):
        print("value",value.size())
        size = (16,value.size()[0],value.size()[1],value.size()[2])
        size_alt = (1L,value.size()[0],value.size()[1],value.size()[2])
        print("size: ",size,"size_alt",size_alt)
        value = value.expand(size)
        
        print(self.relu_shifts[:,channel])
        a = self.relu_shifts[:,channel].clone()
        a = a.view(16,1,1,1)
        print(a.data,a)
        value = self.relu(value - a.repeat([(size_alt)]))
        value = self.relu_slopes[channel,:] * value
        
        return value
    
    def forward (self, h, l):
        
#         print (x.size())
        
        x = self.relu (self.conv1 (l))
#         print (x.size())
        x = self.relu (self.conv2 (x))
#         print (x.size())
        x = self.relu (self.conv3 (x))
#         print (x.size())
        x = self.relu (self.conv4 (x))
#         print (x.size())
#         print ("lel")
        y = self.localconv1 (x)
#         print (y.size())
        y = self.localconv2 (y)
        
        print (y.size())
        
        z = self.globalconv1 (x)
        z = self.globalconv2 (z)
        z = self.globalfc1 (z.view(1, -1))
        z = self.globalfc2 (z)
        z = self.globalfc3 (z)
        
        print (z.size())
        
        
        fused = self.relu(z.view(-1,64,1,1)+y)
        print (fused.size())
        
        lin = self.linear(fused)
        print (lin.size())
        
        bilat = lin.view(-1, 8, 3, 4, 16, 16)
        
        
#         bilat = self.bilateralGrid (fused)
        
        print (bilat.size())
        
#         return bilat

# pixel wise network
        for i in range(0,3):
            a = self.pixelwise_weight[i,:].view(1,3)
            b = h.unsqueeze(0).view(3,-1)
            print(a.size(),b.size())
            p = torch.mm(a,b)
            print(p.size())
            p = p.view(h.size()[0],h.size()[2],h.size()[3]) + self.pixelwise_bias[i]
            print(p.size())
#             p = torch.bmm(self.pixelwise_weight[i,:], h.unsqueeze(0).view(3,-1)).view(h.size()) + self.pixelwise_bias[i]
            print(i)
            p += self.custom_relu(i,p)
        p += self.pixelwise_obias
        print(p.size())

In [293]:
model = LocalFeatureNet() 

# Add code for using CUDA here if it is available
use_gpu = False
if(torch.cuda.is_available()):
    use_gpu = True
    model.cuda()

# Loss function and optimizers
criterion = nn.CrossEntropyLoss()# Define cross-entropy loss
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)# Use Adam optimizer, use learning_rate hyper parameter

In [294]:
def train():
    # Code for training the model
    # Make sure to output a matplotlib graph of training losses
    loss_arr = []
    for epoch in range(num_epochs):
        for i, (himage, limage) in enumerate(train_loader):  
            # Convert torch tensor to Variable
            himage = Variable(himage)
            limage = Variable(limage)
            if(use_gpu):
                himage=himage.cuda()
                limage=limage.cuda()
            # Forward + Backward + Optimize
            optimizer.zero_grad()  # zero the gradient buffer
            outputs = model(himage, limage)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             loss_arr.append(loss.data[0])
#             if (i+1) % batch_size == 0:       
#                 print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' 
#                        %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.data[0]))
    
#     plt.plot( np.array(range(1,len(loss_arr)+1)), np.array(loss_arr))
#     plt.show()


In [295]:
%time train()

torch.Size([1, 64, 16, 16])
torch.Size([1, 64])
torch.Size([1, 64, 16, 16])
torch.Size([1, 96, 16, 16])
torch.Size([1, 8, 3, 4, 16, 16])
torch.Size([1, 3]) torch.Size([3, 2073600])
torch.Size([1, 2073600])
torch.Size([1, 1080, 1920])
0
value torch.Size([1, 1080, 1920])
size:  (16, 1L, 1080L, 1920L) size_alt (1L, 1L, 1080L, 1920L)
Variable containing:
 0.8218
 0.6272
 0.1816
 0.5466
 0.6069
 0.6201
 0.0164
 0.9036
 0.9984
 0.6273
 0.6485
 0.3293
 0.9292
 0.1642
 0.0149
 0.9374
[torch.FloatTensor of size 16]


(0 ,0 ,.,.) = 
  0.8218

(1 ,0 ,.,.) = 
  0.6272

(2 ,0 ,.,.) = 
  0.1816

(3 ,0 ,.,.) = 
  0.5466

(4 ,0 ,.,.) = 
  0.6069

(5 ,0 ,.,.) = 
  0.6201

(6 ,0 ,.,.) = 
  0.0164

(7 ,0 ,.,.) = 
  0.9036

(8 ,0 ,.,.) = 
  0.9984

(9 ,0 ,.,.) = 
  0.6273

(10,0 ,.,.) = 
  0.6485

(11,0 ,.,.) = 
  0.3293

(12,0 ,.,.) = 
  0.9292

(13,0 ,.,.) = 
  0.1642

(14,0 ,.,.) = 
  0.0149

(15,0 ,.,.) = 
  0.9374
[torch.FloatTensor of size 16x1x1x1]
 Variable containing:
(0 ,0 ,.,.) = 
  0.8218

(1 

TypeError: torch.Size() takes an iterable of 'int' (item 0 is 'list')