In [5]:
from utils import mkExpDir
from dataset import dataloader
from model import TTSR
from loss.loss import get_loss_dict
from trainer import Trainer

import os
import torch
import torch.nn as nn
import warnings
import glob
import time

warnings.filterwarnings('ignore')

# Args

In [6]:
from argparse import Namespace

In [7]:
args   = Namespace()

In [8]:
## log setting
args.save_dir      = 'trial'
args.reset         =  True
args.log_file_name = 'TTSR.log'
args.logger_name   = 'TTSR'
## device setting
args.cpu           = False
args.num_gpu       = 1
## dataset setting
args.dataset       = 'CUFED'                      # Which dataset to train and test
args.dataset_dir   = r'/home/esteban/Datasets/CUFED/' # Directory of dataset
## dataloader setting
args.num_workers   = 9                            # The number of workers when loading data

In [9]:
## model setting
args.num_res_blocks = '16+16+8+4' # The number of residual blocks in each stage
args.n_feats        = 64          # The number of channels in network
args.res_scale      = 1.0         # Residual scale

## loss setting
args.GAN_type       = 'WGAN_GP'   # The type of GAN used in training
args.GAN_k          = 2           # Training discriminator k times when training generator once
args.tpl_use_S      = False       # Whether to multiply soft-attention map in transferal perceptual loss
args.tpl_type       = 'l2'        # Which loss type to calculate gram matrix difference in transferal perceptual loss [l1 / l2]
args.rec_w          = 1.0         # The weight of reconstruction loss
args.per_w          = 1e-2        # The weight of perceptual loss
args.tpl_w          = 1e-2        # The weight of transferal perceptual loss
args.adv_w          = 1e-3        # The weight of adversarial loss

## optimizer setting
args.beta1          = 0.9         # The beta1 in Adam optimizer
args.beta2          = 0.999       # The beta2 in Adam optimizer
args.eps            = 1e-8        # The eps in Adam optimizer
args.lr_rate        = 1e-4        # Learning rate
args.lr_rate_dis    = 1e-4        # Learning rate of discriminator
args.lr_rate_lte    = 1e-5        # Learning rate of LTE
args.decay          = 999999      # Learning rate decay type
args.gamma          = 0.5         # Learning rate decay factor for step decay

In [10]:
## training setting
args.batch_size      = 9      # Training batch size
args.train_crop_size = 40     # Training data crop size
args.num_init_epochs = 2      # The number of init epochs which are trained with only reconstruction loss
args.num_epochs      = 50      # The number of training epochs
args.print_every     = 600      # Print period
args.save_every      = 5 # Save period
args.val_every       = 10 # Validation period

In [11]:
## evaluate / test / finetune setting
args.eval              = True                     # Evaluation mode
args.eval_save_results = False                     # Save each image during evaluation
args.model_path        = None                      # The path of model to evaluation
args.test              = False                     # Test mode
args.lr_path           = './test/demo/lr/lr.png'   # The path of input lr image when testing
args.ref_path          = './test/demo/ref/ref.png' # The path of ref image when testing

# Main



In [12]:
# make save_dir
_logger     = mkExpDir(args)

In [13]:
# dataloader of training set and testing set
_dataloader = dataloader.get_dataloader(args) if (not args.test) else None

In [14]:
len(_dataloader['train'].dataset.input_list), len(_dataloader['test']['1'])

(11871, 126)

In [17]:
# device and model
device = torch.device('cpu' if args.cpu else 'cuda')

In [18]:
_model = TTSR.TTSR(args).to(device)
if ((not args.cpu) and (args.num_gpu > 1)):
  _model = nn.DataParallel(_model, list(range(args.num_gpu)))

In [19]:
from torchinfo import summary
summary(_model, [[1, 3, 40, 40], [1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 160, 160]] )

Layer (type:depth-idx)                   Output Shape              Param #
TTSR                                     [1, 3, 160, 160]          555,340
├─LTE: 1-1                               [1, 64, 160, 160]         --
│    └─MeanShift: 2-1                    [1, 3, 160, 160]          (12)
│    └─Sequential: 2-2                   [1, 64, 160, 160]         --
│    │    └─Conv2d: 3-1                  [1, 64, 160, 160]         1,792
│    │    └─ReLU: 3-2                    [1, 64, 160, 160]         --
│    └─Sequential: 2-3                   [1, 128, 80, 80]          --
│    │    └─Conv2d: 3-3                  [1, 64, 160, 160]         36,928
│    │    └─ReLU: 3-4                    [1, 64, 160, 160]         --
│    │    └─MaxPool2d: 3-5               [1, 64, 80, 80]           --
│    │    └─Conv2d: 3-6                  [1, 128, 80, 80]          73,856
│    │    └─ReLU: 3-7                    [1, 128, 80, 80]          --
│    └─Sequential: 2-4                   [1, 256, 40, 40]          

In [20]:
# loss
_loss_all = get_loss_dict(args, _logger)
# trainer
t = Trainer(args, _logger, _dataloader, _model, _loss_all)

