In [1]:
import sys
sys.path.append('../')

import os
import numpy as np
import torch
import pandas as pd
from skimage import transform # io, 
import PIL
import math
from glob import glob

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms, utils, datasets, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import pickle

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline
plt.ion()   # interactive mode

%load_ext autoreload
%autoreload 2


from datasets import GeneralDataset
import Transforms as myTransforms

# from visualize import show_2Dbatch, show_4channel_batch, show_3Dbatch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
def get_data_splits(metadata_df, task='idh', mtl = False):
    
    '''
    This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
    Arguments
    ---------
    metadata_df:   For each sample this dataframe indicates
                        1) whether it is in the labeled training, unlabeled training set, or valiation set
                        2) its idh status, and 1p19q status
    task:          Either 'idh' or '1p19q'
    mtl:           If True, MRI data without IDH mutation or 1p/19q co-deletion labels will be included in the traing set
    
    Outputs
    ---------
    train_df:      Dataframe of training samples
    val_df:        Dataframe of validation samples
    classes:       Names of numerical labels
    '''
    
    # validation set
    if task == 'idh':
        classes = ['wildtype', 'mutant']
        val_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'val') 
                                        & (glioma_metadata_df[task].isin([0,1]))] # check whether IDH status known
    elif task == '1p19q':
        classes = ['non-codel', 'oligo']
        val_df = glioma_metadata_df.loc[glioma_metadata_df['phase'] == 'val']
    
    # training set
    if mtl:
        train_df = glioma_metadata_df.loc[glioma_metadata_df['phase'].isin(['train', 'unlabeled train'])]
    else:
        train_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'train') 
                                          & (glioma_metadata_df[task].isin([0,1]))] # only labeled data (0/1)

    return train_df, val_df, classes

In [14]:
# task = '1p19q'
task = 'idh'

if task == 'idh':
    classes = ['wildtype', 'mutant']
elif task == '1p19q':
    classes = ['non-codel', 'oligo']
else:
    print('task?')

brats2tcia_df = pd.read_csv('../../miccai_clean/data/brats2tcia_df_542x1.csv', index_col=0)

# old way to get train/val
# glioma_metadata_df = pd.read_csv('../../miccai_clean/data/all_glioma_metadata_542x30.csv', index_col=0)
# glioma_metadata_df = glioma_metadata_df.rename(columns={'OS.time':'OS', '_EVENT':'OS_EVENT'})
# val_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df['inferred_subtype'] == 0)]
# train_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'train') & (glioma_metadata_df['inferred_subtype'] == 0)]

# new way to get train/val
glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0) # metadata file
glioma_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1 ######
# # glioma_metadata_df['idh_cluster'] = glioma_metadata_df['idh']
# train_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'train') 
#                                           & (glioma_metadata_df[task].isin([0,1]))] # only labeled data (0/1)
# val_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'val') 
#                                         & (glioma_metadata_df[task].isin([0,1]))] # check whether IDH status known

mtl = False
train_df, val_df, classes = get_data_splits(metadata_df=glioma_metadata_df, task=task, mtl=mtl)

csv_files = {'train':train_df, 'val':val_df, 'data':glioma_metadata_df}



genomic_csv_files = {'train':'../data/MGL/MGL_235x50.csv', 
             'val':'../data/MGL/MGL_235x50.csv',
             'data':'../data/MGL/MGL_235x50.csv'}

image_train_dir = '../data/mr_data/train/'
image_dir = '../data/all_brats_scans/'
# image_dir = '/data/nick/all_brats_scans/'
# image_dir = '../data/whitestrip_n4_mr_data/'

# cluster_column= task + '_cluster'

train_batch_size = 4
val_batch_size = 4
data_batch_size = 1


resize_shape = (64, 64, 64)
# resize_shape = (16, 16, 16)
channels = 1
interpolation = 1

shuffle = True
shuffle_data = False

modality = 't1ce'
dataformat='modality3D'

# label = 'idh_cluster'
label = 'idh'

print('Training size:', len(csv_files['train']))
null_genomic = False

Training size: 112


In [15]:
print('task:\t\t', task)
print('mtl:\t\t', mtl)
print('dataformat:\t', dataformat)
print('channels:\t', channels)
print('modality:\t', modality)
print('resize_shape:\t', resize_shape)
print('null_genomic:\t', null_genomic)

task:		 idh
mtl:		 False
dataformat:	 modality3D
channels:	 1
modality:	 t1ce
resize_shape:	 (64, 64, 64)
null_genomic:	 False


In [4]:
dataformat

'modality3D'

In [5]:
brats2tcia_df

Unnamed: 0,tciaID
Brats18_2013_0_1,
Brats18_2013_10_1,
Brats18_2013_11_1,
Brats18_2013_12_1,
Brats18_2013_13_1,
...,...
Brats18_WashU_W051_1,
Brats18_WashU_W053_1,
Brats18_WashU_W061_1,
Brats18_WashU_W065_1,


