In [None]:
!pip install torchmetrics
!pip install wandb
!pip install torchsummary
!pip install gdown

# Dataset Download

In [None]:
!gdown 1bbIy5P_aP_xarxUw4yPuPk4CglHPGLco

In [None]:
!gdown 1JH7CQ77wjkV0Xnp14l8u4bqJF0ENGMqp

In [None]:
!unzip UDC_train.zip

In [None]:
!unzip UDC_validation_input.zip

In [None]:
import os
os.makedirs('ds/val/GT',exist_ok = True)
os.makedirs('ds/val/input',exist_ok = True)
os.makedirs('ds/train/GT',exist_ok = True)
os.makedirs('ds/train/input',exist_ok = True)

In [None]:
import shutil
val_list = sorted(os.listdir('training/GT/'))[:26]
for im_name in val_list:
  shutil.move(f'training/GT/{im_name}',f'ds/val/GT/{im_name}')
for im_name in val_list:
  shutil.move(f'training/input/{im_name}',f'ds/val/input/{im_name}')

In [None]:
train_list = sorted(os.listdir('training/GT/'))
for im_name in train_list:
  shutil.move(f'training/GT/{im_name}',f'ds/train/GT/{im_name}')
for im_name in train_list:
  shutil.move(f'training/input/{im_name}',f'ds/train/input/{im_name}')

##### Library import

In [1]:
import argparse,os,wandb,tqdm
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset,DataLoader
import PIL.Image as Image
from torchmetrics import PeakSignalNoiseRatio,StructuralSimilarityIndexMeasure
from torchsummary import summary
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##### Parameter Intilisation
* Here we intialise the paramets such as epoch,batch size, train img size etc.

In [8]:
mode = 'train' #['test', 'train']
checkpoint_folder ='PDCRN_SYNTH'
model_type = 'PDCRN'
train_path = 'ds/train'
test_path = 'ds/val'
batch_size = 48
epochs = 1000
LR = 1e-4
num_filters  = 64
dilation_rates = (3, 2, 1, 1, 1, 1)
nPyramidFilters = 64
log_name = 'logger'
in_ch = 3

##### Dataset Intialisation

In [None]:
class Custom_Dataset(Dataset):
    def __init__(self,root_dir, is_train=False):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.is_train = is_train
        self.hq_im_file_list = []
        self.lq_im_file_list = []
        self.hq_train_files = {}
        self.lq_train_files = {}
        for dir_path, _, file_names in os.walk(root_dir):
            for f_paths in sorted(file_names):
                if dir_path.endswith('GT'):
                    self.hq_im_file_list.append(os.path.join(dir_path,f_paths))
                elif dir_path.endswith('input'):
                    self.lq_im_file_list.append(os.path.join(dir_path,f_paths))
        for im_names in self.hq_im_file_list:
            self.hq_train_files[im_names] = np.load(im_names)
        for im_names in self.lq_im_file_list:
            self.lq_train_files[im_names] = np.load(im_names)
        # self.tensor_convert = T.ToTensor()
        self.train_transform = T.Compose(
    [T.RandomCrop((512,512)),T.RandomVerticalFlip(p=0.5),T.RandomHorizontalFlip(p=0.5)\
       ,T.RandomAffine((0,360))])

        # self.val_transform = T.Compose([T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    def __len__(self):
        return len(self.hq_im_file_list)
    def tone_transform(self,im,c=0.25):
      mapped_x = im / (im + c)
      return mapped_x
    def __getitem__(self, idx):
        image_hq_fname = self.hq_train_files[self.hq_im_file_list[idx]]
        image_lq_fname = self.lq_train_files[self.lq_im_file_list[idx]]
        hq_image = torch.from_numpy(self.tone_transform(image_hq_fname)).unsqueeze(dim=0).permute(0,3,1,2)
        lq_image = torch.from_numpy(self.tone_transform(image_lq_fname)).unsqueeze(dim=0).permute(0,3,1,2)
        concat_img = torch.cat([hq_image,lq_image],dim=0)
        if self.is_train:
            image = self.train_transform(concat_img)
        else:
            # image = self.val_transform(concat_img)
            image = concat_img
        hq_img,lq_img = image.tensor_split(2)
        return lq_img.squeeze(0),hq_img.squeeze(0)

