In [0]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from google.colab import drive
from timeit import default_timer as timer

import os
import math

!git clone https://github.com/jorge-pessoa/pytorch-msssim.git
%cd pytorch-msssim/
!python setup.py install
from pytorch_msssim import ssim, msssim

import matplotlib as mpl
mpl.rcParams['figure.dpi']= 100

Cloning into 'pytorch-msssim'...
remote: Enumerating objects: 136, done.[K
remote: Total 136 (delta 0), reused 0 (delta 0), pack-reused 136[K
Receiving objects: 100% (136/136), 1.24 MiB | 3.58 MiB/s, done.
Resolving deltas: 100% (58/58), done.
/content/pytorch-msssim/pytorch-msssim
running install
running build
running build_py
creating build
creating build/lib
creating build/lib/pytorch_msssim
copying pytorch_msssim/__init__.py -> build/lib/pytorch_msssim
running install_lib
copying build/lib/pytorch_msssim/__init__.py -> /usr/local/lib/python3.6/dist-packages/pytorch_msssim
byte-compiling /usr/local/lib/python3.6/dist-packages/pytorch_msssim/__init__.py to __init__.cpython-36.pyc
running install_egg_info
Removing /usr/local/lib/python3.6/dist-packages/pytorch_msssim-0.1.egg-info
Writing /usr/local/lib/python3.6/dist-packages/pytorch_msssim-0.1.egg-info


**Mount drive**

In [0]:
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


**JIT Load Custom Op**

In [0]:
!pip install ninja
from torch.utils.cpp_extension import load
op_path = '/content/gdrive/My Drive/bilateral_slice_op/'
bsliceapply = load(name='bilateral_slicing', sources=[os.path.join(op_path,'bilateral_slicing.cpp'), os.path.join(op_path,'bilteral_slicing_kernel.cu')])



**Layer**

In [0]:
class Convolutional_Layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding = 1, activation=nn.ReLU, batch_norm=False, bias=True):
        super(Convolutional_Layer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding=padding, bias=bias)
        self.batch_norm = nn.BatchNorm2d(out_channels) if batch_norm else None
        self.activation = None if activation is None else activation()

    def forward(self, x):
        out = self.conv(x)
        if self.batch_norm is not None:
            out = self.batch_norm(out)
        if self.activation is not None:
            out = self.activation(out)
        #print("Conv" , out.shape)
        return out

class Fully_Connected_Layer(nn.Module):
    def __init__(self, in_features, out_features, activation=nn.ReLU, batch_norm=False, bias=True):
        super(Fully_Connected_Layer, self).__init__()
        self.FC = nn.Linear(in_features, out_features, bias=bias)
        self.batch_norm = nn.BatchNorm1d(out_features) if batch_norm else None
        self.activation = None if activation is None else activation()

    def forward(self, x):
        out = self.FC(x)
        if self.batch_norm is not None:
            out = self.batch_norm(out)
        if self.activation is not None:
            out = self.activation(out)
        #print("FC", out.shape)
        return out

class Slicing_Apply_Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, grid, guide, frinput):
        has_offset = True;
        output = bsliceapply.forward(grid, guide, frinput, has_offset)
        ctx.save_for_backward(grid, guide, frinput)
        ctx.offset = has_offset
        return output

    @staticmethod
    def backward(ctx, grad):
        grid, guide, frinput = ctx.saved_tensors
        has_offset = ctx.offset
        outputs = bsliceapply.backward(grid, guide, frinput, grad, has_offset)
        grad_grid, grad_guide, grad_frinput = outputs
        return grad_frinput, grad_guide, grad_grid, None


**Grad Check**

In [0]:
from torch.autograd import gradcheck

grid = torch.rand(2, 12, 8, 3, 3,dtype=torch.double, requires_grad=True).cuda()
guide =  torch.rand(2, 3, 3, dtype=torch.double, requires_grad=True).cuda()
frinput =  torch.rand(2, 3, 3, 3,dtype=torch.double, requires_grad=True).cuda()


def grad_test(grid, guide, frinput):
  return Slicing_Apply_Function.apply(grid, guide, frinput)

