In [1]:
import sys
sys.path.append('../')

import os
import numpy as np
import torch
import pandas as pd
from skimage import transform # io, 
import PIL
import math

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms, utils, datasets, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import pickle

from utils import get_bb_3D

from glob import glob
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm_notebook as tqdm
%matplotlib inline
plt.ion()   # interactive mode

%load_ext autoreload
%autoreload 2


from datasets import GeneralDataset
import Transforms as myTransforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Data

In [2]:
def get_data_splits(glioma_metadata_df, task='idh', mtl = False):
    '''
    glioma_metadata_df: for each sample, indicates
        - if it is in the labeled training, unlabeled training set, or valiation set, 
        - idh status, and 1p19q status
    task: 'idh' or '1p19q'
    mlt: True or False
    '''
    if task == 'idh':
        classes = ['wildtype', 'mutant']
        val_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df[task].isin([0,1]))]
        

    elif task == '1p19q':
        classes = ['non-codel', 'oligo']
        val_df = glioma_metadata_df.loc[glioma_metadata_df['phase'] == 'val']
        
    if mtl:
        train_df = glioma_metadata_df.loc[glioma_metadata_df['phase'].isin(['train', 'unlabeled train'])]
    else:
        train_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'train') & (glioma_metadata_df[task].isin([0,1]))]

    return train_df, val_df, classes

In [3]:
mtl = True

# task = '1p19q'
task = 'idh'

# metadata for all brats (including tcia) data
best_model_loc = '../pretrained/espnet_3d_brats.pth'
glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0)
glioma_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1

train_df, val_df, classes = get_data_splits(glioma_metadata_df=glioma_metadata_df, task=task, mtl=mtl)

In [4]:
train_df.shape

(467, 9)

In [5]:
val_df.shape

(59, 9)

In [5]:


# # we are predicting either idh gene mutation or chromosome arm 1p and 19q co-deletion
glioma_metadata_df = pd.read_csv('../../miccai_clean/data/all_glioma_metadata_542x30.csv', index_col=0)
glioma_metadata_df = glioma_metadata_df.rename(columns={'OS.time':'OS', '_EVENT':'OS_EVENT'})
# if task == 'idh':
#     classes = ['wildtype', 'mutant']
#     val_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df['inferred_subtype'] == 0)]

# elif task == '1p19q':
#     c
#     print('invalid classification given')

# metadata for all brats (including tcia) data

brats2tcia_df = pd.read_csv('../../miccai_clean/data/brats2tcia_df_542x1.csv', index_col=0)

# these are labeled files (they were paths in old dataloader) but they are dataframes
csv_files = {'train':train_df,
                 'val':val_df,
                 'data':glioma_metadata_df}

# genomic data (you won't need this)
genomic_csv_files = {'train':'../data/MGL/MGL_235x50.csv', 
                     'val':  '../data/MGL/MGL_235x50.csv',
                     'data': '../data/MGL/MGL_235x50.csv'}

image_dir = '../data/all_brats_scans/'

# labels we predict during classification 
cluster_column= task + '_cluster'


# downsample dim. 
resize_shape = (64, 64, 64)

# dataloader batch size
train_batch_size = 4
val_batch_size = 4
data_batch_size = 1

shuffle = True
shuffle_data = True

dataformat = 'modality3D_mtl'
modality = 't1ce'
channels = 1

null_genomic = False

print('Train size', len(csv_files['train']))

Train size 467


# Dataloaders

In [6]:
# minimal data augmentation (you can add more)
train_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
#         myTransforms.ZeroChannel(prob_zero=0.5),
        myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
        myTransforms.ToTensor(),
    ])

# segmentation masks have separate transformations (you don't want to normalize)
seg_transformations = myTransforms.Compose([
      myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                  interpolation=0, 
                                  channels=channels),
        myTransforms.ToTensor(),
    ])

val_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
        myTransforms.ToTensor(),
    ])

data_transforms = {'train': train_transformations,
                    'val':   val_transformations,
                    'data':   val_transformations,
                   'seg': seg_transformations,
                  }



best_model_loc = '../pretrained/espnet_3d_brats.pth'

# 'jointmodel' dataformat returns entire MR images (but with the black padding cropped out)
transformed_dataset_train = GeneralDataset(csv_file=csv_files['train'],
                                           root_dir=image_dir,
                                           genomic_csv_file = genomic_csv_files['train'],
                                           transform=data_transforms['train'],
                                           seg_transform=data_transforms['seg'],
                                           seg_probs_transform=data_transforms['seg'],
                                           classes=classes,
                                           dataformat=dataformat,
                                           returndims=resize_shape,
                                           label=cluster_column,
                                           brats2tcia_df=brats2tcia_df,
                                           null_genomic = null_genomic,
                                           pretrained=best_model_loc,
                                           device=device,
                                           modality=modality)

