#Model Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
#import pytorch_colors as colors
import numpy as np

class enhance_net_nopool(nn.Module):

	def __init__(self):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True)

		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)



	def forward(self, x):

		x1 = self.relu(self.e_conv1(x))
		# p1 = self.maxpool(x1)
		x2 = self.relu(self.e_conv2(x1))
		# p2 = self.maxpool(x2)
		x3 = self.relu(self.e_conv3(x2))
		# p3 = self.maxpool(x3)
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		# x5 = self.upsample(x5)
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
		x = x + r6*(torch.pow(x,2)-x)
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

#Dataloader func

In [None]:
import os
import sys

import torch
import torch.utils.data as data

import numpy as np
from PIL import Image
import glob
import random
import cv2

random.seed(1143)


def populate_train_list(lowlight_images_path):

  print(lowlight_images_path)
  image_list_lowlight = glob.glob(lowlight_images_path + "*.JPG")
  train_list = image_list_lowlight
  random.shuffle(train_list)
  return train_list



class lowlight_loader(data.Dataset):

	def __init__(self, lowlight_images_path):

		self.train_list = populate_train_list(lowlight_images_path)
		self.size = 256

		self.data_list = self.train_list
		print("Total training examples:", len(self.train_list))




	def __getitem__(self, index):

		data_lowlight_path = self.data_list[index]

		data_lowlight = Image.open(data_lowlight_path)

		data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)

		data_lowlight = (np.asarray(data_lowlight)/255.0)
		data_lowlight = torch.from_numpy(data_lowlight).float()

		return data_lowlight.permute(2,0,1)

	def __len__(self):
		return len(self.data_list)

#Many losses

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import numpy as np


class L_color(nn.Module):

    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x ):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)


        return k


class L_spa(nn.Module):

    def __init__(self):
        super(L_spa, self).__init__()
        # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)
    def forward(self, org , enhance ):
        b,c,h,w = org.shape

        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)

        org_pool =  self.pool(org_mean)
        enhance_pool = self.pool(enhance_mean)

        weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)


        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)

        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)

        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)
        # E = 25*(D_left + D_right + D_up +D_down)

        return E
class L_exp(nn.Module):

    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
    def forward(self, x ):

        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)

        d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
        return d

class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
class Sa_Loss(nn.Module):
    def __init__(self):
        super(Sa_Loss, self).__init__()
        # print(1)
    def forward(self, x ):
        # self.grad = np.ones(x.shape,dtype=np.float32)
        b,c,h,w = x.shape
        # x_de = x.cpu().detach().numpy()
        r,g,b = torch.split(x , 1, dim=1)
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Dr = r-mr
        Dg = g-mg
        Db = b-mb
        k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
        # print(k)


        k = torch.mean(k)
        return k

class perception_loss(nn.Module):
    def __init__(self):
        super(perception_loss, self).__init__()
        features = vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential()
        self.to_relu_2_2 = nn.Sequential()
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()

        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])

        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h
        # out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
        return h_relu_4_3

#Unpacking dataset

In [None]:
import zipfile

with zipfile.ZipFile('drive/MyDrive/data/dataset.zip') as fh:
  fh.extractall()