is_grad_correct = gradcheck(grad_test, [grid, guide, frinput], eps=1e-3, atol=1e1, raise_exception=True)
print(is_grad_correct)

True


**Model**

In [0]:
class Lr_Splat(nn.Module):  # Extract low-level features
    def __init__(self, cm, sb, lb, bn, lris):
        super(Lr_Splat, self).__init__()
        n_conv_layers = int(np.log2(lris / sb))  # Number of conv layers required to reduce the spatial size to sb
        self.splat_layers = nn.ModuleList()
        in_channels = 3
        for i in range(n_conv_layers):
            b_n = bn if i > 0 else False
            out_channels = cm * (2 ** i) * lb
            self.splat_layers.append(Convolutional_Layer(in_channels=in_channels,
                                                         out_channels=out_channels,
                                                         kernel_size=3, stride=2, batch_norm=b_n))
            in_channels = out_channels

    def forward(self, x):
        out = x
        for layer in self.splat_layers:
            out = layer(out)
        return out


class Lr_LocalFeatures(nn.Module):  # Local features in low-res stream
    def __init__(self, cm, sb, lb, bn, lris):
        super(Lr_LocalFeatures, self).__init__()
        n_lr_splat_channels = int(cm * (2 ** int(np.log2(lris / sb) - 1)) * lb)
        self.lf_layers = nn.ModuleList()
        b_n = bn if bn else False
        self.lf_layers.append(Convolutional_Layer(in_channels=n_lr_splat_channels,
                                                  out_channels=n_lr_splat_channels,
                                                  kernel_size=3, stride=1, batch_norm=b_n))
        self.lf_layers.append(Convolutional_Layer(in_channels=n_lr_splat_channels,
                                                  out_channels=n_lr_splat_channels,
                                                  kernel_size=3, stride=1, activation=None))

    def forward(self, x):
        out = x
        for layer in self.lf_layers:
            out = layer(out)
        return out


class Lr_GlobalFeatures(nn.Module):  # Global features in low-res stream
    def __init__(self, cm, sb, lb, bn, lris):
        super(Lr_GlobalFeatures, self).__init__()
        n_lr_splat_channels = int(cm * (2 ** int(np.log2(lris / sb) - 1)) * lb)
        n_splat_conv_layers = int(np.log2(lris / sb))
        n_lrgf_conv_layers = int(np.log2(sb / 4))
        self.gf_conv_layers = nn.ModuleList()
        self.gf_fc_layers = nn.ModuleList()
        b_n = bn if bn else False
        #Convolution Layers
        for i in range(n_lrgf_conv_layers):
            self.gf_conv_layers.append(Convolutional_Layer(in_channels=n_lr_splat_channels,
                                                           out_channels=n_lr_splat_channels,
                                                           kernel_size=3, stride=2, batch_norm=b_n))
        #Fully Connected Layers
        n_prev_layer_size = int((lris / 2 ** (n_splat_conv_layers + n_lrgf_conv_layers)) ** 2)
        self.gf_fc_layers.append(Fully_Connected_Layer(in_features=n_prev_layer_size * n_lr_splat_channels,
                                                       out_features=32 * cm * lb,
                                                       batch_norm=b_n))
        self.gf_fc_layers.append(Fully_Connected_Layer(in_features=32 * cm * lb,
                                                       out_features=16 * cm * lb,
                                                       batch_norm=b_n))
        self.gf_fc_layers.append(Fully_Connected_Layer(in_features=16 * cm * lb,
                                                       out_features=8 * cm * lb, activation=None))

    def forward(self,x):
        out = x
        for layer in self.gf_conv_layers:
            out = layer(out)
        out = out.view(list(out.size())[0],-1) #keep batch size
        #print(out.shape)
        for layer in self.gf_fc_layers:
            out = layer(out)
        return out