transformed_dataset_val = GeneralDataset(csv_file=csv_files['val'],
                                         root_dir=image_dir,
                                         genomic_csv_file = genomic_csv_files['val'],
                                         transform=data_transforms['val'],
                                         seg_transform=data_transforms['seg'],
                                         seg_probs_transform=data_transforms['seg'],
                                         classes=classes,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         label=cluster_column,
                                         brats2tcia_df=brats2tcia_df,
                                         null_genomic = null_genomic,
                                         pretrained=best_model_loc,
                                         device=device,
                                         modality=modality)



image_datasets = {'train':transformed_dataset_train, 
                  'val':transformed_dataset_val}


dataloader_train = DataLoader(image_datasets['train'], batch_size=train_batch_size, shuffle=shuffle, num_workers=4, drop_last=True)
dataloader_val = DataLoader(image_datasets['val'], batch_size=val_batch_size, shuffle=shuffle, num_workers=4)

dataloaders = {'train':dataloader_train, 'val':dataloader_val}

dataset_sizes = {'train':len(image_datasets['train']), 
                 'val':len(image_datasets['val'])}

class_names = image_datasets['train'].classes
class_names

['wildtype', 'mutant']

In [7]:
dataset_sizes['train']

467

In [8]:
# for i, data in tqdm(enumerate(dataloaders['train'])):
#     print('itr', i)
#     (inputs, seg_image, genomic_data, seg_probs), labels,(OS, event), bratsID = data
#     inputs, labels = inputs.to(device), labels.to(device)
#     OS, event = OS.to(device), event.to(device)
#     seg_image, seg_probs, genomic_data = seg_image.to(device), seg_probs.to(device), genomic_data.to(device)
#     seg_image = seg_image.squeeze(1)
#     seg_image = seg_image.type(torch.int64)
    
# # #     from models import ESPNet as Net
# # #     best_model_loc = '../pretrained/espnet_3d_brats.pth'
# # #     classes = 4
# # #     channels = 4


# # #     espnet = Net.ESPNet(classes=classes, channels=channels)
# # #     if os.path.isfile(best_model_loc):
# # #         espnet.load_state_dict(torch.load(best_model_loc, map_location=device))


# # #     level0_weight = espnet.level0.conv.weight[:, 0].unsqueeze(1)
# # #     espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
# # #                                              stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

# # #     espnet.level0.conv.weight = nn.Parameter(level0_weight)

# # #     espnet = espnet.to(device) 

# # #     output = espnet(inputs)

# # #     seg_img = output.mask_out.max(1)[1].data.byte().cpu().numpy()
    
#     break

# Class weights

In [9]:
# num_classes = 2
# best_model_loc = '../pretrained/espnet_3d_brats.pth'

# train_df = csv_files['train']
# cluster_df = train_df[train_df['phase'] == 'train']['IDH'].dropna()

# _, cnts = np.unique(cluster_df, return_counts=True)
# loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts) ## 
# loss_weights = torch.FloatTensor(loss_weights).to(device)
# print('subtype class weights:', loss_weights)

In [10]:
# label_df = csv_files['train']
# label_df = label_df.loc[(label_df['phase'] == 'train') & (label_df[task] != -1)][task]
# # label_df = labels_dict['train'][task]

# _, cnts = np.unique(label_df, return_counts=True)
# loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts) ## 
# loss_weights = torch.FloatTensor(loss_weights).to(device)
# print('subtype class weights:', loss_weights)

In [11]:
# csv_files['train']['IDH'].dropna()

In [15]:
num_classes = 2
loss_weights = [1.0000, 1.6667]
loss_weights = torch.FloatTensor(loss_weights).to(device)
print('subtype class weights:', loss_weights)

subtype class weights: tensor([1.0000, 1.6667], device='cuda:0')


# Training

In [16]:
seg_loss_weight = 1
surv_loss_weight = 1

from torch.utils.tensorboard import SummaryWriter
img_dims = str(resize_shape[0]) + 'x' + str(resize_shape[1]) + 'x' + str(resize_shape[2])
model_outfile_dir = '3D_' + task + '_'+modality+'_mtl=' + str(mtl) + '_' + img_dims + '_genomic=' + str(null_genomic)
print('tensorboad:', model_outfile_dir)
writer = SummaryWriter('runs1/'+model_outfile_dir)

tensorboad: 3D_idh_t1ce_mtl=True_64x64x64_genomic=False


In [17]:
best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
epochs = 20
iterations = 10

brats_seg_ids = glioma_metadata_df[glioma_metadata_df['gt_seg'] == 1].index

seg_4class_weights = np.load('../data/segmentation_notcropped_4-class_weights.npy')
seg_4class_weights = torch.FloatTensor(seg_4class_weights).to(device)
print('seg_4class_weights:', seg_4class_weights)

seg_2class_weights = np.load('../data/segmentation_notcropped_2-class_weights.npy')
seg_2class_weights = torch.FloatTensor(seg_2class_weights).to(device)
print('seg_2class_weights:', seg_2class_weights)

