# Training of EDCNN

In [0]:
!git clone -l -s git://github.com/juanigp/CT-denoising.git cloned-repo
%cd cloned-repo
from google.colab import drive
drive.mount('/gdrive', force_remount = True)

import os
from IPython.core.debugger import set_trace
from models.EDCNN import EDCNN
from utils import utils
import torch
import torch.nn as nn
import torch.utils.data.sampler as sampler
from torch.autograd import Variable
from matplotlib import pyplot as plt
import random
import numpy as np

## Hyperparameters, model, dataset and dataloader

In [0]:
#hyperparameters:
num_epochs = 100
batch_size = 32
learning_rate = 0.00001

#instantiating the model:
model = EDCNN()

#loss function
criterion = nn.L1Loss()

#optimizer algorithm
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

#if gpu available
if torch.cuda.is_available():
    model.cuda()
    criterion.cuda()
    
#dataset and dataloaders
#csv file containing the directories of the lo res and ground truth patches
csv_file = r'/gdrive/My Drive/patches/100_FBPPhil.csv' 
dataset = utils.CTVolumesDataset(csv_file)

#split of data in training and testing data:
#the .csv is shuffled (using the same seed everytime for repeatability)
num_samples = len(dataset)
total_idx = list(range(num_samples))
random.seed(10)
random.shuffle(total_idx)

#pick 10% of samples to test
testing_samples_percentage = 0.1
split_index = int( num_samples * testing_samples_percentage )
#pick the first 10% of samples in the shuffled dataset for testing
testing_idx = total_idx[0 : split_index]
#pick the other 90% of samples in the shuffled dataset for training
training_idx = total_idx[split_index : num_samples]
#random samplers for training and testing
training_sampler = sampler.SubsetRandomSampler(training_idx)
testing_sampler = sampler.SubsetRandomSampler(testing_idx)
#dataloaders for training and testing
training_dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = batch_size, sampler = training_sampler)
testing_dataloader = torch.utils.data.DataLoader(dataset = dataset, batch_size = batch_size, sampler = testing_sampler)

## Training the model!

In [0]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [0]:
#directory to save the models
models_dir = r'/gdrive/My Drive/models'
#file to record metrics  
metrics_file_name = 'training_loss.csv' 
metrics_file_dir = os.path.join(models_dir, metrics_file_name)

#loading a previously trained model
resume_checkpoint = False
checkpoint_file_dir = pass
if resume_checkpoint:
  checkpoint = torch.load(checkpoint_file_dir)
  start_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])
else:
  start_epoch = 0


#training and testing simultaneously
total_step = len(dataloader)
model.train()

for epoch in range(start_epoch, num_epochs):
    #training epoch
    training_epoch_loss = 0
    num_batches = 0
    for i, (lo_res, hi_res) in enumerate(dataloader):
        #add an extra dimension:
        lo_res = utils.var_or_cuda( lo_res.unsqueeze(1) )
        hi_res = utils.var_or_cuda(hi_res)
        if lo_res.size()[0] != batch_size:
          continue  
        num_batches += 1 
        #forward pass 
        outputs = model(lo_res)
        loss = criterion(outputs, hi_res.unsqueeze(1))
        training_epoch_loss += loss.item()
        #backward & optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    training_epoch_loss /= num_batches
    #save model after training epoch   
    checkpoint_file_dir = os.path.join(models_dir, 'EDCNN_checkpoint_epoch_' + str(epoch + 1) + '.pth.tar' )   
    save_checkpoint({
        'epoch': epoch + 1,
        'model': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, checkpoint_file_dir)
    print('Training epoch [{}/{}]'.format(epoch+1, num_epochs)) 
    
    #testing epoch
    model.eval()
    testing_epoch_loss = 0
    num_batches = 0
    with torch.no_grad():
      for batch, (lo_res, hi_res) in enumerate(testing_dataloader):
        #add an extra dimension:
        lo_res = utils.var_or_cuda( lo_res.unsqueeze(1) )
        hi_res = utils.var_or_cuda(hi_res)
        if lo_res.size()[0] != batch_size:
            continue
        num_batches += 1
        outputs = model(lo_res)
        loss = criterion(outputs, hi_res.unsqueeze(1))
        testing_epoch_loss += loss.item()

    testing_epoch_loss /= num_batches
    print('Testing epoch [{}/{}]'.format(epoch+1, num_epochs) )     

    csv_line = str(training_epoch_loss) + ',' + str(testing_epoch_loss) + '\n'
    with open(metrics_file_dir , 'a+') as file:
        file.write(csv_line)