In [None]:
! mv ./dataset/* ./data/train_data/

mv: target './data/train_data/' is not a directory


#Main cell

In [None]:
from collections import namedtuple
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
import numpy as np
from torchvision import transforms


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)





def train(config):
    os.environ['CUDA_VISIBLE_DEVICES']='0'
    DCE_net = enhance_net_nopool().cuda()
    DCE_net.apply(weights_init)
    if config.load_pretrain == True:
        DCE_net.load_state_dict(torch.load(config.pretrain_dir))

    train_dataset = lowlight_loader(config.lowlight_images_path)
    print(train_dataset)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)



    l_color = L_color()
    l_spa = L_spa()

    l_exp = L_exp(16,0.6)
    l_TV = L_TV()


    optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    DCE_net.train()

    for epoch in range(config.num_epochs):
      for iteration, img_lowlight in enumerate(train_loader):

        img_lowlight = img_lowlight.cuda()

        enhanced_image_1,enhanced_image,A  = DCE_net(img_lowlight)

        Loss_TV = 200*l_TV(A)

        loss_spa = torch.mean(l_spa(enhanced_image, img_lowlight))

        loss_col = 5*torch.mean(l_color(enhanced_image))

        loss_exp = 10*torch.mean(l_exp(enhanced_image))


        # best_loss
        loss =  Loss_TV + loss_spa + loss_col + loss_exp
        #


        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
        optimizer.step()

        if ((iteration+1) % config.display_iter) == 0:
          print("Loss at iteration", iteration+1, ":", loss.item())
        if ((iteration+1) % config.snapshot_iter) == 0:

          torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth')




if __name__ == "__main__":

	# parser = argparse.ArgumentParser()

	# # Input Parameters
	# parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
	# parser.add_argument('--lr', type=float, default=0.0001)
	# parser.add_argument('--weight_decay', type=float, default=0.0001)
	# parser.add_argument('--grad_clip_norm', type=float, default=0.1)
	# parser.add_argument('--num_epochs', type=int, default=200)
	# parser.add_argument('--train_batch_size', type=int, default=8)
	# parser.add_argument('--val_batch_size', type=int, default=4)
	# parser.add_argument('--num_workers', type=int, default=4)
	# parser.add_argument('--display_iter', type=int, default=10)
	# parser.add_argument('--snapshot_iter', type=int, default=10)
	# parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
	# parser.add_argument('--load_pretrain', type=bool, default= False)
	# parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")

	# config = parser.parse_args()

    if not os.path.exists('snapshots'):
		    os.mkdir('snapshots')
    Config = namedtuple(typename='Config',
                      field_names='lowlight_images_path lr weight_decay grad_clip_norm num_epochs train_batch_size val_batch_size num_workers display_iter snapshot_iter snapshots_folder load_pretrain pretrain_dir')
    config = Config(lowlight_images_path='dataset/low/',
                    lr=1e-4,
                    weight_decay=1e-4,
                    grad_clip_norm=0.1,
                    num_epochs=10,
                    train_batch_size=8,
                    val_batch_size=4,
                    num_workers=4,
                    display_iter=10,
                    snapshot_iter=10,
                    snapshots_folder='snapshots',
                    load_pretrain=False,
                    pretrain_dir='snapshots/Epoch99.pth')

    print(config.lowlight_images_path)
    train(config)

dataset/low/
dataset/low/
Total training examples: 739
<__main__.lowlight_loader object at 0x782dde3f6cb0>


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)


Loss at iteration 10 : 2.838231086730957
Loss at iteration 20 : 2.6426167488098145
Loss at iteration 30 : 3.000356435775757
Loss at iteration 40 : 2.610485315322876
Loss at iteration 50 : 2.359203577041626
Loss at iteration 60 : 1.9219133853912354
Loss at iteration 70 : 2.1342687606811523
Loss at iteration 80 : 1.9892079830169678
Loss at iteration 90 : 2.0536866188049316


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 1.8665355443954468
Loss at iteration 20 : 1.9111285209655762
Loss at iteration 30 : 0.9831662178039551
Loss at iteration 40 : 1.2037456035614014
Loss at iteration 50 : 1.2717481851577759
Loss at iteration 60 : 0.8526502847671509
Loss at iteration 70 : 0.9140684604644775
Loss at iteration 80 : 0.9310120940208435
Loss at iteration 90 : 0.7928346395492554


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.8503995537757874
Loss at iteration 20 : 0.8684111833572388
Loss at iteration 30 : 1.3968722820281982
Loss at iteration 40 : 0.5895294547080994
Loss at iteration 50 : 0.7322898507118225
Loss at iteration 60 : 0.8424534201622009
Loss at iteration 70 : 1.0184762477874756
Loss at iteration 80 : 0.6766331195831299
Loss at iteration 90 : 0.8297868967056274


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.9664385914802551
Loss at iteration 20 : 1.360647439956665
Loss at iteration 30 : 0.6319637298583984
Loss at iteration 40 : 1.1071019172668457
Loss at iteration 50 : 0.7230488061904907
Loss at iteration 60 : 0.7805615663528442
Loss at iteration 70 : 0.6183333396911621
Loss at iteration 80 : 1.014220118522644
Loss at iteration 90 : 0.818138599395752


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 1.0275758504867554
Loss at iteration 20 : 0.748002290725708
Loss at iteration 30 : 0.8631169199943542
Loss at iteration 40 : 0.8054494857788086
Loss at iteration 50 : 0.9308217763900757
Loss at iteration 60 : 0.9209845066070557
Loss at iteration 70 : 0.7389471530914307
Loss at iteration 80 : 0.9044817686080933
Loss at iteration 90 : 1.142937183380127


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.6785875558853149
Loss at iteration 20 : 0.8135749101638794
Loss at iteration 30 : 0.9307042360305786
Loss at iteration 40 : 0.8323526382446289
Loss at iteration 50 : 0.8005580902099609
Loss at iteration 60 : 0.8524062633514404
Loss at iteration 70 : 0.6912227869033813
Loss at iteration 80 : 0.817693293094635
Loss at iteration 90 : 0.6217324733734131


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.7488853931427002
Loss at iteration 20 : 0.750346302986145
Loss at iteration 30 : 1.1003437042236328
Loss at iteration 40 : 0.7912396192550659
Loss at iteration 50 : 0.7059059143066406
Loss at iteration 60 : 0.8259745836257935
Loss at iteration 70 : 0.6692497730255127
Loss at iteration 80 : 0.6367825865745544
Loss at iteration 90 : 1.0271201133728027


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.789138913154602
Loss at iteration 20 : 0.9202148914337158
Loss at iteration 30 : 1.1560499668121338
Loss at iteration 40 : 1.342419147491455
Loss at iteration 50 : 0.7387574911117554
Loss at iteration 60 : 0.951035737991333
Loss at iteration 70 : 0.7774948477745056
Loss at iteration 80 : 0.7299361228942871
Loss at iteration 90 : 0.556025505065918


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 0.42378002405166626
Loss at iteration 20 : 0.8534656763076782
Loss at iteration 30 : 0.6332070827484131
Loss at iteration 40 : 0.9819373488426208
Loss at iteration 50 : 0.7927204370498657
Loss at iteration 60 : 0.753942608833313
Loss at iteration 70 : 0.7407373785972595
Loss at iteration 80 : 1.221308708190918
Loss at iteration 90 : 0.6613943576812744


  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
  data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)


Loss at iteration 10 : 1.2602710723876953
Loss at iteration 20 : 0.856682538986206
Loss at iteration 30 : 0.7644873261451721
Loss at iteration 40 : 0.7333015203475952
Loss at iteration 50 : 0.8697239756584167
Loss at iteration 60 : 0.9214871525764465
Loss at iteration 70 : 1.0226349830627441
Loss at iteration 80 : 0.7942115068435669
Loss at iteration 90 : 0.6833429336547852
