In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import os
os.chdir('gdrive/My Drive/Super_resolution/brainSR')

In [None]:
from data.image_list import create_image_list
from data.dataset import create_dataset
from data.data_loader import create_dataloader
import torch
import time
import random
from collections import OrderedDict
import math
from math import log10
from utils import util, convert, metric

In [None]:
opt = {
    "name": "SRResNet_x4",
    "model": "sr_resnet",
    "gpu_ids": [0],
    "upscale_factor":4,
    "device": "cuda",

    "datasets":
        {
            "path":
                {
                "root": "data/BrainTumour",
                "folders": ["imagesTr", "imagesTs"]
                },
            "train":
                {
                "phase": "train",
                "batch_size": 64,
                "use_shuffle": True,
                "upscale_factor": 4,
                "scale": True,
                # "transform_H": {},
                "no_of_images": 725,
                "n_workers": 1
                },
            "valid":
                {
                "phase": "valid",
                "no_of_images": 25,
                "scale": True,
                "depth_padding": 1,
                "upscale_factor": 4                
                # "transform_H": {}
                }
    },

    
  "path": 
    {
        "root": ""
    },

  "network":
    {
        "norm_type": "batch",
        "which_model_G": "sr_resnet",
        "ngf": 64,
        "ngb": 16,
        "input_ngc": 1,
        "output_ngc": 1
    },

  "train" : 
    {
        "manual_seed": 777,
        "niter": 30,
        "val_freq": 5,
        "lr_G": 1e-4,
        "criterion": "mse"
    },


    "logger": 
    {
        "print_freq": 100,
        "save_checkpoint_freq": 5
    }
}


In [None]:
def parse(opt_path, is_train=True):
#     with open(opt_path, 'r') as f:
#         opt = json.load(f, object_pairs_hook=OrderedDict)
    opt['is_train'] = is_train

    for key, path in opt['path'].items():
        opt['path'][key] = os.path.expanduser(path)
    if is_train:
        experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['options'] = experiments_root
        opt['path']['trained_models'] = os.path.join(experiments_root, 'trained_models')
        opt['path']['log'] = os.path.join(experiments_root, 'log')
    else:
        results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = os.path.join(results_root, 'log')
        opt['path']['test_images'] = os.path.join(results_root, 'test_images')

    return opt

opt = parse(opt)


In [None]:
training_image_list = create_image_list(opt, train=True)
valid_image_list = create_image_list(opt, valid=True)


0/725 images loaded
50/725 images loaded
100/725 images loaded
150/725 images loaded
200/725 images loaded
250/725 images loaded
300/725 images loaded
350/725 images loaded
400/725 images loaded
450/725 images loaded
500/725 images loaded
550/725 images loaded
600/725 images loaded
650/725 images loaded
700/725 images loaded
0/25 images loaded


In [None]:
training_set = create_dataset(opt["datasets"]["train"], training_image_list)
valid_set = create_dataset(opt["datasets"]["valid"], valid_image_list)


In [None]:
from models.models import create_model
model = create_model(opt)

  nn.init.kaiming_normal(self.C.weight, a=a, mode='fan_in')


---------- Model initialized -------------
Number of parameters in G: 1528724
-----------------------------------------------
Model [SRResNetModel] is created.


In [None]:
from utils.logger import Logger
logger = Logger(opt)

In [None]:
def validate(val_loader, val_size, depth, model, logger, epoch, current_step):
    print('Start validation phase ...')
    val_start_time = time.time()
    model.eval() # Change to eval mode. It is important for BN layers.

    val_results = OrderedDict()
    avg_psnr = 0.0 
    for val_data in val_loader:
        # img_path = val_data['path'][0]
        # img_name = os.path.splitext(os.path.basename(img_path))[0]
        # img_dir = os.path.join(opt['path']['root'], 'valid', img_name)
        # util.mkdir(img_dir)
        for slice_no in range(depth):
          val_data_slice = dict(H=val_data['H'][slice_no,:,:,:].unsqueeze(0),
                                L=val_data['L'][slice_no,:,:,:].unsqueeze(0))
          model.feed_data(val_data_slice)
          model.val()

          visuals = model.get_current_visuals()
          
          sr_img = visuals['super-resolution'] # uint8
          gt_img = visuals['ground-truth'] # uint8

          mse = model.criterion(sr_img, gt_img).item()
          psnr = 10 * log10(1 / mse)
          avg_psnr += psnr

    avg_psnr = avg_psnr / val_size / depth
    val_results['psnr'] = avg_psnr

    val_duration = time.time() - val_start_time
    # Save to log
    logger.print_results(val_results, epoch, current_step, val_duration, 'val')
    model.train() # Change back to train mode.


In [None]:
model.train()

# run this model
start_time = time.time()

val_size = len(valid_set)
depth = 624

