In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
import os
os.chdir('gdrive/My Drive/Super_resolution/brainSR')

In [3]:
from data.image_list import create_image_list
from data.dataset import create_dataset
from data.data_loader import create_dataloader
import numpy as np
import torch
import time
import random
from collections import OrderedDict
import math
from math import log10
from utils import util, convert, metric

In [4]:
opt = {
    "name": "espcn_x4",
    "model": "espcn",
    "gpu_ids": [0],
    "upscale_factor":4,
    "device": "cuda",

    "datasets":
        {
            "path":
                {
                "root": "data/BrainTumour",
                "folders": ["imagesTr"]
                },
            "train":
                {
                "phase": "train",
                "batch_size": 64,
                "use_shuffle": True,
                "upscale_factor": 4,
                "scale": True,
                "no_of_images": 450,
                "n_workers": 1,
                "gaussian": dict(use=True, sigma=1, kernel_size=3, dim=2)                 
                },
            "valid":
                {
                "phase": "valid",
                "no_of_images": 34,
                "scale": True,
                "depth_padding": 1,
                "upscale_factor": 4,
                "guassian": dict(use=True, sigma=1, kernel_size=3, dim=2)
                }
    },

    
  "path": 
    {
        "root": ""
    },

  "network":
    {
        "norm_type": "batch",
#         "ngf": 64,
#         "ngb": 16,
#         "input_ngc": 1,
#         "output_ngc": 1
    },

  "train" : 
    {
        "manual_seed": 777,
        "niter": 80,
        "val_freq": 1,
        "lr": 1e-4,
        "criterion": "mse",
    },
    
  "logger": 
    {
        "print_freq": 150, # images
        "save_checkpoint_freq": 1 # iteration
    }
}


In [5]:
def parse(opt_path, is_train=True):
#     with open(opt_path, 'r') as f:
#         opt = json.load(f, object_pairs_hook=OrderedDict)
    opt['is_train'] = is_train

    for key, path in opt['path'].items():
        opt['path'][key] = os.path.expanduser(path)
    if is_train:
        experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['options'] = experiments_root
        opt['path']['trained_models'] = os.path.join(experiments_root, 'trained_models')
        opt['path']['log'] = os.path.join(experiments_root, 'log')
    else:
        results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = os.path.join(results_root, 'log')
        opt['path']['test_images'] = os.path.join(results_root, 'test_images')

    return opt

opt = parse(opt)


In [6]:
# seeding function for reproducibility
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
SEED = 400
seed_everything(400)

In [7]:
training_image_list = create_image_list(opt, train=True)
valid_image_list = create_image_list(opt, valid=True)


0/450 images loaded
50/450 images loaded
100/450 images loaded
150/450 images loaded
200/450 images loaded
250/450 images loaded
300/450 images loaded
350/450 images loaded
400/450 images loaded
0/34 images loaded


In [8]:
training_set = create_dataset(opt["datasets"]["train"], training_image_list)
valid_set = create_dataset(opt["datasets"]["valid"], valid_image_list)


In [9]:
from models.models import create_model
model = create_model(opt)

Model [ESPCNModel] is created.


In [10]:
from utils.logger import Logger
logger = Logger(opt)

In [11]:
def validate(val_set, val_size, model, logger, epoch, img_no):
    print('Start validation phase ...')
    val_start_time = time.time()
    model.eval() # Change to eval mode. It is important for BN layers.

    val_results = OrderedDict()
    avg_psnr = 0.0 
    avg_mse = 0.0
    for val_data in val_set:
      assert val_data['H'].shape[0] == val_data['L'].shape[0]
      batch_no = val_data['H'].shape[0]

      model.feed_data(val_data)
      model.forward()
      
      visuals = model.get_current_visuals()
      
      sr_img = visuals['super-resolution'] # uint8
      gt_img = visuals['ground-truth'] # uint8
      
      mse = model.criterion(sr_img, gt_img).item()
      psnr = 10 * log10(1 / mse)
      avg_mse += mse
      avg_psnr += psnr

    avg_mse = avg_mse / val_size
    avg_psnr = avg_psnr / val_size

    val_results['mse'] = avg_mse
    val_results['psnr'] = avg_psnr

    val_duration = time.time() - val_start_time

    model.train() # Change back to train mode.
    return val_results, val_duration