for i in range(iterations):
    print('Iteration', i)
    

    from models.nick_mtl_model import GBMNetMTL
    gbm_net = GBMNetMTL(g_in_features=50, 
                        g_out_features=128, 
                        n_classes=num_classes, 
                        n_volumes=1, 
                        seg_classes=4, 
                        pretrained=best_model_loc, 
                        class_loss_weights = loss_weights,
                        seg_4class_weights=seg_4class_weights,
                        seg_2class_weights=seg_2class_weights,
                        seg_loss_scale=seg_loss_weight,
                        surv_loss_scale=surv_loss_weight,
                        device = device,
                        brats_seg_ids=brats_seg_ids,
                        standard_unlabled_loss=False,
                        fusion_net_flag=True,
                        modality=modality,
                        take_surv_loss=False) # might have to put loss_weights on device.

    gbm_net = gbm_net.to(device)

    optimizer_gbmnet = optim.Adam(gbm_net.parameters(), lr=0.0005) # change to adami
#     exp_scheduler = optim.lr_scheduler.StepLR(optimizer_gbmnet, step_size=7, gamma=0.1)
    
    exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_gbmnet, 
                                                         mode='min', 
                                                         factor=0.1, 
                                                         patience=10, # number of epochs with no change 
                                                         verbose=True, 
                                                         threshold=0.0001, 
                                                         threshold_mode='rel', 
                                                         cooldown=0, 
                                                         min_lr=0, 
                                                         eps=1e-08)

    from train_mtl import train
    model, best_wts, best_auc, best_acc, best_auc_acc = train(model=gbm_net, 
                   dataloaders=dataloaders,
                   data_transforms=data_transforms,
                   optimizer=optimizer_gbmnet, 
                   scheduler=exp_scheduler,
                   writer=writer,
                   num_epochs=epochs, 
                   verbose=False, 
                   device=device,
                   dataset_sizes=dataset_sizes,
                   channels=1,
                   classes=class_names,
                   weight_outfile_prefix=model_outfile_dir,
                   pad=0)

    
    del gbm_net
    del model
    
    best_auc_list.append(best_auc)
    best_acc_list.append(best_acc)
    best_auc_acc_list.append(best_auc_acc)
    
    if not os.path.exists('../model_weights/results/'):
        os.makedirs('../model_weights/results/')
    
    results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
    with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_list, fp)
    with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_acc_list, fp)
    with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_acc_list, fp)

  0%|          | 0/20 [00:00<?, ?it/s]

seg_4class_weights: tensor([  1.0000, 124.3847,  60.4277, 188.0025], device='cuda:0')
seg_2class_weights: tensor([ 1.0000, 33.4366], device='cuda:0')
Iteration 0
GBMNet!
Segmentation model will only use one modality (channel)
training . . . 
before epochs


  0%|          | 0/20 [00:00<?, ?it/s]


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/app/software/SciPy-bundle/2020.03-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 4410, in get_value
    return libindex.get_value_at(s, key)
  File "pandas/_libs/index.pyx", line 44, in pandas._libs.index.get_value_at
  File "pandas/_libs/index.pyx", line 45, in pandas._libs.index.get_value_at
  File "pandas/_libs/util.pxd", line 98, in pandas._libs.util.get_value_at
  File "pandas/_libs/util.pxd", line 83, in pandas._libs.util.validate_indexer
TypeError: 'str' object cannot be interpreted as an integer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/app/software/fhPython/3.8.2-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/app/software/fhPython/3.8.2-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/app/software/fhPython/3.8.2-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "../datasets.py", line 161, in __getitem__
    label = self.metadata_df.iloc[idx][self.label] # 'cluster'
  File "/app/software/SciPy-bundle/2020.03-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/pandas/core/series.py", line 871, in __getitem__
    result = self.index.get_value(self, key)
  File "/app/software/SciPy-bundle/2020.03-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 4418, in get_value
    raise e1
  File "/app/software/SciPy-bundle/2020.03-foss-2020a-Python-3.8.2/lib/python3.8/site-packages/pandas/core/indexes/base.py", line 4404, in get_value
    return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
  File "pandas/_libs/index.pyx", line 80, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 90, in pandas._libs.index.IndexEngine.get_value
  File "pandas/_libs/index.pyx", line 138, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 1619, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 1627, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'idh_cluster'


In [None]:
np.mean(best_acc_list)

---
T1ce: 50 epochs, 10 itr (This was MTL without the surival loss). This is the number we wanted the highest)

3D_idh_t1ce_mtl_64x64x64_null-genomic_seglossweight-1_zero-sprinkle-channel

    - auc mean: 0.8840277777777779
    - auc std: 0.015952080698166733

    - acc mean: 0.779050925925926
    - acc std: 0.05753653711830886
---

3D_1p19q_t1ce_mtl_64x64x64_segLoss1_survLoss1

    - auc mean: 0.8054545454545454
    - auc std: 0.025454545454545462

    - acc mean: 0.5198484848484848
    - acc std: 0.030906491275110404
    
3D_1p19q_t1ce-t1_mtl_64x64x64_segLoss1_survLoss1

    - auc mean: 0.7426666666666668
    - auc std: 0.05036214529795515

    - acc mean: 0.5
    - acc std: 0.0

In [None]:
np.mean(best_auc_list)

In [None]:
best_auc_list