# Training of REDCNN
### Comments:
 


In [25]:
using_colab = True

if using_colab :
  !git clone -l -s git://github.com/juanigp/CT-denoising.git cloned-repo
  %cd cloned-repo
  from google.colab import drive
  drive.mount('/gdrive')


import os
from IPython.core.debugger import set_trace
from models.Mini_REDCNN import REDCNN
from utils import utils
import torch
import torch.nn as nn
import torch.utils.data.sampler as sampler
from torch.autograd import Variable
from matplotlib import pyplot as plt
import random
import numpy as np

Cloning into 'cloned-repo'...
remote: Enumerating objects: 175, done.[K
remote: Counting objects:   0% (1/175)[Kremote: Counting objects:   1% (2/175)[Kremote: Counting objects:   2% (4/175)[Kremote: Counting objects:   3% (6/175)[Kremote: Counting objects:   4% (7/175)[Kremote: Counting objects:   5% (9/175)[Kremote: Counting objects:   6% (11/175)[Kremote: Counting objects:   7% (13/175)[Kremote: Counting objects:   8% (14/175)[Kremote: Counting objects:   9% (16/175)[Kremote: Counting objects:  10% (18/175)[Kremote: Counting objects:  11% (20/175)[Kremote: Counting objects:  12% (21/175)[Kremote: Counting objects:  13% (23/175)[Kremote: Counting objects:  14% (25/175)[Kremote: Counting objects:  15% (27/175)[Kremote: Counting objects:  16% (28/175)[Kremote: Counting objects:  17% (30/175)[Kremote: Counting objects:  18% (32/175)[Kremote: Counting objects:  19% (34/175)[Kremote: Counting objects:  20% (35/175)[Kremote: Counting objects:  21%

## Hyperparameters, model, dataset and dataloader

In [0]:
#hyperparameters:
num_epochs = 1000
batch_size = 16
learning_rate = 0.0001

#instantiating the model:
model = REDCNN()
#model.double()

#loss function
criterion = nn.MSELoss()

#optimizer algorithm
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

#if gpu available
if torch.cuda.is_available():
    model.cuda()
    criterion.cuda()
    
#dataset
if using_colab:
  csv_file = r'/gdrive/My Drive/patches/1.csv'  
else:
  #should be XCT instead of 500FBP!!
  csv_file = r'C:\Users\Juan Pisula\Desktop\ct_images\patches\100_FBPPhil_500FBP.csv'

#dataset, dataloader  
dataset = utils.CTVolumesDataset(csv_file)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle = False)

## Inspecting data

In [0]:
"""
#inspect training examples

batches = list(dataloader)

print(len(batches)) #how many batches
batch = batches[20]
print( len( batch ) ) #length of the batches (2 = lo res, hi res)
print( batch[0].size() ) #size of the lo res volumes of the batch: batch_size volumes, size of volume
plt.imshow(batch[1][0][10][:][:], cmap = 'gray' )

#enu = enumerate(dataloader)
#len(dataloader) # = amount of patches / batch size

(lo_res, hi_res) = batch
print(lo_res.size())
lo_res = lo_res.unsqueeze(1)
print(lo_res.size())
"""

## Training the model!

In [0]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [0]:
#directory to save the model
if using_colab:
  models_dir = r'/gdrive/My Drive/models' 
else:
  models_dir = r'C:\Users\Juan Pisula\Desktop\ct_images'  

#file to record metrics  
metrics_file_name = 'metrics.csv' 
metrics_file_dir = os.path.join(models_dir, metrics_file_name)

#loading a previously trained model
resume_checkpoint = False
checkpoint_file_dir = os.path.join(models_dir,'REDCNN_checkpoint_epoch_0.pth.tar')
if resume_checkpoint:
  checkpoint = torch.load(checkpoint_file_dir)
  start_epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
else:
  start_epoch = 0


#training
total_step = len(dataloader)
model.train()

for epoch in range(start_epoch, num_epochs):
  
    for i, (lo_res, hi_res) in enumerate(dataloader):
        #add an extra dimension:
        lo_res = utils.var_or_cuda( lo_res.unsqueeze(1) )
        hi_res = utils.var_or_cuda(hi_res)
        if lo_res.size()[0] != batch_size:
            print("batch_size != {} drop last incompatible batch".format( batch_size ))
            continue
            
        #forward pass 
        outputs = model(lo_res)
        loss = criterion(outputs, hi_res.unsqueeze(1))
        #backward & optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
        if (i+1) % 1 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            
    #save model after epoch   
    checkpoint_file_dir = os.path.join(models_dir, 'REDCNN_checkpoint_epoch_' + str(epoch + 1) + '.pth.tar' )
    
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, checkpoint_file_dir)

    csv_line = str(loss.item()) + ',' + str(epoch) + '\n'
    with open(metrics_file_dir , 'a+') as file:
        file.write(csv_line)