In [None]:
# run this model
start_time = time.time()

val_size = len(valid_set)
depth = 624

avg_train_loss_list=[]
avg_valid_loss_list=[]
lr_list = []
psnr_list = []

for iteration in range(opt['train']['niter']+1):
  train_loss=0
  val_loss=0
  for i, train_data in enumerate(training_set):
    train_start_time = time.time()
    # training
    model.feed_data(train_data)
    model.optimize_parameters(i)
    train_duration = time.time() - train_start_time
    
    train_loss += model.loss.item() # this just gets a value, whereas get_current_losses gets a dict
    if i % opt['logger']['print_freq'] == 0:
      losses = model.get_current_losses()
      logger.print_results(losses, iteration, i, train_duration, 'loss')      
        
  avg_train_loss = train_loss/len(training_set)
  avg_train_loss_list.append(avg_train_loss)

  # write training results
  if i == len(training_set)-1:
    train_results = dict(avg_train_loss = avg_train_loss,
                         learning_rate = model.optimizer.state_dict()['param_groups'][0]['lr'])
    
    logger.print_results(train_results, iteration, i, 0, 'loss')

  # validation
  if iteration % opt['train']['val_freq'] == 0:
    val_results, val_duration = validate(valid_set, val_size, model, logger, iteration, i)
    print(val_results)
    # Save to log
    logger.print_results(val_results, iteration, i, val_duration, 'val')

  # save
  if iteration % opt['logger']['save_checkpoint_freq'] == 0:
    print('Saving the model at the end of iteration %d' % (iteration))
    model.save(iteration)

  # append to lists
  lr = model.optimizer.state_dict()['param_groups'][0]['lr']
  lr_list.append(lr)
  avg_valid_loss_list.append(val_results['mse'])
  psnr_list.append(val_results['psnr'])

  # update learning rate
  if opt['train'].get('scheduler') is not None:
    if iteration % opt['train'].get('scheduler_frequency') == 0:
      model.update_learning_rate(val_results['psnr'])

  print('end of iteration ' + str(iteration))

print('lr_list:', lr_list)
print('psnr_list:', psnr_list)
print('avg_train_loss_list', avg_train_loss_list)
print('avg_valid_loss_list', avg_valid_loss_list)

  "See the documentation of nn.Upsample for details.".format(mode))


(epoch:   0, iters:        0, time: 0.038) loss: 0.03082901 
(epoch:   0, iters:      150, time: 0.033) loss: 0.00181382 
(epoch:   0, iters:      300, time: 0.033) loss: 0.00121857 
(epoch:   0, iters:      449, time: 0.000) avg_train_loss: 0.00236369 learning_rate: 0.00010000 
Start validation phase ...
OrderedDict([('mse', 0.0007479374437346397), ('psnr', 31.35503714865684)])
(epoch:   0, iters:      449, time: 62.177) mse: 0.00074794 psnr: 31.35503715 
Saving the model at the end of iteration 0
end of iteration 0
(epoch:   1, iters:        0, time: 0.036) loss: 0.00049081 
(epoch:   1, iters:      150, time: 0.032) loss: 0.00069349 
(epoch:   1, iters:      300, time: 0.033) loss: 0.00072398 
(epoch:   1, iters:      449, time: 0.000) avg_train_loss: 0.00054655 learning_rate: 0.00010000 
Start validation phase ...
OrderedDict([('mse', 0.0005929996413589619), ('psnr', 32.36915115212756)])
(epoch:   1, iters:      449, time: 46.994) mse: 0.00059300 psnr: 32.36915115 
Saving the model