In [None]:
#load all desired packages

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import time
import copy
from math import exp

import torch

#MT additions
from torchvision import datasets, transforms
import h5py
from scipy.io import loadmat
import math

import os, sys
directory = os.path.abspath('')
sys.path.insert(1,os.path.join(directory,'src')) # setting path can also append directory.parent
import utils
import metrics
import datasets
import plots
import train
from UNet import UNet

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")
    
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
#run this code on GPU if you can (much faster), otherwise run on CPU
if torch.cuda.is_available():
    dev = "cuda"
    device = torch.device(dev)
else:
    dev = "cpu"
    device = torch.device(dev)

print(dev)

In [None]:
#Set training parameters and paths to save directories, logging directories, your split, etc.
batch_size     = 1
val_batch_size = 1

num_workers    = 1

num_it         = 1
modifications  = ''

maxDisp = 150


proj_path      = '/data/knee_mri9/mwtong/t1rho_map_synthesis'
#this is where you'll save log files of your training
log_save_dir   = proj_path+'/training/train_logs/'
#this is where you'll save the models themselves
model_save_dir = proj_path+'/training/checkpoints/'
#this is a csv that tracks training parameters and has paths to the specific log file and model checkpoints of 
#each run
model_info_file= proj_path+'/training/trained_model_info.csv'

split_path          = proj_path + '/splits/022_25to75Percent_Slices_NoKneeCoil.csv'
num_epochs          = 2
lr                  = 0.001
scheduler_options   = ['constant',4]
loss_options        = ['12','NRMSE',0.6,0.4]
scaleInput_options  = ['clip',0,150]
scaleMetrics_options= ['clip',0,100]
do_augmentation     = False    

In [None]:
#Load splits
label_df = pd.read_csv(split_path)

print(label_df['set'].value_counts())
label_df

In [None]:
# determine whether to augment data
if do_augmentation == True:
    #@title **Exercise:** Data augmentation
    # Visit https://pytorch.org/vision/stable/transforms.html and search for the term 'Random'.
    # This will give you a list of built-in functions for image augmentation.
    # Choose and implement a few that you think are appropriate to diversify your dataset.

    transform = transforms.Compose([
                    transforms.RandomRotation(8, fill=0.5),
                    transforms.RandomAffine(degrees=0,translate=(0.1, 0.1),fill=0.5),
                    transforms.RandomResizedCrop(size=256,scale=(0.96, 1.0),ratio=(1, 1)),
     #               transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
                    transforms.Grayscale(num_output_channels=1),
                    transforms.ConvertImageDtype(torch.float)
                    ])

    transform_type = transform
else:
    transform_type = None

# load the training, validation, and test sets using the dataloader
trainset = datasets.T1RT2Data(labels_df=label_df, set_name='train', transforms=transform_type,scale_options=['clipAndScale',0,150])
valset   = datasets.T1RT2Data(labels_df=label_df, set_name='val')
testset  = datasets.T1RT2Data(labels_df=label_df, set_name='test')
print(trainset.len, valset.len, testset.len)

In [None]:
#Visualize random entry from the training set
ind   = random.randint(0,trainset.len-1)
ind = 277
print(ind)
start = time.time()
image,label = trainset.__getitem__(ind)
end   = time.time()
print('Image loading time: '+str(np.round(end-start,3))+' seconds')
plots.plot_dataloader(trainset,ind,maxDisp)

In [None]:
#File that we will eventually create to store training logs
log_save   = os.path.join(log_save_dir,'run_'+str(len(os.listdir(log_save_dir))+1)+'.txt')
#File path to which we will store the trained model
model_save = os.path.join(model_save_dir,'run_'+str(len(os.listdir(log_save_dir))+1))
print(log_save)
print(model_save)

#Load CNN of choice and assign to correct device
generator_model = UNet(init_features = 64, in_channels = 1, out_channels = 1)
if dev == "cuda":
    generator_model.to(device)

In [None]:
#ALWAYS set random seeds when you train so that your training is reproducible!
random.seed(14)
torch.manual_seed(14)
    
# load the training, validation, and test sets using the dataloader
trainset = datasets.T1RT2Data(labels_df=label_df, set_name='train')
valset   = datasets.T1RT2Data(labels_df=label_df, set_name='val')
testset  = datasets.T1RT2Data(labels_df=label_df, set_name='test')
print(trainset.len, valset.len, testset.len)

trainset_loader = datasets.DataLoader(trainset, batch_size=batch_size,shuffle=True,num_workers=num_workers,
                             drop_last=True)
valset_loader   = datasets.DataLoader(valset, batch_size=val_batch_size,shuffle=False,num_workers=num_workers,
                             drop_last=False)
testset_loader  = datasets.DataLoader(testset, batch_size=val_batch_size,shuffle=False,num_workers=num_workers,
                             drop_last=False)

In [None]:
import train
#Load your loss function, optimizer, and learning rate scheduler, then set up and start training
if scheduler_options[0]=='constant' and len(scheduler_options)==3:
    gamma_=scheduler_options[2]
else:
    gamma_=0.1
criterion = train.NetLoss()
optimizer = torch.optim.Adam(generator_model.parameters(), lr=lr, weight_decay = 0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_options[1], gamma=gamma_)

dataloaders = {'train': trainset_loader, 'val': valset_loader}
generator_model, losses = train.train_model(dataloaders, generator_model, criterion, optimizer, scheduler, log_save,
                            loss_options, scaleMetrics_options, scheduler_options, num_epochs, return_stats=True)

#Save your model, and update the csv with the trained model list to reflect the parameters of this train, and paths 
#to best models.
torch.save(generator_model, model_save)
d = {'Ref':ii,'Model_Path': model_save, 'Log_Path': log_save,'Split_Path':split_path,
         'epochs': [num_epochs], 'lr': [lr], 'scheduler:type,step,%lossChange,%lrDecrease':[scheduler_options],
         'loss:type,metric,wSeg,wNotSeg': [loss_options],'scaleInputs:method,min,max':[scaleInput_options],
         'scaleMetrics:method,min,max':[scaleMetrics_options],'augment':do_augmentation,'modifications': modifications,
         'NRMSE': '','NRMSE_seg': '','best_loss': '','observations': ''}
df = pd.DataFrame(data = d)

df_full = pd.read_csv(model_info_file)
df_full = df_full.append(df,ignore_index = True)
#df_full.to_csv(model_info_file,index = False)


In [None]:
#Load predictions for your trained model
preds, labels = eval.test_metrics(valset_loader, generator_model)

In [None]:
#Visualize your predictions and corresponding ground truth
ind       = random.randint(0,preds.shape[0]-1)

fig  = plt.figure(figsize = (12,5))
plt.rcParams.update({'font.size': 12})
plt.subplot(121)
im1 = plt.imshow(preds[ind,0,:,:].cpu().detach().numpy(),cmap = 'jet')
plt.clim([0,max_map])
plt.axis('off')
plt.title('Predicted T1rho Map')

plt.subplot(122)
im2 = plt.imshow(labels[ind,0,:,:].cpu().detach().numpy(),cmap = 'jet')
plt.clim([0,max_map])
plt.axis('off')
plt.title('Ground Truth T1rho Map')

fig.subplots_adjust(right=0.82)
cbar_ax = fig.add_axes([0.85, 0.128, 0.01, 0.75])
cbar = fig.colorbar(im2, cax=cbar_ax)
cbar.set_label('T2 relaxation time (ms)',rotation = 270,labelpad = 20)