In [21]:
t.load('./TTSR.pt')
t.evaluate(current_epoch = -1)

[2023-05-22 13:05:10,978] - [trainer.py file line:48] - INFO: load_model_path: ./TTSR.pt
[2023-05-22 13:05:11,037] - [trainer.py file line:121] - INFO: Epoch -1 evaluation process...
[2023-05-22 13:05:30,223] - [trainer.py file line:150] - INFO: Ref  PSNR (now): 25.402 	 SSIM (now): 0.7600
[2023-05-22 13:05:30,224] - [trainer.py file line:157] - INFO: Ref  PSNR (max): 25.402 (-1) 	 SSIM (max): 0.7600 (-1)
[2023-05-22 13:05:30,225] - [trainer.py file line:160] - INFO: Evaluation over.


In [23]:
t.load('./TTSR-rec.pt')
t.evaluate(current_epoch = -1)

[2023-05-22 13:06:32,225] - [trainer.py file line:48] - INFO: load_model_path: ./TTSR-rec.pt
[2023-05-22 13:06:32,292] - [trainer.py file line:121] - INFO: Epoch -1 evaluation process...
[2023-05-22 13:06:51,596] - [trainer.py file line:150] - INFO: Ref  PSNR (now): 26.991 	 SSIM (now): 0.8003
[2023-05-22 13:06:51,597] - [trainer.py file line:157] - INFO: Ref  PSNR (max): 26.991 (-1) 	 SSIM (max): 0.8003 (-1)
[2023-05-22 13:06:51,598] - [trainer.py file line:160] - INFO: Evaluation over.


# Test

In [74]:
args.model_path = './save_dir/model/model_00050.pt'                      # The path of model to evaluation
t.load(args.model_path)

[2022-01-28 10:10:40,403] - [trainer.py file line:53] - INFO: load_model_path: ./save_dir/model/model_00050.pt


# Plot Results

In [75]:
def prepare(sample_batched):
    print(sample_batched.keys())
    for key in sample_batched.keys():
        sample_batched[key] = sample_batched[key].to(device)
    return sample_batched

In [76]:
import matplotlib.pyplot as plt
%matplotlib inline

In [77]:
def plot_results(model = None, total_images = 5) :
    i = 0
    for sample_batched in _dataloader['test']['1'] :
        sample_batched = prepare(sample_batched)
        lr             = sample_batched['LR']
        lr_sr          = sample_batched['LR_sr']
        hr             = sample_batched['HR']
        ref            = sample_batched['Ref']
        ref_sr         = sample_batched['Ref_sr']
        

        t.model.eval()
        with torch.no_grad():
            sr, _, _, _, _ = t.model(lr = lr, lrsr = lr_sr, ref = ref, refsr = ref_sr)
            sr_save = (sr+1.) * 127.5
            
            
            lr = (lr+1) * 127.5
            lr_sr = (lr_sr+1) * 127.5
            hr  = (hr+1) * 127.5
            ref  = (ref+1) * 127.5
            ref_sr  = (ref_sr+1) * 127.5

            
            
            sr_save_ = np.transpose(sr_save[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
            lr_      = np.transpose(lr[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
            lr_sr_   = np.transpose(lr_sr[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
            hr_      = np.transpose(hr[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
            ref_     = np.transpose(ref[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
            ref_sr_  = np.transpose(ref_sr[0].squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)

            # plot
            names       = ['LR', 'LR_sr', 'TTSR', 'Ref', 'Ref_sr', 'Target']
            num_classes = len(names)

            fig, ax = plt.subplots(figsize = (20,10), nrows = 1, ncols = 6, sharex = True, sharey = True,)
            fig.suptitle(args.model_path)
            ax      = ax.flatten()

            ax[0].imshow(lr_)
            ax[0].set_xlabel(str(0) + ': '+ names[0])

            ax[1].imshow(lr_sr_)
            ax[1].set_xlabel(str(1) + ': '+ names[1])

            ax[2].imshow(sr_save_)
            ax[2].set_xlabel(str(2) + ': '+ names[2])
            
            ax[3].imshow(ref_)
            ax[3].set_xlabel(str(3) + ': '+ names[3])
            
            ax[4].imshow(ref_sr_)
            ax[4].set_xlabel(str(4) + ': '+ names[4])
            
            ax[5].imshow(hr_)
            ax[5].set_xlabel(str(3) + ': '+ names[5])


            ax[0].set_xticks([])
            ax[0].set_yticks([])
            plt.show()
            
            
            plt.imsave( './test/' + str(i) + '_hr.png', hr_)
            plt.imsave( './test/' + str(i) + '_ref.png', ref_)
            plt.imsave( './test/' +  str(i) + '_ttsr.png',sr_save_)
            
            '''save_path = os.path.join(self.args.save_dir, 'save_results', args.model_path.split('/')[-1].split('.')[0] + '_' + os.path.basename(self.args.lr_path))
            imsave(save_path, sr_save)
            self.logger.info('output path: %s' %(save_path))'''
        i += 1
        if i == total_images - 1 : 
            return