In [None]:
train_ds = Custom_Dataset(train_path,is_train=True)
train_dataloader = DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=8)
val_ds = Custom_Dataset(test_path,is_train=False)
val_dataloader = DataLoader(val_ds,batch_size=1,shuffle=False,num_workers=8)

##### DWT and IWT layer

In [2]:
class DWT(nn.Module):
    def __init__(self):
        super(DWT,self).__init__()
    def forward(self,input):
        x01 = input[:,:,0::2,:] / 4.0
        x02 = input[:,:,1::2,:] / 4.0
        x1 = x01[:, :,:, 0::2]
        x2 = x01[:, :,:, 1::2]
        x3 = x02[:, :,:, 0::2]
        x4 = x02[:, :,:, 1::2]
        y1 = x1+x2+x3+x4
        y2 = x1-x2+x3-x4
        y3 = x1+x2-x3-x4
        y4 = x1-x2-x3+x4
        return torch.cat([y1, y2, y3, y4], axis=1)

In [3]:
class IWT(nn.Module):
    def __init__(self,scale=2):
        super(IWT,self).__init__()
        self.upsampler = nn.PixelShuffle(scale)

    def kernel_build(self,input_shape):
        c = input_shape[1]
        out_c = c >> 2
        kernel = np.zeros((c, c,1, 1), dtype=np.float32)
        for i in range(0, c, 4):
            idx = i >> 2
            kernel[idx,idx::out_c,0,0]          = [1, 1, 1, 1]
            kernel[idx+out_c,idx::out_c,0,0]    = [1,-1, 1,-1]
            kernel[idx+out_c*2,idx::out_c,0,0]  = [1, 1,-1,-1]
            kernel[idx+out_c*3,idx::out_c,0,0]  = [1,-1,-1, 1]
        self.kernel = torch.tensor(data=kernel,dtype=torch.float32,requires_grad=True).to(device)
        return None

    def forward(self,input):
        self.kernel_build(input.shape)
        y = nn.functional.conv2d(input,weight = self.kernel, padding='same')
        y = self.upsampler(y)
        return y

#### Model Components and Model Intialisation

