<a href="https://colab.research.google.com/github/hrishikeshps94/PDCRN_UDC/blob/master/udc_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchmetrics
!pip install wandb
!pip install torchsummary
!pip install gdown

# Dataset Download

In [None]:
!gdown --id 1l_QOnq1Y-O-xPIBu9a_cl5SX8dCalMiF #Link to the dataset in drive

In [None]:
!unzip /content/ds.zip -d /content/

##### Library import

In [1]:
import argparse,os,wandb,tqdm
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset,DataLoader
import PIL.Image as Image
from torchmetrics import PeakSignalNoiseRatio,StructuralSimilarityIndexMeasure
from torchsummary import summary
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##### Parameter Intilisation
* Here we intialise the paramets such as epoch,batch size, train img size etc.

In [2]:
mode = 'train' #['test', 'train']
checkpoint_folder ='checkpoint'
model_type = 'PDCRN'
train_path = '/media/hrishi/data/WORK/RESEARCH/2022/journal-2022/UDC/ds/Poled/train'
test_path = '/media/hrishi/data/WORK/RESEARCH/2022/journal-2022/UDC/ds/Poled/val'
batch_size = 1
epochs = 1000
LR = 1e-4
num_filters  = 8
dilation_rates = (3, 2, 1, 1, 1, 1)
nPyramidFilters = 8
log_name = 'logger'
in_ch = 3

##### Dataset Intialisation

In [4]:
class Custom_Dataset(Dataset):
    def __init__(self,root_dir, is_train=False):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.is_train = is_train
        self.hq_im_file_list = []
        self.lq_im_file_list = []
        for dir_path, _, file_names in os.walk(root_dir):
            for f_paths in sorted(file_names):
                if dir_path.endswith('HQ'):
                    self.hq_im_file_list.append(os.path.join(dir_path,f_paths))
                elif dir_path.endswith('LQ'):
                    self.lq_im_file_list.append(os.path.join(dir_path,f_paths))
        self.tensor_convert = T.ToTensor()
        self.train_transform = T.Compose(
    [T.RandomCrop((256,256)),T.RandomVerticalFlip(p=0.5),T.RandomHorizontalFlip(p=0.5)\
       ,T.RandomAffine((0,360))])

        # self.val_transform = T.Compose([T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    def __len__(self):
        return len(self.hq_im_file_list)

    def __getitem__(self, idx):
        image_hq_fname = self.hq_im_file_list[idx]
        image_lq_fname = self.lq_im_file_list[idx]
        hq_image = self.tensor_convert(Image.open(image_hq_fname).convert('RGB')).unsqueeze(dim=0)
        lq_image = self.tensor_convert(Image.open(image_lq_fname).convert('RGB')).unsqueeze(dim=0)
        concat_img = torch.cat([hq_image,lq_image],dim=0)
        if self.is_train:
            image = self.train_transform(concat_img)
        else:
            # image = self.val_transform(concat_img)
            image = concat_img
        hq_img,lq_img = image.tensor_split(2)
        return lq_img.squeeze(0),hq_img.squeeze(0)

In [5]:
train_ds = Custom_Dataset(train_path,is_train=True)
train_dataloader = DataLoader(train_ds,batch_size=batch_size,shuffle=True,num_workers=os.cpu_count())
val_ds = Custom_Dataset(test_path,is_train=False)
val_dataloader = DataLoader(val_ds,batch_size=1,shuffle=False,num_workers=os.cpu_count())

##### DWT and IWT layer

In [6]:
class DWT(nn.Module):
    def __init__(self):
        super(DWT,self).__init__()
    def forward(self,input):
        x01 = input[:,:,0::2,:] / 4.0
        x02 = input[:,:,1::2,:] / 4.0
        x1 = x01[:, :,:, 0::2]
        x2 = x01[:, :,:, 1::2]
        x3 = x02[:, :,:, 0::2]
        x4 = x02[:, :,:, 1::2]
        y1 = x1+x2+x3+x4
        y2 = x1-x2+x3-x4
        y3 = x1+x2-x3-x4
        y4 = x1-x2-x3+x4
        return torch.cat([y1, y2, y3, y4], axis=1)

In [7]:
class IWT(nn.Module):
    def __init__(self,scale=2):
        super(IWT,self).__init__()
        self.upsampler = nn.PixelShuffle(scale)

    def kernel_build(self,input_shape):
        c = input_shape[1]
        out_c = c >> 2
        kernel = np.zeros((c, c,1, 1), dtype=np.float32)
        for i in range(0, c, 4):
            idx = i >> 2
            kernel[idx,idx::out_c,0,0]          = [1, 1, 1, 1]
            kernel[idx+out_c,idx::out_c,0,0]    = [1,-1, 1,-1]
            kernel[idx+out_c*2,idx::out_c,0,0]  = [1, 1,-1,-1]
            kernel[idx+out_c*3,idx::out_c,0,0]  = [1,-1,-1, 1]
        self.kernel = torch.tensor(data=kernel,dtype=torch.float32,requires_grad=True).to(device)
        return None

    def forward(self,input):
        self.kernel_build(input.shape)
        y = nn.functional.conv2d(input,weight = self.kernel, padding='same')
        y = self.upsampler(y)
        return y