class Fusion(nn.Module):
    def __init__(self):
        super(Fusion, self).__init__()
        self.Relu = nn.ReLU()

    def forward(self,LrLocalFeatures, LrGlobalFeatures):
      Rs_LrGlobalFeatures = LrGlobalFeatures.view(list(LrGlobalFeatures.size())[0],list(LrGlobalFeatures.size())[1], 1, 1) #Pytorch: [batch size, channel, size, size]
      #print(Rs_LrGlobalFeatures.shape)
      out = torch.add(LrLocalFeatures, Rs_LrGlobalFeatures)
      out = self.Relu(out)
      #print(out.shape)
      return out

class LinearPredict_BGrid(nn.Module):
    def __init__(self,cm, lb, nin=4, nout=3):
        super(LinearPredict_BGrid, self).__init__()
        self.lb = lb
        self.conv = Convolutional_Layer(in_channels=8 * cm * lb, out_channels= lb * nin * nout,
                                        kernel_size=1, stride=1, padding=0,activation=None) #No batch norm

    def forward(self,x):
        batch_size = list(x.size())[0]
        out = x
        out = self.conv(out) # [batch_size, 96, 16, 16]
        out = torch.stack(tensors = torch.split(tensor=out,split_size_or_sections=self.lb,dim=1),dim = 1) #unroll grid
        #print(out.shape)
        return out

class Guide_PointwiseNN(nn.Module):
    def __init__(self, bn,guide_complexity=16):
      super(Guide_PointwiseNN, self).__init__()
      b_n = bn if bn else False
      self.conv1 = Convolutional_Layer(in_channels=3, out_channels=guide_complexity,
                                        kernel_size=1, stride= 1,padding=0,batch_norm=b_n)
      self.conv2 = Convolutional_Layer(in_channels=guide_complexity,out_channels=1,
                                        kernel_size=1,stride=1, padding=0, activation=nn.Sigmoid)

    def forward(self, x):
        out = x
        out = self.conv1(out)
        out = self.conv2(out)
        return out.squeeze(1)

class SliceNApply(nn.Module):
  def __init__(self):
    super(SliceNApply, self).__init__()
   
  def forward(self,Bilterial_Grid,Guide,fr):
      return Slicing_Apply_Function.apply(Bilterial_Grid,Guide,fr)

class Net(nn.Module):
  def __init__(self,_cm,_sb,_lb,_bn,_lris):
    super(Net,self).__init__()
    self.splat = Lr_Splat(cm=_cm,sb=_sb,lb=_lb,bn=_bn,lris=_lris)
    self.localf = Lr_LocalFeatures(cm=_cm,sb=_sb,lb=_lb,bn=_bn,lris=_lris)
    self.globalf = Lr_GlobalFeatures(cm=_cm,sb=_sb,lb=_lb,bn=_bn,lris=_lris)
    self.fusion = Fusion()
    self.bgrid = LinearPredict_BGrid(cm=_cm,lb=_lb)
    self.guide = Guide_PointwiseNN(bn=_bn)
    self.slice_op = SliceNApply()
  
  def forward(self,lr, fr):
    out = self.splat(lr)
    local_out = self.localf(out)
    global_out = self.globalf(out)
    fus_out = self.fusion(local_out,global_out)
    bg_out = self.bgrid(fus_out)
    g_out = self.guide(fr)
    fin_out = self.slice_op(bg_out,g_out,fr)
    return g_out,fin_out

**Testing on the model structure** (for Debug)