In [6]:
# ## new files
# g_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0) # metadata file
# g_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1 ######
# t_df = train_df = g_metadata_df.loc[(g_metadata_df['phase'] == 'train') 
#                                           & (g_metadata_df[task].isin([0,1]))] # only labeled data (0/1)

# ## old files
# glioma_metadata_df = pd.read_csv('../../miccai_clean/data/all_glioma_metadata_542x30.csv', index_col=0)
# glioma_metadata_df = glioma_metadata_df.rename(columns={'OS.time':'OS', '_EVENT':'OS_EVENT'})
# val_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df['inferred_subtype'] == 0)]
# train_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'train') & (glioma_metadata_df['inferred_subtype'] == 0)]

# # check to see if they are the same!
# df = pd.concat([train_df['idh_cluster'], t_df['idh']], axis=1, join='inner')
# df.loc[df['idh_cluster'] != df['idh']]

# 3D cropped

In [7]:
train_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                  interpolation=interpolation, 
                                  channels=channels),
        myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
#         myTransforms.ZeroChannel(prob_zero=0.5),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor(),
    ])

seg_transformations = myTransforms.Compose([
    myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                  interpolation=0, 
                                  channels=1),
        myTransforms.ToTensor(),
    ])


val_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                  interpolation=interpolation, 
                                  channels=channels),
        myTransforms.ToTensor(),
    ])

data_transforms = {'train': train_transformations,
                    'val':   val_transformations,
                    'data':   val_transformations,
                   'seg': seg_transformations,
                  }



transformed_dataset_train = GeneralDataset(csv_file=csv_files['train'],
                                           root_dir=image_dir,
                                           genomic_csv_file = genomic_csv_files['train'],
                                           transform=data_transforms['train'],
                                           seg_transform=data_transforms['seg'],
                                           label=label,
                                           classes=classes,
                                           dataformat=dataformat,
                                           returndims=resize_shape,
                                           brats2tcia_df=brats2tcia_df,
                                           modality=modality,
                                           null_genomic=True)

transformed_dataset_val = GeneralDataset(csv_file=csv_files['val'],
                                         root_dir=image_dir,
                                         genomic_csv_file = genomic_csv_files['val'],
                                         transform=data_transforms['val'],
                                         seg_transform=data_transforms['seg'],
                                         classes=classes,
                                         label=label,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         brats2tcia_df=brats2tcia_df,
                                         modality=modality,
                                         null_genomic=True)


transformed_dataset_data = GeneralDataset(csv_file=csv_files['data'],
                                         root_dir=image_dir,
                                          genomic_csv_file = genomic_csv_files['data'],
                                         transform=data_transforms['data'],
                                          seg_transform=data_transforms['seg'],
                                         classes=classes,
                                          label=label,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         brats2tcia_df=brats2tcia_df,
                                          modality=modality,
                                         null_genomic=True)

image_datasets = {'train':transformed_dataset_train, 
                  'val':transformed_dataset_val,
                  'data':transformed_dataset_data}


dataloader_train = DataLoader(image_datasets['train'], batch_size=train_batch_size, shuffle=shuffle, num_workers=4)
dataloader_val = DataLoader(image_datasets['val'], batch_size=val_batch_size, shuffle=shuffle, num_workers=4)
dataloader_data = DataLoader(image_datasets['data'], batch_size=data_batch_size, shuffle=shuffle_data, num_workers=4)

dataloaders = {'train':dataloader_train, 'val':dataloader_val, 'data':dataloader_data}

dataset_sizes = {'train':len(image_datasets['train']), 
                 'val':len(image_datasets['val']), 
                 'data':len(image_datasets['data'])}

class_names = image_datasets['train'].classes
class_names

['wildtype', 'mutant']

In [8]:
# subtype_dict = {0:'wildtype', 1:'val'}
# for i, data in enumerate(dataloaders['data']):
# #     (image_slices, image_volumes), labels = data
#     (image, seg_image, genomic_data), cluster, bratsID = data
    
#     shape = image.shape

#     img = image[:,:,:,int(shape[-1]/2)].squeeze()
#     img = utils.make_grid(img)
#     img = img.detach().cpu().numpy()
    
#     plt.figure(figsize=(15, 8))
#     plt.imshow(np.hstack([img[0].T, img[1].T, img[2].T, img[3]]), cmap='Greys_r')
    
    
#     print('**', image.shape)
#     break

In [9]:
np.unique(train_df['idh'], return_counts=True)

(array([0., 1.]), array([55, 57]))

In [10]:
from torch.utils.tensorboard import SummaryWriter
img_dims = str(resize_shape[0]) + 'x' + str(resize_shape[1]) + 'x' + str(resize_shape[2])
model_outfile_dir = '3D_' + task + '_' + modality + '_' + img_dims +'_bs-' + str(train_batch_size) + '_newdataloader_only-gt-idh'
print('tensorboad:', model_outfile_dir)
writer = SummaryWriter('runs1/'+model_outfile_dir)