#### Model Components and Model Intialisation

In [8]:
class DilationPyramid(nn.Module):
    def __init__(self,num_filters,dilation_rates):
        super(DilationPyramid,self).__init__()
        self.layer_1 = nn.Conv2d(num_filters,num_filters*2,3,padding='same')
        self.layer_2 = nn.Conv2d(num_filters*2,num_filters,3,padding='same',dilation=dilation_rates[0])
        self.layer_3 = nn.Sequential(*[nn.Conv2d(num_filters,num_filters,3,padding='same',dilation=dil_rate) \
            for dil_rate in dilation_rates[1:]])
        self.layer_4 = nn.Conv2d(num_filters*2,num_filters,1,padding='same')
    def forward(self,input):
        x = self.layer_1(input)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = torch.cat([input,x],dim=1)
        out = self.layer_4(x)
        return out

In [9]:
class PyramidBlock(nn.Module):
    def __init__(self,num_filters,dilation_rates,nPyramidFilters):
        super(PyramidBlock,self).__init__()
        self.feat_extract = nn.Sequential(*[DilationPyramid(nPyramidFilters,dilation_rates),\
        nn.Conv2d(num_filters,num_filters,3,padding='same')])
    def forward(self,input):
        x = self.feat_extract(input)*0.1
        return input+x