for iteration in range(opt['train']['niter']+1):
    for i, train_data in enumerate(training_set):
        train_start_time = time.time()
        # training
        model.feed_data(train_data)
        model.optimize_parameters(i)
        train_duration = time.time() - train_start_time
        
        # print losses
        if i % opt['logger']['print_freq'] == 0:
          losses = model.get_current_losses()
          logger.print_results(losses, iteration, i, train_duration, 'loss')

    if iteration != 0:
        # validation
        if iteration % opt['train']['val_freq'] == 0:
          validate(valid_set, val_size, depth, model, logger, iteration, i)

        # save
        if iteration % opt['logger']['save_checkpoint_freq'] == 0:
            print('Saving the model at the end of iteration %d' % (iteration))
            model.save(iteration)

    print('end of iteration ' + str(iteration))

  "See the documentation of nn.Upsample for details.".format(mode))


(epoch:   0, iters:        0, time: 0.771) loss: 0.787772 
(epoch:   0, iters:      100, time: 0.751) loss: 0.008760 
(epoch:   0, iters:      200, time: 0.751) loss: 0.004576 
(epoch:   0, iters:      300, time: 0.750) loss: 0.003137 
(epoch:   0, iters:      400, time: 0.751) loss: 0.003046 
(epoch:   0, iters:      500, time: 0.755) loss: 0.001649 
(epoch:   0, iters:      600, time: 0.753) loss: 0.002698 
(epoch:   0, iters:      700, time: 0.753) loss: 0.002341 
end of iteration 0
(epoch:   1, iters:        0, time: 0.750) loss: 0.001461 
(epoch:   1, iters:      100, time: 0.750) loss: 0.001399 
(epoch:   1, iters:      200, time: 0.750) loss: 0.001242 
(epoch:   1, iters:      300, time: 0.756) loss: 0.001412 
(epoch:   1, iters:      400, time: 0.753) loss: 0.001344 
(epoch:   1, iters:      500, time: 0.750) loss: 0.000988 
(epoch:   1, iters:      600, time: 0.750) loss: 0.001349 
(epoch:   1, iters:      700, time: 0.753) loss: 0.001351 
end of iteration 1
(epoch:   2, iters

Continue training

In [None]:
model.train()

# run this model
start_time = time.time()

val_size = len(valid_set)
depth = 624

for iteration in range(opt['train']['niter']+1):
    for i, train_data in enumerate(training_set):
        train_start_time = time.time()
        # training
        model.feed_data(train_data)
        model.optimize_parameters(i)
        train_duration = time.time() - train_start_time
        
        # print losses
        if i % opt['logger']['print_freq'] == 0:
          losses = model.get_current_losses()
          logger.print_results(losses, iteration, i, train_duration, 'loss')

    if iteration != 0:
        # validation
        if iteration % opt['train']['val_freq'] == 0:
          validate(valid_set, val_size, depth, model, logger, iteration, i)

        # save
        if iteration % opt['logger']['save_checkpoint_freq'] == 0:
            print('Saving the model at the end of iteration %d' % (iteration))
            model.save(iteration)

    print('end of iteration ' + str(iteration))

2

In [None]:
model.load_path_G = 'experiments/SRResNet_x4/trained_models/25_G.pth'
model.load()

loading model for G [experiments/SRResNet_x4/trained_models/25_G.pth] ...


In [None]:
model.train()

# run this model
start_time = time.time()

val_size = len(valid_set)
depth = 624

for iteration in range(26, 46):
    for i, train_data in enumerate(training_set):
        train_start_time = time.time()
        # training
        model.feed_data(train_data)
        model.optimize_parameters(i)
        train_duration = time.time() - train_start_time
        
        # print losses
        if i % opt['logger']['print_freq'] == 0:
          losses = model.get_current_losses()
          logger.print_results(losses, iteration, i, train_duration, 'loss')

    if iteration != 0:
        # validation
        if iteration % opt['train']['val_freq'] == 0:
          validate(valid_set, val_size, depth, model, logger, iteration, i)

        # save
        if iteration % opt['logger']['save_checkpoint_freq'] == 0:
            print('Saving the model at the end of iteration %d' % (iteration))
            model.save(iteration)

    print('end of iteration ' + str(iteration))

  "See the documentation of nn.Upsample for details.".format(mode))


(epoch:  26, iters:        0, time: 0.765) loss: 0.000248 
(epoch:  26, iters:      100, time: 0.751) loss: 0.000295 
(epoch:  26, iters:      200, time: 0.751) loss: 0.000383 
(epoch:  26, iters:      300, time: 0.762) loss: 0.000534 
(epoch:  26, iters:      400, time: 0.753) loss: 0.000634 
(epoch:  26, iters:      500, time: 0.763) loss: 0.000256 
(epoch:  26, iters:      600, time: 0.756) loss: 0.000491 
(epoch:  26, iters:      700, time: 0.755) loss: 0.000621 
end of iteration 26
(epoch:  27, iters:        0, time: 0.754) loss: 0.000257 
(epoch:  27, iters:      100, time: 0.752) loss: 0.000336 
(epoch:  27, iters:      200, time: 0.753) loss: 0.000332 
(epoch:  27, iters:      300, time: 0.752) loss: 0.000484 
(epoch:  27, iters:      400, time: 0.753) loss: 0.000490 
(epoch:  27, iters:      500, time: 0.754) loss: 0.000280 
