# 0. Data preprocessing

1. Firsty, create a shortcut of this [shared directory](https://drive.google.com/drive/folders/1LYgWpfDm3-A4q_QVcPGupsaNjblfCfYZ) in your Google Drive root: *MyDrive*

2. Specify path to shortcut as `src_path` in cell below, eg. we stored it in *MyDrive/shortcuts/*

3. Leave `target_path` as it is

4. Run cell below

In [1]:
from google.colab import drive
import os

print('Mounting Google Drive...')
drive.mount('/gdrive')

src_path = '/gdrive/MyDrive/shortcuts/' #@param {type: 'string'}
assert os.path.exists(src_path), f"Source '{src_path}' doesn't exist!"

target_path = '.' #@param {type: 'string'}
os.makedirs(target_path, exist_ok=True)
assert os.path.exists(target_path), f"Target '{target_path}' doesn't exist!"

target_path = os.path.join(target_path, os.path.basename(src_path))
print(f'Copying from "{src_path}" to "{target_path}"...')
os.makedirs(target_path, exist_ok=True)
!cp -rf "$src_path"/* "$target_path"  # also work when source is a shortcut

Mounting Google Drive...
Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).
Copying from "/gdrive/MyDrive/shortcuts/" to "./"...


Unzip it to transfer it from GDrive to Colab:

In [2]:
%%capture
!sudo tar -xvf project/*

%load_ext autoreload
%autoreload 2

!pip install transformers wandb timm

## 0.1. Imports

In [3]:
# Torch
import torchvision.transforms as TT
import torchvision
import torch
from torch import nn

# Custom decorators
from transformers import TrainingArguments, Trainer, TrainerCallback

# Metrics(network_swinir, util_calculate_psnr_ssim should be placed in root dir)
from network_swinir import SwinIR
from util_calculate_psnr_ssim import calculate_psnr, calculate_ssim

# Service imports
import sys
import os

# Plotting, reading, etc.
from PIL import Image
import matplotlib.pyplot as plt
from termcolor import colored

# GUI for metrics
import wandb

# numpy for life
import numpy as np

sys.path.append(os.path.join(os.path.dirname(sys.path[1]),'gdrive', 'MyDrive', 'ML_proj'))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## 0.2 Data preparation

In [4]:
train_path = 'BSR/BSDS500/data/images/train/'
test_path = 'BSR/BSDS500/data/images/test/'
val_path = 'BSR/BSDS500/data/images/test/'

In [5]:
TRAIN_SIZE = 432
VAL_SIZE = 68

### Extracting data to a comfortable representation in code

In [6]:
train_pics_list = []

for pic_name in os.listdir(train_path):
    if pic_name.split('.')[-1] == 'jpg':
        train_pics_list.append(train_path + pic_name)

for pic_name in os.listdir(test_path):
    if pic_name.split('.')[-1] == 'jpg':
        train_pics_list.append(test_path + pic_name)

val_pics_for_train = []

for pic_name in os.listdir(val_path):

    if len(val_pics_for_train) == TRAIN_SIZE - len(train_pics_list):
        break
    
    if pic_name.split('.')[-1] == 'jpg':
        val_pics_for_train.append(val_path + pic_name)

train_pics_list += val_pics_for_train

assert len(train_pics_list) == TRAIN_SIZE

val_pics_list = []

remaining_pics_list = []

for pic_name in os.listdir(val_path):

    if len(val_pics_list) == VAL_SIZE:
        break
    
    if pic_name.split('.')[-1] == 'jpg':
        val_pics_list.append(val_path + pic_name)

assert len(val_pics_list) == VAL_SIZE

Here we introduce our custom classes to make our code more suitable for paper's logics

In [7]:
class ImageNoiseAdditor(object):
    def __init__(self, train_params, valid_params):
        '''
        train_params = dict(sigma_from: int [from 0 to 255], 
                            sigma_to: int [from 0 to 255])

        valid_params = dict(sigma: int(from 0 to 255))
        '''
        self.train_params = train_params
        self.valid_params = valid_params


    def apply(self, image, train=True):

        assert isinstance(image, Image.Image), colored(f'wrong type of image (wait {Image.Image})', 'red') + f' Got:{type(image)}'
        img_numpy = np.asarray(image).astype(float)
        
        if train:
            sigma = np.random.randint(low=self.train_params['sigma_from'], 
                                      high=self.train_params['sigma_to'])
        else:
            sigma = self.valid_params['sigma']
        
        n = np.random.randn(*img_numpy.shape)
        img_numpy += float(sigma) * n
        img_numpy = np.clip(img_numpy, 0, 255).astype(np.uint8)

        image_with_noise = Image.fromarray(img_numpy, mode='RGB')

        return image_with_noise, sigma

In [8]:
class BSDDataset(torch.utils.data.Dataset):
  def __init__(self, img_list, transform, height, width, sigma_val=50, train=False):
    super(BSDDataset, self).__init__()
    self.img_list = img_list
    self.transform = transform
    self.height = height
    self.width = width
    self.to_tensor = TT.Compose([TT.Resize((height, width)),
                                 TT.ToTensor()])
    
    self.noise = ImageNoiseAdditor(train_params=dict(sigma_from=10, sigma_to=55),
                                   valid_params=dict(sigma=sigma_val))
    
    self.train = train

  def get_rand_from_to(self, from_, to_):
      return np.random.randint(low=int(from_), high=int(to_))

  def __len__(self):
    return len(self.img_list)

  def __getitem__(self, idx):

    img = Image.open(self.img_list[idx])
    img_with_noise, noise_val = self.noise.apply(img, train=self.train)

    if np.random.choice([True, False]):
       img = torchvision.transforms.functional.hflip(img)
       img_with_noise = torchvision.transforms.functional.hflip(img_with_noise)

    w, h = img.size[:2]

    max_top = h / 2
    max_left = w / 2

    if h - self.height > 0:
        max_top = h - self.height
        top = self.get_rand_from_to(0, max_top)
        height = self.get_rand_from_to(self.height, h - top)
    else:
        top = self.get_rand_from_to(0, max_top)
        height = self.get_rand_from_to(h / 2, h - top)

    if w - self.width > 0:
        max_left = w - self.width
        left = self.get_rand_from_to(0, max_left)
        width = self.get_rand_from_to(self.width, w - left)
    else:
        left = self.get_rand_from_to(0, max_left)
        width = self.get_rand_from_to(w / 2, w - left)

    img = torchvision.transforms.functional.crop(img, top, left, height, width)
    img_with_noise = torchvision.transforms.functional.crop(img_with_noise, top, left, height, width)

    
    return dict(inputs=self.to_tensor(img_with_noise), 
                labels=self.to_tensor(img),
                sigma=torch.tensor(noise_val).float())

# 1. Training

Bringing images to a general representation, applying noise for training and packing it to dataloaders:

In [9]:
upscale = 8
window_size = 8
height = (512 // upscale // window_size + 1) * window_size
width = (512 // upscale // window_size + 1) * window_size

random_transform = [
    TT.Compose([TT.RandomCrop(320, padding=0),
                        TT.Pad(2, padding_mode='reflect')]),
    TT.Compose([TT.RandomCrop(320, padding=0), 
            ]),
]

train_transform = TT.Compose([
    TT.Resize((height, width)),
    TT.ToTensor()
])

val_transform = TT.Compose([
    TT.Resize((height, width)),
    TT.ToTensor()
])

train_set = BSDDataset(train_pics_list, transform=train_transform, 
                       height=height, width=width, train=True)

val_set = BSDDataset(val_pics_list, transform=val_transform, 
                     height=height, width=width, sigma_val=50, train=True)


print('Train size', len(train_set))
print('Test size', len(val_set))

Train size 432
Test size 68


In [10]:
model = SwinIR(upscale=1, img_size=(height, width),
               window_size=window_size, img_range=1.,
               depths=[6, 6, 6, 6], embed_dim=120, num_heads=[6, 6, 6, 6], 
               mlp_ratio=4, upsampler=None, resi_connection='3conv')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [11]:
class  CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps
        self.l2 = torch.nn.MSELoss()

    def forward(self, I_pred, I_traget):
        x = torch.sqrt(self.l2(I_pred, I_traget) + self.eps ** 2)
        return x

In [12]:
lr = 0.5 * 1e-5
epochs = 400
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                       T_max=epochs, 
                                                       eta_min=0, 
                                                       last_epoch=- 1, 
                                                       verbose=False)

criterion = CharbonnierLoss()

In [13]:
def compute_metrics(EvalPredict):
    
    predictions, label_ids = EvalPredict
    batch_size = predictions.shape[0]
    psnr_accumulator = np.zeros(batch_size)
    ssim_accumulator = np.zeros(batch_size)

    for b in range(batch_size):

        output = predictions[b]
        img_gt = label_ids[b]

        output = (output * 255.0).round().astype(np.uint8).transpose(1, 2, 0)
        img_gt = (img_gt * 255.0).round().astype(np.uint8).transpose(1, 2, 0)

        psnr = calculate_psnr(output, img_gt, crop_border=0, input_order='HWC')
        psnr_accumulator[b] = psnr

        ssim = calculate_ssim(output, img_gt, crop_border=0, input_order='HWC')
        ssim_accumulator[b] = ssim

    return {'psnr': psnr_accumulator.mean(), 'ssim': ssim_accumulator.mean()}

class CustomTrainer(Trainer):
    def __init__(self, loss, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss = loss

    def compute_loss(self, model, inputs, return_outputs=False):
        inp_, labels = inputs['inputs'], inputs['labels']
        # forward pass
        outputs = model(inp_)

        # compute custom loss (suppose one has 3 labels with different weights)
        loss = self.loss(outputs, labels)
        return (loss, outputs) if return_outputs else loss

In [14]:
class ModelTestCallback(TrainerCallback):
    def __init__(self, eval_data_set, **kwargs):
        super().__init__(**kwargs)
        self.eval_data_set = eval_data_set

    def on_evaluate(self, args, state, control, **kwargs): # on_evaluate, on_epoch_begin
        print(colored('Callback', 'green', attrs=['bold']))
        caption = ['input', 'target', 'output']
        # print(kwargs)
        s_list = [15, 25, 50]
        for s in s_list:
            self.eval_data_set.noise.valid_params['sigma'] = s
            idx = np.random.randint(low=0, high=len(self.eval_data_set))
            data = self.eval_data_set.__getitem__(idx)
            
            inp, lbl = data['inputs'].to('cpu'), data['labels'].to('cpu')
            model = kwargs['model'].to('cpu')
            out = model(torch.unsqueeze(inp, dim=0))
            
            inp = [Image.fromarray((inp.cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            lbl = [Image.fromarray((lbl.cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            out = [Image.fromarray((out[0].cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            
            wandb.log({f"examples sigma = {self.eval_data_set.noise.valid_params['sigma']}": [wandb.Image(img, caption=caption[i]) for i, img in enumerate(inp + lbl + out)]})

            del inp, lbl, out


        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        kwargs['model'].to(device)

In [15]:
batch_size = 6
batch_accumulation = 8
path_for_save = os.path.join(os.path.dirname(sys.path[1]),'gdrive', 'MyDrive', 'ML_proj', 'checkpoints')

args = TrainingArguments(output_dir=path_for_save,
                         overwrite_output_dir=True,
                         report_to='wandb',
                         evaluation_strategy='steps',
                         logging_steps=5,
                         gradient_accumulation_steps=batch_accumulation,
                         per_device_train_batch_size=batch_size,
                         num_train_epochs=epochs,
                         load_best_model_at_end=True,
                         metric_for_best_model='ssim',
                         greater_is_better=True,
                         no_cuda=False,
                         save_strategy='steps',
                         save_steps=5, 
                         save_total_limit=3,
                         ignore_data_skip=True,
                         resume_from_checkpoint=os.path.join(os.path.dirname(sys.path[1]),'gdrive', 'MyDrive', 'ML_proj', 'checkpoints', 'checkpoint-1220')
                         )

trainer = CustomTrainer(loss=criterion,
                        model=model, 
                        train_dataset=train_set,
                        eval_dataset=val_set,
                        compute_metrics=compute_metrics,
                        optimizers=(optimizer, scheduler),
                        callbacks=[ModelTestCallback(eval_data_set=val_set)],
                        args=args)

Start training:

We used **WandB** framework for tracing metrics, so you should register [here](https://wandb.ai/home) and open the highlighted link below(*appears once you started to execute cell*):

In [16]:
trainer.train()

***** Running training *****
  Num examples = 432
  Num Epochs = 400
  Instantaneous batch size per device = 6
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 8
  Total optimization steps = 3600
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mscalyvladimir[0m (use `wandb login --relogin` to force relogin)


KeyboardInterrupt: ignored

# 2. Projection layer

Introducing training [projection layer](https://arxiv.org/abs/1711.07807)

In [17]:
class ProjectionLayer(torch.nn.Module):
    def __init__(self, img_size):
        super().__init__()
        self.img_size = img_size
        self.alpha = nn.Parameter(torch.randn(1), requires_grad=True)
        self.Nt = torch.prod(torch.tensor(self.img_size))
        

    def calculate_eps(self, sigma):
        return torch.exp(self.alpha) * sigma * torch.sqrt(self.Nt - 1)

    def forward(self, pred, inp, sigma):

        eps = self.calculate_eps(sigma)

        d = pred - inp
        
        denom = torch.max(torch.linalg.norm(d, dim=(1, 2, 3)), eps)

        nom = torch.mul(eps, d.transpose(3, 0))

        p = inp + torch.div(nom, denom).transpose(3, 0)
        return p


class SwinIRP(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.swin_ir = SwinIR(**kwargs)
        img_size = (*kwargs['img_size'], 3)
        self.pj_layer = ProjectionLayer(img_size)

    def forward(self, x, sigma):
        x_ = self.swin_ir.forward(x)
        x_ = self.pj_layer(x_, x, sigma)
        return x_

    def load_swin_model(self, checkpoint):
      print(colored('checkpoint:', 'grey', attrs=['bold']), checkpoint)
      state_dict = torch.load(os.path.join(checkpoint, 'pytorch_model.bin'), map_location="cpu")
      load_result = self.swin_ir.load_state_dict(state_dict, strict=False)
      print(load_result)

In [18]:
model = SwinIRP(upscale=1, img_size=(height, width),
                window_size=window_size, img_range=1.,
                depths=[6, 6, 6, 6], embed_dim=120, num_heads=[6, 6, 6, 6], 
                mlp_ratio=4, upsampler=None, resi_connection='3conv')

Introducing a separate loss function and negative PSNR metric for projection layer training

In [19]:
class CharbonnierLoss(nn.Module):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps
        self.l2 = torch.nn.MSELoss()

    def forward(self, I_pred, I_traget):
        x = torch.sqrt(self.l2(I_pred, I_traget) + self.eps ** 2)
        return x

class NegativePSNR(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        Nt = torch.prod(torch.tensor(pred.shape[1:]))
        p = torch.sqrt(Nt) * 255

        return torch.sum((-20) * torch.log10(p / torch.linalg.norm(pred - target, dim=(1, 2, 3))))

Empirical hyperparameters

In [20]:
lr = 2e-5
epochs = 400
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                       T_max=epochs, 
                                                       eta_min=0, 
                                                       last_epoch=- 1, 
                                                       verbose=False)

criterion = NegativePSNR()

In [21]:
def compute_metrics(EvalPredict):
    
    predictions, label_ids = EvalPredict
    batch_size = predictions.shape[0]
    psnr_accumulator = np.zeros(batch_size)
    ssim_accumulator = np.zeros(batch_size)

    for b in range(batch_size):

        output = predictions[b]
        img_gt = label_ids[b]

        output = (output * 255.0).round().astype(np.uint8).transpose(1, 2, 0)
        img_gt = (img_gt * 255.0).round().astype(np.uint8).transpose(1, 2, 0)

        psnr = calculate_psnr(output, img_gt, crop_border=0, input_order='HWC')
        psnr_accumulator[b] = psnr

        ssim = calculate_ssim(output, img_gt, crop_border=0, input_order='HWC')
        ssim_accumulator[b] = ssim

    return {'psnr': psnr_accumulator.mean(), 'ssim': ssim_accumulator.mean()}


class CustomTrainer(Trainer):
    def __init__(self, loss, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss = loss

    def compute_loss(self, model, inputs, return_outputs=False):
        inp_, labels, sigmas = inputs['inputs'], inputs['labels'], inputs['sigma']
        # forward pass
        outputs = model(inp_, sigmas)

        # compute custom loss (suppose one has 3 labels with different weights)
        loss = self.loss(outputs, labels)
        return (loss, outputs) if return_outputs else loss


class ModelTestCallback(TrainerCallback):
    def __init__(self, eval_data_set, **kwargs):
        super().__init__(**kwargs)
        self.eval_data_set = eval_data_set

    def on_evaluate(self, args, state, control, **kwargs): # on_evaluate, on_epoch_begin
        print(colored('Callback', 'green', attrs=['bold']))
        caption = ['input', 'target', 'output']
        
        s_list = [15, 25, 50]
        for s in s_list:
            self.eval_data_set.noise.valid_params['sigma'] = s
            idx = np.random.randint(low=0, high=len(self.eval_data_set))
            data = self.eval_data_set.__getitem__(idx)
            
            
            inp, lbl, sigma = data['inputs'].to('cpu'), data['labels'].to('cpu'), data['sigma'].to('cpu')
            model = kwargs['model'].to('cpu')
            out = model(torch.unsqueeze(inp, dim=0), sigma)
            

            inp = [Image.fromarray((inp.cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            lbl = [Image.fromarray((lbl.cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            out = [Image.fromarray((out[0].cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))]
            
            wandb.log({f"examples sigma = {sigma}": [wandb.Image(img, caption=caption[i]) for i, img in enumerate(inp + lbl + out)]})

            del inp, lbl, out


        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        kwargs['model'].to(device)

Some extra hyperparams

In [22]:
batch_size = 6
batch_accumulation = 4
path_for_save = os.path.join(os.path.dirname(sys.path[1]),'gdrive', 'MyDrive', 'ML_proj', 'checkpoints_proj')

args = TrainingArguments(output_dir=path_for_save,
                         overwrite_output_dir=True,
                         report_to='wandb',
                         evaluation_strategy='steps',
                         logging_steps=5,
                         gradient_accumulation_steps=batch_accumulation,
                         per_device_train_batch_size=batch_size,
                         num_train_epochs=epochs,
                         load_best_model_at_end=True,
                         metric_for_best_model='ssim',
                         greater_is_better=True,
                         no_cuda=False,
                         save_strategy='steps',
                         save_steps=5, 
                         save_total_limit=3,
                         ignore_data_skip=True)

trainer = CustomTrainer(loss=criterion,
                        model=model, 
                        train_dataset=train_set,
                        eval_dataset=val_set,
                        compute_metrics=compute_metrics,
                        optimizers=(optimizer, scheduler),
                        callbacks=[ModelTestCallback(eval_data_set=val_set)],
                        args=args)                        

using `logging_steps` to initialize `eval_steps` to 5
PyTorch: setting up devices


Here you are to specify path to pretrained model which saved by default to your GDrive

In [23]:
checkpoint_dir_path = 'checkpoint-435'# example

In [24]:
trainer.train(os.path.join(os.path.dirname(sys.path[1]),'gdrive', 'MyDrive', 'ML_proj', 'checkpoints_proj', checkpoint_dir_path))

Loading model from /gdrive/MyDrive/ML_proj/checkpoints_proj/checkpoint-435).
***** Running training *****
  Num examples = 432
  Num Epochs = 400
  Instantaneous batch size per device = 6
  Total train batch size (w. parallel, distributed & accumulation) = 24
  Gradient Accumulation steps = 4
  Total optimization steps = 7200
  Continuing training from checkpoint, will skip to saved global_step
  Continuing training from epoch 24
  Continuing training from global step 435
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss


KeyboardInterrupt: ignored