In [11]:
class UDC_Arc(nn.Module):
    def __init__(self,in_ch,num_filters,dilation_rates,nPyramidFilters):
        super(UDC_Arc,self).__init__()
        self.encoder = nn.Sequential(*nn.ModuleList([DWT(),nn.PixelUnshuffle(downscale_factor=2),\
        nn.Conv2d(in_channels=in_ch*4*4,out_channels=num_filters,kernel_size=5,padding='same'),\
        nn.Conv2d(in_channels=num_filters,out_channels=num_filters,kernel_size=3,padding='same'),\
        PyramidBlock(num_filters,dilation_rates,nPyramidFilters),\
        nn.Conv2d(num_filters,num_filters*2,kernel_size=5,stride=2,padding=2),\
        PyramidBlock(num_filters*2,dilation_rates,nPyramidFilters*2),\
        nn.Conv2d(num_filters*2,num_filters*4,kernel_size=5,stride=2,padding=2),\
        PyramidBlock(num_filters*4,dilation_rates,nPyramidFilters*4)
        ]))
        self.decoder = nn.Sequential(*nn.ModuleList([PyramidBlock(num_filters*4,dilation_rates,nPyramidFilters*4),\
        nn.ConvTranspose2d(num_filters*4,num_filters*2,kernel_size = 4,stride=2,padding=1),\
        PyramidBlock(num_filters*2,dilation_rates,nPyramidFilters*2),\
        nn.ConvTranspose2d(num_filters*2,num_filters,kernel_size = 4,stride=2,padding=1),\
        PyramidBlock(num_filters,dilation_rates,nPyramidFilters),\
        nn.PixelShuffle(upscale_factor=2),nn.Conv2d(num_filters//4,in_ch*4,3,padding='same'),IWT()
        ]))
    def forward(self,input):
        x_enc = self.encoder(input)
        x_dec = self.decoder(x_enc)
        return x_dec

In [12]:
model = UDC_Arc(in_ch,num_filters,dilation_rates,nPyramidFilters)
model = model.to(device)
summary(model,input_size=(3,256,256))

Layer (type:depth-idx)                        Param #
├─Sequential: 1-1                             --
|    └─DWT: 2-1                               --
|    └─PixelUnshuffle: 2-2                    --
|    └─Conv2d: 2-3                            9,608
|    └─Conv2d: 2-4                            584
|    └─PyramidBlock: 2-5                      --
|    |    └─Sequential: 3-1                   5,968
|    └─Conv2d: 2-6                            3,216
|    └─PyramidBlock: 2-7                      --
|    |    └─Sequential: 3-2                   23,712
|    └─Conv2d: 2-8                            12,832
|    └─PyramidBlock: 2-9                      --
|    |    └─Sequential: 3-3                   94,528
├─Sequential: 1-2                             --
|    └─PyramidBlock: 2-10                     --
|    |    └─Sequential: 3-4                   94,528
|    └─ConvTranspose2d: 2-11                  8,208
|    └─PyramidBlock: 2-12                     --
|    |    └─Sequential: 3-5        

Layer (type:depth-idx)                        Param #
├─Sequential: 1-1                             --
|    └─DWT: 2-1                               --
|    └─PixelUnshuffle: 2-2                    --
|    └─Conv2d: 2-3                            9,608
|    └─Conv2d: 2-4                            584
|    └─PyramidBlock: 2-5                      --
|    |    └─Sequential: 3-1                   5,968
|    └─Conv2d: 2-6                            3,216
|    └─PyramidBlock: 2-7                      --
|    |    └─Sequential: 3-2                   23,712
|    └─Conv2d: 2-8                            12,832
|    └─PyramidBlock: 2-9                      --
|    |    └─Sequential: 3-3                   94,528
├─Sequential: 1-2                             --
|    └─PyramidBlock: 2-10                     --
|    |    └─Sequential: 3-4                   94,528
|    └─ConvTranspose2d: 2-11                  8,208
|    └─PyramidBlock: 2-12                     --
|    |    └─Sequential: 3-5        

#### Save and Load checkpoint

In [13]:
def save_checkpoint(checkpoint_folder,model_type,type='last'):
    checkpoint_folder = os.path.join(checkpoint_folder,model_type)
    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)
    checkpoint_filename = os.path.join(checkpoint_folder,f'{type}.pth')
    save_data = {
        'step': current_epoch,
        f'best_psnr':best_psnr,
        f'best_ssim':best_ssim,
        'generator_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(save_data, checkpoint_filename)

In [14]:
def load_model_checkpoint_for_training(checkpoint_folder,model_type,type ='last'):
    checkpoint_folder = os.path.join(checkpoint_folder,model_type)
    checkpoint_filename = os.path.join(checkpoint_folder, f'{type}.pth')
    if not os.path.exists(checkpoint_filename):
        print("Couldn't find checkpoint file. Starting training from the beginning.")
        return 0,0
    data = torch.load(checkpoint_filename)
    current_epoch = data['step']
    best_psnr = data['best_psnr']
    best_ssim = data['best_ssim']
    model.load_state_dict(data['generator_state_dict'])
    optimizer.load_state_dict(data['optimizer_state_dict'])
    print(f"Restored model at epoch {current_epoch}.")
    return best_psnr,best_ssim

#### Train and validation Pipeline

In [15]:
def train_epoch():
    model.train()
    for count,(inputs, gt) in enumerate(tqdm.tqdm(train_dataloader)):
        inputs = inputs.to(device)
        gt = gt.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = criterion(outputs,gt)
            loss.backward()
            optimizer.step()
    wandb.log({'train_l1_loss':loss.item()})
    wandb.log({'Learning rate':optimizer.param_groups[0]['lr']})
    return None

In [16]:
def val_epoch(checkpoint_folder,model_type,best_psnr,best_ssim):
    model.eval()
    for inputs, gt in tqdm.tqdm(val_dataloader):
        inputs = inputs.to(device)
        gt = gt.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _ = criterion(outputs,gt)
        psnr.update(outputs,gt)
        ssim.update(outputs,gt)
    wandb.log({'val_psnr':psnr.compute().item()})
    wandb.log({'val_ssim':ssim.compute().item()})
    val_psnr,val_ssim = psnr.compute().item(),ssim.compute().item()
    psnr.reset()
    ssim.reset()
    if val_psnr>best_psnr:
        best_psnr = val_psnr
        save_checkpoint(checkpoint_folder,model_type,'best')
    else:
        save_checkpoint(checkpoint_folder,model_type,'last')
    if val_ssim>best_ssim:
        best_ssim = val_ssim
    print(f'Epoch = {current_epoch} Val best PSNR = {best_psnr},Val best SSIM= {best_ssim},Val current PSNR = {val_psnr},Val currentSSIM= {val_ssim}')
    return best_psnr,best_ssim

#### Model Training

In [17]:
current_epoch = 0
best_psnr = 0
best_ssim = 0
optimizer = torch.optim.SGD(model.parameters(), lr=(batch_size*0.3)/256, momentum=0.9)
criterion = torch.nn.L1Loss().to(device)
psnr  = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)



In [18]:
best_psnr,best_ssim = load_model_checkpoint_for_training(checkpoint_folder,model_type)
wandb.init(project=f"UDC",name=log_name)
for epoch in range(current_epoch,epochs):
    current_epoch = epoch
    train_epoch()
    best_psnr,best_ssim = val_epoch(checkpoint_folder,model_type,best_psnr,best_ssim)

Couldn't find checkpoint file. Starting training from the beginning.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhrishifms[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 230/230 [00:45<00:00,  5.11it/s]
100%|██████████| 10/10 [00:01<00:00,  6.41it/s]


RuntimeError: CUDA out of memory. Tried to allocate 1.19 GiB (GPU 0; 3.82 GiB total capacity; 2.16 GiB already allocated; 369.81 MiB free; 2.16 GiB reserved in total by PyTorch)