In [0]:
print("Splat")
X = torch.rand(1,3,256,256)
lrsplat = Lr_Splat(cm=1,sb=16,lb=8,bn=True,lris=256)
out = lrsplat(X)
print(list(out.size()))
print("Local Features")
X = torch.rand(1, 64, 16, 16)
lrlocalfeatures = Lr_LocalFeatures(cm=1,sb=16,lb=8,bn=True,lris=256)
out = lrlocalfeatures(X)
print(list(out.size()))
print("Global Features")
X = torch.rand(2, 64, 16, 16)
lrglobalfeatures = Lr_GlobalFeatures(cm=1,sb=16,lb=8,bn=True,lris=256)
out = lrglobalfeatures(X)
print(list(out.size()))
print("Fusion and Bilterial Grid")
LrLocal = torch.rand(2, 64, 16, 16)
LrGlobal = torch.rand(2,64)
fusion = Fusion()
bg = LinearPredict_BGrid(cm=1,lb=8)
out = fusion(LrLocal,LrGlobal)
bg_out = bg(out)
print(list(bg_out.size()))
print("GuideNN")
X = torch.rand(2, 3, 1080, 1920)
guide = Guide_PointwiseNN(bn=True)
g_out = guide(X)
print(list(g_out.size()))
print("Slice and Apply Coefficients")
slice_op = SliceNApply()
bg_out = bg_out.cuda()
g_out = g_out.cuda()
X = X.cuda()
out = slice_op(bg_out,g_out,X)
print(list(out.size()))
print("Net")
lr = torch.rand(2,3,256,256).cuda()
fr = torch.rand(2, 3, 1080, 1920).cuda()
net = Net(_cm=1,_sb=16,_lb=8,_bn=True,_lris=256).cuda()
out = net(lr,fr)[1]
print(out.shape)

Splat
[1, 64, 16, 16]
Local Features
[1, 64, 16, 16]
Global Features
[2, 64]
Fusion and Bilterial Grid
[2, 12, 8, 16, 16]
GuideNN
[2, 1080, 1920]
Slice and Apply Coefficients
[2, 3, 1080, 1920]
Net
torch.Size([2, 3, 1080, 1920])


**Model Summary**

In [0]:
!pip install torchsummary
from torchsummary import summary
summary(net, input_size=[(3, 256, 256),(3,1080,1920)])


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 128, 128]             224
              ReLU-2          [-1, 8, 128, 128]               0
Convolutional_Layer-3          [-1, 8, 128, 128]               0
            Conv2d-4           [-1, 16, 64, 64]           1,168
       BatchNorm2d-5           [-1, 16, 64, 64]              32
              ReLU-6           [-1, 16, 64, 64]               0
Convolutional_Layer-7           [-1, 16, 64, 64]               0
            Conv2d-8           [-1, 32, 32, 32]           4,640
       BatchNorm2d-9           [-1, 32, 32, 32]              64
             ReLU-10           [-1, 32, 32, 32]               0
