In [1]:
from data.image_list import create_image_list_with_label
from data.dataset import create_dataset
from data.data_loader import create_dataloader
import torch
import time
import random
from collections import OrderedDict
import math
from math import log10
from utils import util, convert, metric
import os
import numpy as np
import nibabel as nib
from data.transforms import input_transform, target_transform, label_transform
import torch

In [2]:
opt = {
    "name": "SFT_ResNet_x4",
    "model": "sft_resnet",
    "gpu_ids": [0],
    "upscale_factor":4,
    "device": "cuda",

    "datasets":
        {
            "path":
                {
                "root": "data/BrainTumour",
                "folders": dict(tumour_images="imagesTr",
                                normal_images="imagesTs",
                                tumour_labels="labelsTr")
                },
            "train":
                {
                "phase": "train",
                "batch_size": 64,
                "use_shuffle": True,
                "upscale_factor": 4,
                "scale": True,
                "no_of_images": dict(tumour=450, normal=250),
                "n_workers": 1
                },
            "valid":
                {
                "phase": "valid",
                "no_of_images": dict(tumour=34, normal=16),
                "scale": True,
                "depth_padding": 1,
                "upscale_factor": 4                
                },
            "condition": True,
            "label_names_list": ["edema", "non-enhancing tumour", "enhancing tumour"]
    },

    
  "path": 
    {
        "root": ""
    },

  "network":
    {
        "norm_type": "batch",
        "which_model_G": "sr_resnet",
        "ngf": 64,
        "ngb": 16,
        "input_ngc": 1,
        "output_ngc": 1
    },

  "train" : 
    {
        "manual_seed": 777,
        "niter": 30,
        "val_freq": 5,
        "lr": 1e-4,
        "criterion": "mse"
    },


    "logger": 
    {
        "print_freq": 100,
        "save_checkpoint_freq": 5
    }
}

def parse(opt_path, is_train=True):
#     with open(opt_path, 'r') as f:
#         opt = json.load(f, object_pairs_hook=OrderedDict)
    opt['is_train'] = is_train

    for key, path in opt['path'].items():
        opt['path'][key] = os.path.expanduser(path)
    if is_train:
        experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['options'] = experiments_root
        opt['path']['trained_models'] = os.path.join(experiments_root, 'trained_models')
        opt['path']['log'] = os.path.join(experiments_root, 'log')
    else:
        results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = os.path.join(results_root, 'log')
        opt['path']['test_images'] = os.path.join(results_root, 'test_images')

    return opt

opt = parse(opt)


In [3]:
# import json
# with open('data/BrainTumour/dataset.json') as f:
#     dataset_json = json.load(f)
# dataset_json

In [4]:
training_image_list = create_image_list_with_label(opt, train=True)
valid_image_list = create_image_list_with_label(opt, valid=True)


0
50
100
150
200
250
300
350
400
0
50
100
150
200
0
0


In [5]:
training_set = create_dataset(opt["datasets"]["train"], training_image_list, use_condition=True)
valid_set = create_dataset(opt["datasets"]["valid"], valid_image_list, use_condition=True)

In [7]:
from models.models import create_model
model = create_model(opt)

---------- Model initialized -------------
Number of parameters in G: 1822371
-----------------------------------------------
Model [SRSFTResNet] is created.


In [8]:
from utils.logger import Logger
logger = Logger(opt)

In [9]:
def validate(val_loader, val_size, depth, model, logger, epoch, current_step):
    print('Start validation phase ...')
    val_start_time = time.time()
    model.eval() # Change to eval mode. It is important for BN layers.

    val_results = OrderedDict()
    avg_psnr = 0.0 
    for val_data in val_loader:
        for slice_no in range(depth):
          val_data_slice = dict(H=val_data['H'][slice_no,:,:,:].unsqueeze(0),
                                L=val_data['L'][slice_no,:,:,:].unsqueeze(0))
          model.feed_data(val_data_slice)
          model.val()

          visuals = model.get_current_visuals()
          
          sr_img = visuals['super-resolution'] # uint8
          gt_img = visuals['ground-truth'] # uint8

          mse = model.criterion(sr_img, gt_img).item()
          psnr = 10 * log10(1 / mse)
          avg_psnr += psnr

    avg_psnr = avg_psnr / val_size / depth
    val_results['psnr'] = avg_psnr

    val_duration = time.time() - val_start_time
    # Save to log
    logger.print_results(val_results, epoch, current_step, val_duration, 'val')
    model.train() # Change back to train mode.


In [10]:
model.train()

# run this model
start_time = time.time()

val_size = len(valid_set)
depth = 624

for iteration in range(opt['train']['niter']+1):
    for i, train_data in enumerate(training_set):
        train_start_time = time.time()
        # training
        model.feed_data(train_data)
        model.optimize_parameters(i)
        train_duration = time.time() - train_start_time
        
        # print losses
        if i % opt['logger']['print_freq'] == 0:
          losses = model.get_current_losses()
          logger.print_results(losses, iteration, i, train_duration, 'loss')

    if iteration != 0:
        # validation
        if iteration % opt['train']['val_freq'] == 0:
          validate(valid_set, val_size, depth, model, logger, iteration, i)

        # save
        if iteration % opt['logger']['save_checkpoint_freq'] == 0:
            print('Saving the model at the end of iteration %d' % (iteration))
            model.save(iteration)

    print('end of iteration ' + str(iteration))

  "See the documentation of nn.Upsample for details.".format(mode))


RuntimeError: CUDA out of memory. Tried to allocate 58.00 MiB (GPU 0; 6.00 GiB total capacity; 4.44 GiB already allocated; 44.82 MiB free; 4.57 GiB reserved in total by PyTorch)