tensorboad: 3D_idh_t1ce_64x64x64_bs-4_newdataloader_only-gt-idh


In [11]:
channels

1

In [12]:
num_classes = 2
best_model_loc = '../pretrained/espnet_3d_brats.pth'

# cluster_df = csv_files['train'][task + '_cluster']
cluster_df = csv_files['train'][task]
_, cnts = np.unique(cluster_df, return_counts=True)
loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
loss_weights = torch.FloatTensor(loss_weights).to(device)
criterion = nn.CrossEntropyLoss(weight=loss_weights)
print('loss weights:', loss_weights)


from train import train
best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
epochs = 50
iterations = 10
for i in range(iterations):
    print('Iteration', i)
    
    
    # resize_shape = (64, 64, 64)

    from models.Models import SegModel
    
    esp_model = SegModel(best_model_loc=best_model_loc, 
                         inp_res = resize_shape, 
                         num_classes=num_classes, 
                         channels=4)
    
    level0_weight = esp_model.espnet.level0.conv.weight[:, 0].unsqueeze(1)
    esp_model.espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
                                             stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

    esp_model.espnet.level0.conv.weight = nn.Parameter(level0_weight)
    esp_model = esp_model.to(device=device)

    optimizer_esp = optim.Adam(esp_model.parameters(), lr=0.0005) # change to adam
#     exp_scheduler = optim.lr_scheduler.StepLR(optimizer_esp, step_size=7, gamma=0.1)

    exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_esp, 
                                                         mode='min', 
                                                         factor=0.1, 
                                                         patience=10, # number of epochs with no change 
                                                         verbose=True, 
                                                         threshold=0.0001, 
                                                         threshold_mode='rel', 
                                                         cooldown=0, 
                                                         min_lr=0, 
                                                         eps=1e-08)


    model, best_wts, best_auc, best_acc, best_auc_acc = train(model=esp_model, 
                       dataloaders=dataloaders,
                       data_transforms=data_transforms,
                       criterion = criterion, 
                       optimizer=optimizer_esp, 
                       scheduler=exp_scheduler,
                       writer=writer,
                       num_epochs=epochs, 
                       verbose=False, 
                       device=device,
                       dataset_sizes=dataset_sizes,
                       channels=1,
                       resize_shape=resize_shape,
                       classes=class_names,
#                        use_probs=True,
                       volume_val=False,
#                        percentile=50,
                       weight_outfile_prefix=model_outfile_dir)
    del esp_model
    del model
    
    best_auc_list.append(best_auc)
    best_acc_list.append(best_acc)
    best_auc_acc_list.append(best_auc_acc)
    
    
    if not os.path.exists('../model_weights/results/'):
        os.makedirs('../model_weights/results/')
    
    results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
    with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_list, fp)
    with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_acc_list, fp)
    with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_acc_list, fp)

loss weights: tensor([1.0364, 1.0000], device='cuda:0')
Iteration 0
training . . . 


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best ACC:	 0.5283564814814814 	in epoch 0
New Best AUC-acc aveage:	 0.6663773148148148 	in epoch 0
New Best ACC:	 0.7552083333333333 	in epoch 1
New Best AUC-acc aveage:	 0.7769097222222221 	in epoch 1
New Best ACC:	 0.7795138888888888 	in epoch 3
New Best AUC-acc aveage:	 0.7925347222222221 	in epoch 3
New Best ACC:	 0.8049768518518519 	in epoch 6
New Best AUC:	 0.863425925925926 	in epoch 6
New Best AUC-acc aveage:	 0.834201388888889 	in epoch 6



KeyboardInterrupt: 

---
Modality: T1ce (March 8th)

50 epoch

10 loops

3D_idh_t1ce_64x64x64_bs-4_newdataloader

    best_auc_list std: 0.015810795227100584
    best_auc_list mean: 0.8700231481481481
    
    acc std: 0.02270886719744711
    best_acc_list mean: 0.8197337962962962
    
    
 
---
3D_idh_t1ce_64x64x64_bs-4_newdataloader_only-gt-idh



3D_idh_t1ce_64x64x64_bs-4_newdataloader_only-gt-idh

    best_auc_list std: 0.016007827330301965
    best_auc_list mean: 0.8934027777777779
    
    acc std: 0.02130698705091663
    best_acc_list mean: 0.842013888888889
    
    
 
---

tensorboad: 3D_1p19q_t1ce_64x64x64_bs-4_newdataloader_only-gt-idh

    best_auc_list std: 0.05383139511559983
    best_auc_list mean: 0.7228787878787879
    
    acc std: 0.04590909090909092
    best_acc_list mean: 0.5153030303030304

In [None]:
np.mean(best_acc_list)