In [4]:
class DilationPyramid(nn.Module):
    def __init__(self,num_filters,dilation_rates):
        super(DilationPyramid,self).__init__()
        self.layer_1 = nn.Conv2d(num_filters,num_filters*2,3,padding='same')
        self.layer_2 = nn.Conv2d(num_filters*2,num_filters,3,padding='same',dilation=dilation_rates[0])
        self.layer_3 = []
        for dil_rate in dilation_rates[1:]:
          self.layer_3.append(nn.Conv2d(num_filters,num_filters,3,padding='same',dilation=dil_rate))
          self.layer_3.append(nn.ReLU())
        self.layer_3 = nn.Sequential(*self.layer_3)
        # self.layer_3 = nn.Sequential(*[[nn.Conv2d(num_filters,num_filters,3,padding='same',dilation=dil_rate),nn.ReLU()] for dil_rate in dilation_rates[1:]])
        self.layer_4 = nn.Conv2d(num_filters*2,num_filters,1,padding='same')
    def forward(self,input):
        x = self.layer_1(input)
        x = nn.functional.relu(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = torch.cat([input,x],dim=1)
        x = self.layer_4(x)
        out = nn.functional.relu(x)
        return out

In [5]:
class PyramidBlock(nn.Module):
    def __init__(self,num_filters,dilation_rates,nPyramidFilters):
        super(PyramidBlock,self).__init__()
        self.feat_extract = nn.Sequential(*[DilationPyramid(nPyramidFilters,dilation_rates),\
        nn.Conv2d(num_filters,num_filters,3,padding='same')])
    def forward(self,input):
        x = self.feat_extract(input)*0.1
        return input+x

In [6]:
class UDC_Arc(nn.Module):
    def __init__(self,in_ch,num_filters,dilation_rates,nPyramidFilters):
        super(UDC_Arc,self).__init__()
        self.encoder = nn.Sequential(*nn.ModuleList([DWT(),nn.PixelUnshuffle(downscale_factor=2),\
        nn.Conv2d(in_channels=in_ch*4*4,out_channels=num_filters,kernel_size=5,padding='same'),\
        nn.Conv2d(in_channels=num_filters,out_channels=num_filters,kernel_size=3,padding='same'),\
        PyramidBlock(num_filters,dilation_rates,nPyramidFilters),\
        nn.Conv2d(num_filters,num_filters*2,kernel_size=5,stride=2,padding=2),\
        PyramidBlock(num_filters*2,dilation_rates,nPyramidFilters*2),\
        nn.Conv2d(num_filters*2,num_filters*4,kernel_size=5,stride=2,padding=2),\
        PyramidBlock(num_filters*4,dilation_rates,nPyramidFilters*4)
        ]))
        self.decoder = nn.Sequential(*nn.ModuleList([PyramidBlock(num_filters*4,dilation_rates,nPyramidFilters*4),\
        nn.ConvTranspose2d(num_filters*4,num_filters*2,kernel_size = 4,stride=2,padding=1),\
        PyramidBlock(num_filters*2,dilation_rates,nPyramidFilters*2),\
        nn.ConvTranspose2d(num_filters*2,num_filters,kernel_size = 4,stride=2,padding=1),\
        PyramidBlock(num_filters,dilation_rates,nPyramidFilters),\
        nn.PixelShuffle(upscale_factor=2),nn.Conv2d(num_filters//4,in_ch*4,3,padding='same'),IWT(),\
        # nn.Tanh()
        ]))
    def forward(self,input):
        x_enc = self.encoder(input)
        x_dec = self.decoder(x_enc)
        x_dec = torch.minimum(torch.maximum(torch.zeros_like(x_dec),x_dec),torch.ones_like(x_dec)*500)
        # x_dec = torch.nn.functional.relu(x_dec)
        return x_dec

In [None]:
model = UDC_Arc(in_ch,num_filters,dilation_rates,nPyramidFilters)
model = model.to(device)
summary(model,input_size=(3,256,256))

#### Save and Load checkpoint

In [None]:
def save_checkpoint(checkpoint_folder,model_type,type='last'):
    checkpoint_folder = os.path.join(checkpoint_folder,model_type)
    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)
    checkpoint_filename = os.path.join(checkpoint_folder,f'{type}.pth')
    save_data = {
        'step': current_epoch,
        f'best_psnr':best_psnr,
        f'best_ssim':best_ssim,
        'generator_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(save_data, checkpoint_filename)

In [None]:
def load_model_checkpoint_for_training(checkpoint_folder,model_type,type ='last'):
    checkpoint_folder = os.path.join(checkpoint_folder,model_type)
    checkpoint_filename = os.path.join(checkpoint_folder, f'{type}.pth')
    if not os.path.exists(checkpoint_filename):
        print("Couldn't find checkpoint file. Starting training from the beginning.")
        return 0,0,0
    data = torch.load(checkpoint_filename)
    current_epoch = data['step']
    best_psnr = data['best_psnr']
    best_ssim = data['best_ssim']
    model.load_state_dict(data['generator_state_dict'])
    optimizer.load_state_dict(data['optimizer_state_dict'])
    print(f"Restored model at epoch {current_epoch}.")
    return best_psnr,best_ssim,current_epoch

#### Train and validation Pipeline

In [None]:
def train_epoch():
    model.train()
    for count,(inputs, gt) in enumerate(tqdm.tqdm(train_dataloader)):
        inputs = inputs.to(device)
        gt = gt.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = criterion(outputs,gt)
            loss.backward()
            optimizer.step()
    wandb.log({'train_l1_loss':loss.item()})
    wandb.log({'Learning rate':optimizer.param_groups[0]['lr']})
    return None

In [None]:
def val_epoch(checkpoint_folder,model_type,best_psnr,best_ssim):
    model.eval()
    for inputs, gt in tqdm.tqdm(val_dataloader):
        inputs = inputs.to(device)
        gt = gt.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _ = criterion(outputs,gt)
        psnr.update(outputs,gt)
        # ssim.update(outputs,gt)
    wandb.log({'val_psnr':psnr.compute().item()})
    wandb.log({'val_ssim':ssim.compute().item()})
    val_psnr,val_ssim = psnr.compute().item(),ssim.compute().item()
    psnr.reset()
    ssim.reset()
    if val_psnr>best_psnr:
        best_psnr = val_psnr
        save_checkpoint(checkpoint_folder,model_type,'best')
    else:
        save_checkpoint(checkpoint_folder,model_type,'last')
    if val_ssim>best_ssim:
        best_ssim = val_ssim
    print(f'Epoch = {current_epoch} Val best PSNR = {best_psnr},Val best SSIM= {best_ssim},Val current PSNR = {val_psnr},Val currentSSIM= {val_ssim}')
    return best_psnr,best_ssim

#### Model Training

In [None]:
current_epoch = 0
best_psnr = 0
best_ssim = 0
# optimizer = torch.optim.SGD(model.parameters(), lr=(batch_size*0.3)/256, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-6)
criterion = torch.nn.L1Loss().to(device)
psnr  = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)

In [None]:
best_psnr,best_ssim,current_epoch = load_model_checkpoint_for_training(checkpoint_folder,model_type)
wandb.init(project=f"UDC",name=log_name)
for epoch in range(current_epoch,epochs):
    current_epoch = epoch
    train_epoch()
    best_psnr,best_ssim = val_epoch(checkpoint_folder,model_type,best_psnr,best_ssim)

#### Test

In [8]:
import torch
from model.DBWN_D import DBWN_D
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
model = DBWN_D(device='cuda:0',num_filters=64)
model = model.to(device)

In [10]:
import cv2
from torchvision.transforms.functional import to_tensor
import matplotlib.pyplot as plt
import numpy as np

In [11]:
def test(inp):
    output = model(inp)
    return output

In [12]:
def inv_tone_transform(mapped_x,c=0.25):
    im = np.divide((mapped_x*c),(1-mapped_x), out=mapped_x, where=mapped_x!=1)
    return im.clip(0,500)

In [13]:
def tone_transform(im,c=0.25):
    mapped_x = im / (im + c)
    return mapped_x

In [14]:
test_dir = 'ds/test/input'

In [15]:
# from model.DBWN import DBWN
# test_model = DBWN().to(device)

In [16]:
# checkpoint_model_folder = os.path.join(checkpoint_folder,model_type)
# checkpoint_filename = os.path.join(checkpoint_model_folder, 'best.pth')
checkpoint_filename = 'checkpoint/final.pth'
data = torch.load(checkpoint_filename)
model.load_state_dict(data['generator_state_dict'])

<All keys matched successfully>

In [17]:
image_list = []
for dir_path, _, file_names in os.walk(test_dir):
    for f_paths in sorted(file_names):
        image_list.append(os.path.join(dir_path,f_paths))

In [18]:
os.makedirs('ds/test/output',exist_ok = True)

In [19]:
for count,im_name in enumerate(image_list):
    im = np.load(im_name)
    # im = tone_transform(im)    
    im = torch.from_numpy(tone_transform(im)).unsqueeze(0).permute(0,3,1,2).to(device)
    
    out = model(im)
    out = out[0,...].permute(1,2,0).detach().cpu().numpy()
    out = inv_tone_transform(out)
    im  = im[0,...].permute(1,2,0).detach().cpu().numpy()
    im = inv_tone_transform(im)
    print(count,file_names[count],im.max(),out.max())
    np.save(f'ds/test/output/{file_names[count]}',out.astype(np.float32))
    im_save_im = (np.concatenate([[im,out]],axis=1).clip(0,255)).astype(np.uint8)
    cv2.imwrite(f'ds/test/results/{file_names[count]}',cv2.cvtColor(im_save_im,cv2.COLOR_RGB2BGR))
    disp = [im,out]
    fig = plt.figure(figsize=(12,12))
    for i in range(2):
        plt.subplot(1,2,i+1)    # the number of images in the grid is 5*5 (25)
        plt.imshow(disp[i])
    plt.show()
    plt.close()

0 000.npy 64.276764 500.0


error: OpenCV(4.6.0) /io/opencv/modules/imgcodecs/src/loadsave.cpp:730: error: (-2:Unspecified error) could not find a writer for the specified extension in function 'imwrite_'


In [None]:
cd /content/validation/output

In [None]:
!zip /content/drive/MyDrive/2022/ACCV-22/result.zip -r *

In [None]:
!wget https://codalab.lisn.upsaclay.fr/my/datasets/download/b549bbbf-7c4e-4ada-bdfb-429523bac284

In [None]:
!unzip /content/b549bbbf-7c4e-4ada-bdfb-429523bac284

In [None]:
!pip install pyiqa

In [None]:
cd /content/test_code

In [None]:
import shutil
shutil.copytree('/content/drive/MyDrive/2022/ACCV-22/checkpoint','/content/checkpoint')