Convolutional_Layer-11           [-1, 32, 32, 32]               0
           Conv2d-12           [-1, 64, 16, 16]          18,496
      BatchNorm2d-13           [-1, 64, 16, 16]             128
             ReLU-14           [-1,

**Dataset**

In [0]:
# import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

class LowLightDataSet(Dataset):
    def __init__(self, data_dir):
        self.dataset_input_dir = os.path.join(data_dir, 'low')
        self.dataset_gt_dir = os.path.join(data_dir, 'high')
        self.list_of_files = self.listAllInputImageFiles(self.dataset_input_dir)
        self.HighResTransform = transforms.Compose(
            [
                transforms.ToTensor()
            ]
        )
        self.LowResTransform = transforms.Compose(
           [
            transforms.Resize((256,256)),
            transforms.ToTensor()
           ]
        )
        self.greyScaleTransform = transforms.Compose(
            [transforms.Grayscale(1),
             transforms.ToTensor()
             ]
       )

    def __getitem__(self, index):
        image_name = self.list_of_files[index]
        hr_input_image_path = os.path.join(self.dataset_input_dir,image_name)
        hr_gt_image_path = os.path.join(self.dataset_gt_dir,image_name)
        with Image.open(hr_input_image_path) as img:
            tmp_image = img.convert('RGB')
            input_image_hr = self.HighResTransform(tmp_image)
            input_image_lr = self.LowResTransform(tmp_image)
        with Image.open(hr_gt_image_path) as img2:
            tmp_image = img2.convert('RGB')
            gt_image_hr = self.HighResTransform(tmp_image)
            grey_image = self.greyScaleTransform(tmp_image)
        return input_image_hr, input_image_lr, gt_image_hr, grey_image

    def __len__(self):
        return len(self.list_of_files)

    @staticmethod
    def listAllInputImageFiles(data_dir):
        list = os.listdir(data_dir)
        files = []
        for l in list:
            fullpath = os.path.join(data_dir, l)
            # FOR DEBUG: print(fullpath)
            if os.path.isfile(fullpath):
                files.append(l)
        #files = sorted(files,key=lambda i: int(os.path.splitext(os.path.basename(i))[0]))
        return files

**Testing on input images** 

In [0]:
def showIMG(imgs):
  for img in imgs:
    plt.figure()
    if img.shape[0] == 3:
      img_np = img.permute(1,2,0).cpu().detach().numpy() 
      print("showIMG() Output",img_np.shape)
      plt.imshow(img_np,interpolation="bilinear")

    elif img.shape[0] == 1:
      g_out_np = img.permute(1,2,0).cpu().detach().squeeze(2).numpy()
      print("showIMG() Guide Map",g_out_np.shape)
      plt.imshow(g_out_np,cmap='gray',interpolation="bilinear")

def showResultMap(img):
   plt.figure()
   img_np = img.permute(1,2,0).cpu().detach().numpy() 
   img_np_s = (img_np[:,:,0] + img_np[:,:,1] +  img_np[:,:,2])/3
   print("showResultMap() Output",img_np_s.shape)
   plt.imshow(img_np_s,interpolation="bilinear", cmap="gray")

In [0]:
# t_training_dataset = LowLightDataSet("/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/our485")
# i_fr, i_lr, o_fr = t_training_dataset[0]
t_testing_dataset = LowLightDataSet("/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15")
i_fr, i_lr, o_fr, grey_fr= t_testing_dataset[7]
showIMG([i_fr, i_lr, o_fr,grey_fr])

In [0]:
showResultMap(i_fr)

In [0]:
guide = Guide_PointwiseNN(bn=True)
i_fr_ = i_fr.unsqueeze(0)
g_out = guide(i_fr_)
showIMG([g_out])

**Evaluation function and Loss Function**

In [0]:
def compute_PSNR(output, target):
  return 10 * torch.log10 (1/F.mse_loss(output, target))

class MS_SSIM(nn.Module):
   def __init__(self, window_size=11, size_average=True, normalize=True):
        super(MS_SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.normalize = normalize

   def forward(self, output, target):
        return msssim(output, target, window_size=self.window_size, size_average=self.size_average,
                      normalize=self.normalize)
        
class SSIM(nn.Module):
   def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average

   def forward(self, output, target):
        return ssim(output, target, window_size=self.window_size, size_average=self.size_average)

          


**Training**

In [0]:
# MODEL PARAMETERS and TRAINING PARAMETERS

spatial_bins = 16
luma_bins = 8
channel_multiplier = 1 #ori 1
low_res_input_size = 256
epochs = 50
weight_decay = 10 ** -8 #ori -8
batch_size= 16
batch_norm = True
learning_rate = 10 ** -4 #ori -4

In [0]:
def train(sb,lb,cm,lris,epochs,w_decay,bs,bn,learning_rate,isResume=True, colab=True, checkptfolder='/content/gdrive/My Drive/Checkpoint/1/',training_path = "/content/gdrive/My Drive/Data/RetinexNetData/BrighteningTrain/", checkptname='checkpt.pth',checkptinterval=10):
    
    CUDA=torch.cuda.is_available()
    torch.cuda.empty_cache()

    if CUDA:
        dl_pin_memory= True
        device = torch.device("cuda")
    else:
        dl_pin_memory= False
        device = torch.device("cpu")
        return
    print("Cuda",dl_pin_memory)

    if colab:
      drive.mount('/content/gdrive')
    
    training_dataset = LowLightDataSet(training_path)
    train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=bs,shuffle=True,pin_memory=dl_pin_memory)


    model = Net(_cm=cm,_sb=sb,_lb=lb,_bn=bn,_lris=lris)
    msssim_criterion = MS_SSIM() 
    ssim_criterion = SSIM()
    mse_criterion = torch.nn.MSELoss()
    l1_criterion = torch.nn.L1Loss()
    optimizer = optim.Adam(model.parameters(),lr = learning_rate, weight_decay=w_decay)
    curr_epoch = 0
    losslogger = []

    if isResume:
      if os.path.isfile(checkptfolder+checkptname):
        print("=> Loading checkpoint'{}".format(checkptfolder+checkptname))
        checkpt = torch.load(checkptfolder+checkptname)
        curr_epoch = checkpt['epoch'] 
        epochs -= curr_epoch
        model.load_state_dict(checkpt['model_state_dict'])
        optimizer.load_state_dict(checkpt['optimizer_state_dict'])
        losslogger = checkpt['losslogger']
        for state in optimizer.state.values():
          for k, v in state.items():
              if isinstance(v, torch.Tensor):
                 state[k] = v.cuda()
        print("=> loaded checkpoint '{}' (start from epoch {})".format(checkptfolder+checkptname, checkpt['epoch']))
      else:
        print("=> no checkpoint found at '{}'".format(checkptfolder+checkptname))
        return
    else:
      print("=> No checkpoint will be used")
      
    model.to(device)

    for e in range(epochs):
      e_losslogger = []
      model.train()

      for batch_idx, (fr,lr,target,target_map) in enumerate(train_loader):
        
        optimizer.zero_grad()
        lr = lr.to(device)
        fr = fr.to(device)
        target_map = target_map.to(device).squeeze(1)
        target = target.to(device)
        g_out, output = model(lr, fr)
        msssimloss = 1 - msssim_criterion(output, target)
        ssimloss = 1 - ssim_criterion(output, target)
        l1loss = l1_criterion(output, target)
        l2loss = mse_criterion(output, target)
        loss = msssimloss
        e_losslogger.append(loss)
        loss.backward()

        print("Epoch: {}, Batch: {}, Loss: {}, SSIM: {}, MS-SSIM: {}, L1 Loss: {}, L2 Loss:{}, PSNR: {} dB".format(curr_epoch,batch_idx, loss.item(),
                                                                                                    ssim(output, target,  size_average=True), 
                                                                                                    msssim( output, target,  size_average=True),
                                                                                                    l1loss.item(), l2loss.item(),
                                                                                                    compute_PSNR(output,target)))
        optimizer.step()

      losslogger.append(e_losslogger)
      if (curr_epoch+1) % checkptinterval ==  0:
          saveCheckpt(checkptfolder,curr_epoch,model,optimizer,losslogger)
      curr_epoch += 1
 
def saveCheckpt(checkptfolder,curr_epoch,model,optimizer,losslogger,batch_idx='_'):
  checkptname = 'checkpt_epoch_{}_batch_{}.pth'.format(curr_epoch,batch_idx)
  print("Saving checkpoint {}".format(checkptname))
  torch.save({
    'epoch': curr_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'losslogger': losslogger,
    }, checkptfolder + checkptname)
  print("{} is saved".format(checkptname))


In [0]:
#Train from scratch
train(sb=spatial_bins,lb=luma_bins,cm=channel_multiplier,lris=low_res_input_size,epochs=epochs,
      w_decay=weight_decay,bs=batch_size,bn=batch_norm,learning_rate=learning_rate,isResume=False,
      training_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/our485",
      checkptfolder='/content/gdrive/My Drive/Checkpoint/1/')

In [0]:
#Train from checkpoint
train(sb=spatial_bins,lb=luma_bins,cm=channel_multiplier,lris=low_res_input_size,epochs=100,
      w_decay=weight_decay,bs=batch_size,bn=batch_norm,learning_rate=learning_rate,
      isResume=True,checkptname='checkpt_epoch_10_batch__.pth', checkptfolder='/content/gdrive/My Drive/Checkpoint/3/',
      training_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/our485")

**Demo**

In [0]:
def demo(input,isShowGuideMap=False, sb=16,lb=8,cm=1,lris=256,bn=True,checkptfolder='/content/gdrive/My Drive/Checkpoint/',checkptname='checkpt.pth', colab=True,testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15"):

  CUDA=torch.cuda.is_available()
  torch.backends.cudnn.benchmark = True
  torch.backends.cudnn.fastest = True

  if CUDA:
        dl_pin_memory= True
        device = torch.device("cuda")
  else:
        dl_pin_memory= False
        device = torch.device("cpu")
        return
  print("Cuda",dl_pin_memory)

  if colab:
    drive.mount('/content/gdrive')
  
  model = Net(_cm=cm,_sb=sb,_lb=lb,_bn=bn,_lris=lris)

  if os.path.isfile(checkptfolder+checkptname):
      print("=> Loading checkpoint'{}".format(checkptfolder+checkptname))
      checkpt = torch.load(checkptfolder+checkptname)
      model.load_state_dict(checkpt['model_state_dict'])
      print("=> loaded checkpoint '{}' ".format(checkptfolder+checkptname))
  else:
      print("=> no checkpoint found at '{}'".format(checkptfolder+checkptname))
    
  model.eval()
  model.to(device)
  lr, fr = loadIMG(input)
  lr = lr.to(device)
  fr = fr.to(device)
  g_out,output = model(lr,fr)
  imgs_output = []
  imgs_output.append(output.squeeze(0))
  if isShowGuideMap:
     imgs_output.append(g_out)
  showIMG(imgs_output) 
  #showResultMap(output.squeeze(0))

def loadIMG(image_path):
  fr_img_loader = transforms.Compose([transforms.ToTensor()])
  lr_img_loader = transforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])
  image = Image.open(image_path)
  fr_image = fr_img_loader(image)
  lr_image = lr_img_loader(image)
  print("loadIMG()",fr_image.shape)
  fr_image = fr_image.unsqueeze(0) 
  lr_image = lr_image.unsqueeze(0) 
  return lr_image, fr_image

**Demo image**

In [0]:
testing_img = '/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15/low/665.png'




---





**Train with LOL Dataset: cm = 1, loss = ms-ssim**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM1_MSSSIM/', cm=1, isShowGuideMap = True)
end = timer()
print(end - start,"seconds")


**Train with LOL Dataset: cm = 3, loss = ms-ssim**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM3_MSSSIM/', cm=3, isShowGuideMap = True)
end = timer()
print(end - start,"seconds")


**Train with LOL Dataset: cm = 3, loss = l1_loss**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM3_L1/', cm=3, isShowGuideMap = True)
end = timer()
print(end - start,"seconds")




---



**Train with LOL Dataset: cm = 3, loss = ms-ssim, l1, l2**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_59_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', cm=3, isShowGuideMap = True)
end = timer()
print(end - start,"seconds")


**Train with LOL Dataset: cm = 2, loss = ms-ssim, l1, l2**


In [0]:
start = timer()
demo(testing_img,checkptname='LOL_cm2_ep_49_msssimloss_l2loss_l1loss.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/lol_cm2_ep49_msssim_l1_l2/', cm=2,isShowGuideMap=True)
end = timer()
print(end - start,"seconds")

**Train with LOL Dataset: cm = 1, loss = ms-ssim, l1, l2**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm1_ep4849_msssimloss_l1loss_l2loss/', cm=1,isShowGuideMap=True)
end = timer()
print(end - start,"seconds")



---



**Train with LOL Dataset: cm = 3, loss = l2**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-51_l2loss/', cm=3,isShowGuideMap=True)
end = timer()
print(end - start,"seconds")


**Train with LOL Dataset: cm = 3, loss = l2 , ms-ssim**

In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep47-49_msssimloss_l2_loss/', cm=3,isShowGuideMap=True)
end = timer()
print(end - start,"seconds")


**Train with LOL Dataset: cm = 3, loss = ms-ssim, l1, l2**




In [0]:
start = timer()
demo(testing_img,checkptname='checkpt_epoch_59_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', cm=3, isShowGuideMap = True)
end = timer()
print(end - start,"seconds")

**Evaluation**

In [0]:
def test(isShowIMG=False,isShowGuideMap=False, sb=16,lb=8,cm=1,lris=256,bn=True,checkptfolder='/content/gdrive/My Drive/Checkpoint/',checkptname='checkpt.pth', colab=True,testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15"):

  count = 0
  totalmsssimloss = 0
  totalssimloss = 0
  totalPSNR = 0
  totalL1Loss = 0
  totalL2Loss = 0

  CUDA=torch.cuda.is_available()
  torch.backends.cudnn.benchmark = True
  torch.backends.cudnn.fastest = True

  if CUDA:
        dl_pin_memory= True
        device = torch.device("cuda")
  else:
        dl_pin_memory= False
        device = torch.device("cpu")
        return
  print("Cuda",dl_pin_memory)

  testing_dataset = LowLightDataSet(testing_path)
  test_loader = torch.utils.data.DataLoader(dataset=testing_dataset, batch_size=1,shuffle=True,pin_memory=dl_pin_memory)



  if colab:
    drive.mount('/content/gdrive')
  
  model = Net(_cm=cm,_sb=sb,_lb=lb,_bn=bn,_lris=lris)

  if os.path.isfile(checkptfolder+checkptname):
      print("=> Loading checkpoint'{}".format(checkptfolder+checkptname))
      checkpt = torch.load(checkptfolder+checkptname)
      model.load_state_dict(checkpt['model_state_dict'])
      print("=> loaded checkpoint '{}' ".format(checkptfolder+checkptname))
  else:
      print("=> no checkpoint found at '{}'".format(checkptfolder+checkptname))
      return

  model.eval()
  model.to(device)

  for batch_idx, (fr,lr,target,target_map) in enumerate(test_loader):
  
    lr = lr.to(device)
    fr = fr.to(device)
    target = target.to(device)
    g_out, output = model(lr,fr)
    #output = amplification(output)
    l1_criterion = torch.nn.L1Loss()
    l1loss = l1_criterion(output, target)
    l2_criterion = torch.nn.MSELoss()
    l2loss = l2_criterion(output, target)
    print("Image: {}, SSIM: {}, MS-SSIM: {}, L1 Loss: {}, L2 Loss:{}, PSNR: {} dB ".format(batch_idx,
                                                                              ssim(output, target,  size_average=True), 
                                                                              msssim( output, target,  size_average=True),
                                                                              l1loss.item(), l2loss.item(),
                                                                              compute_PSNR(output,target)))
    count += 1
    totalssimloss += ssim(output, target,  size_average=True)
    totalmsssimloss += msssim( output, target,  size_average=True)
    totalPSNR += compute_PSNR(output,target)
    totalL1Loss += l1loss.item()
    totalL2Loss += l2loss.item()
    imgs_output = []
    imgs_output.append(output.squeeze(0))
    if isShowIMG:
      if isShowGuideMap:
        imgs_output.append(target.squeeze(0))
        imgs_output.append(fr.squeeze(0))
        #imgs_output.append(target_map.squeeze(0))
      showIMG(imgs_output) 
      #showResultMap(output.squeeze(0))

  print("Average: SSIM: {}, MS-SSIM: {}, L1 Loss: {}, L2 Loss: {}, PSNR: {} dB ".format(totalssimloss/count, 
                                                                          totalmsssimloss/count,
                                                                          totalL1Loss/count,
                                                                          totalL2Loss/count, 
                                                                          totalPSNR/count))


**6.2.4.3. Different Loss Functions**

In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM1_MSSSIM/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=1, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM1_L1/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=1, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM1_L2/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=1, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm1_ep4849_msssimloss_l1loss_l2loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=1, isShowGuideMap=True, isShowIMG=True)


***6.2.5.4.	Number of Channels***

In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM3_L1/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_CM3_MSSSIM/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-51_l2loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_49_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)



**6.2.5.5.	Different combination of loss functions and number of epochs**

In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_69_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_79_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_89_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)


In [0]:
torch.cuda.empty_cache()

test(checkptname='checkpt_epoch_99_batch__.pth',checkptfolder='/content/gdrive/My Drive/Checkpoint/LOL_cm3_ep49-99_mssimloss_l2loss_l1_loss/', testing_path = "/content/gdrive/My Drive/Data/RetinexNetData/LOLdataset/eval15", cm=3, isShowGuideMap=True, isShowIMG=True)
