TODO
- genomic data needs to be PCA and split up into train/val/data
- get rid of acc

```
GeneralDataset
get_transformations
get_data_splits, get_input_params
mtl_experiment
```

In [1]:
import sys
sys.path.append('../')

import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import GeneralDataset
from Transforms import get_transformations
from utils import get_data_splits, get_input_params
from train_mtl import mtl_experiment

import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

Choose parameters
- 4 channel full resolution:
```python
    task \in ['idh', '1p19q']
    dataformat \in ['raw3D', 'crop3Dslice', 'modality3D']
    modality \in ['all', 't1ce', 'flair', 't2', 't1', 't1ce-t1']
    include_genomic_data \in [True, False]
```

In [4]:
task = 'idh'
dataformat = 'crop3Dslice'
modality = 'all' # only relevent for 'modality3D' dataformat
include_genomic_data = False # don't include genomic data

dataformat, channels, resize_shape = get_input_params(dataformat)

In [5]:
print('task:\t\t', task)
print('dataformat:\t', dataformat)
print('channels:\t', channels)
print('modality:\t', modality)
print('resize_shape:\t', resize_shape)
print('include_genomic_data:\t', include_genomic_data)

task:		 idh
dataformat:	 cropped3D_mtl
channels:	 4
modality:	 all
resize_shape:	 (64, 64, 64)
include_genomic_data:	 False


In [6]:
# MRI directory
image_dir = '../data/all_brats_scans/'

# metadata for all brats (including tcia) data
## including: map from bratsIDs to tciaIDs; IDH & 1p/19q labels, and survival labels
best_model_loc = '../pretrained/espnet_3d_brats.pth' # segmentation model weights
glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0) # metadata file
glioma_metadata_df = glioma_metadata_df.drop(columns=['BoundingBox', 'some_seg', 'subtype'])

# get training splits
train_df, val_df, classes = get_data_splits(metadata_df=glioma_metadata_df, task=task, mtl=False)

# map between brats dataset and tcia data (tcia data is avalible for a subset of the brats patients)
brats2tcia_df = glioma_metadata_df['tciaID']

# these are labeled files (they were paths in old dataloader) but they are dataframes
labels_dict = {'train':train_df, 'val':val_df, 'data':glioma_metadata_df}

genomic_data_dict = {'train':'../data/MGL/MGL_235x50.csv', 'val':'../data/MGL/MGL_235x50.csv'}

label = task

print('Train size', len(train_df))

Train size 112


In [7]:
glioma_metadata_df

Unnamed: 0,tciaID,phase,idh,1p19q,gt_seg,OS,OS_EVENT
Brats18_2013_0_1,,unlabeled train,,,1,,
Brats18_2013_10_1,,unlabeled train,,,1,,
Brats18_2013_11_1,,unlabeled train,,,1,,
Brats18_2013_12_1,,unlabeled train,,,1,,
Brats18_2013_13_1,,unlabeled train,,,1,,
...,...,...,...,...,...,...,...
Brats18_WashU_W051_1,,unlabeled train,,,0,,
Brats18_WashU_W053_1,,unlabeled train,,,0,,
Brats18_WashU_W061_1,,unlabeled train,,,0,,
Brats18_WashU_W065_1,,unlabeled train,,,0,,


# Dataloader

In [7]:
## get transformations
# MinMaxNormalize, Scale, "zero sprinle", "zero channel"
train_transformations, seg_transformations, val_transformations = get_transformations(channels=channels, 
                                                                                      resize_shape=resize_shape, 
                                                                                      prob_voxel_zero=0.2, 
                                                                                      prob_true=0.8, 
                                                                                      prob_channel_zero=0.5,
                                                                                      mtl=False)
data_transforms = {'train': train_transformations, 'val':   val_transformations}
                   
transformed_dataset_train = GeneralDataset(metadata_df=train_df, 
                                           root_dir=image_dir,
                                           genomic_csv_file = genomic_data_dict['train'],
                                           transform=train_transformations,
                                           seg_transform=seg_transformations,
                                           label=label,
                                           classes=classes,
                                           dataformat=dataformat,
                                           returndims=resize_shape,
                                           brats2tcia_df=brats2tcia_df,
                                           include_genomic_data = include_genomic_data,
                                           pretrained=best_model_loc,
                                           modality=modality)

transformed_dataset_val = GeneralDataset(metadata_df=val_df,
                                         root_dir=image_dir,
                                         genomic_csv_file = genomic_data_dict['val'],
                                         transform=val_transformations,
                                         seg_transform=seg_transformations,
                                         label=label,
                                         classes=classes,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         brats2tcia_df=brats2tcia_df,
                                         include_genomic_data = include_genomic_data,
                                         pretrained=best_model_loc,
                                         modality=modality)


image_datasets = {'train':transformed_dataset_train, 'val':transformed_dataset_val}

train_batch_size, val_batch_size = 4, 4
dataloader_train = DataLoader(image_datasets['train'], batch_size=train_batch_size, shuffle=True, num_workers=4)
dataloader_val = DataLoader(image_datasets['val'], batch_size=val_batch_size, shuffle=True, num_workers=4)

dataloaders = {'train':dataloader_train, 'val':dataloader_val}
dataset_sizes = {'train':len(image_datasets['train']), 'val':len(image_datasets['val'])}

class_names = image_datasets['train'].classes
class_names

['wildtype', 'mutant']

# Visualize Data

In [8]:
# # visualize training (or validation) data
# for i, data in enumerate(dataloaders['train']):
#     # data batch
#     (image, seg_image, genomic_data, seg_probs), label, (OS, OS_EVENT), bratsID = data
#     # print scan ID
#     print(bratsID[0])
    
#     # format MRI images (slices of volumetric input)
#     img = image[0,:, :, :, int(image.shape[-1]/2)].squeeze()
#     img = utils.make_grid(img)
#     img = img.detach().cpu().numpy()
    
#     # plot images
#     plt.figure(figsize=(15, 8))
#     img_list = [img[i].T for i in range(channels)] # 1 image per channel
#     plt.imshow(np.hstack(img_list), cmap='Greys_r')
#     plt.show()

#     ## plot segmentation mask ##
#     seg_img = seg_image[0, :, :, :, int(seg_image.shape[-1]/2)].squeeze()
#     seg_img = utils.make_grid(seg_img).detach().cpu().numpy()

#     plt.figure(figsize=(4, 4))
#     plt.imshow(np.hstack([seg_img[0].T]), cmap='Greys_r')
#     plt.show()

#     break

# Train

In [None]:
# tensorboard
img_dims = str(resize_shape[0]) + 'x' + str(resize_shape[1]) + 'x' + str(resize_shape[2])
weight_outfile_prefix = '3D_' + task + '_' + modality + '_' \
                    + img_dims + '_genomic-' + str(include_genomic_data)
print('tensorboad:', weight_outfile_prefix)
writer = SummaryWriter('runs1/'+weight_outfile_prefix)


# get loss weights
num_classes = 2
label_df = train_df.loc[train_df[task].isin([0,1])][task]
_, cnts = np.unique(label_df, return_counts=True)
loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
loss_weights = torch.FloatTensor(loss_weights).to(device)

# train
best_auc_list, best_acc_list = [], []
epochs = 50
iterations = 10
seg_loss_weight = 0.1
surv_loss_weight = None

# samples with ground truth segmentations
brats_seg_ids = glioma_metadata_df[glioma_metadata_df['gt_seg'] == 1].index

# segmentation class weights
seg_4class_weights = np.load('../data/segmentation_4-class_weights.npy')
seg_4class_weights = torch.FloatTensor(seg_4class_weights).to(device)

seg_2class_weights = np.load('../data/segmentation_2-class_weights.npy')
seg_2class_weights = torch.FloatTensor(seg_2class_weights).to(device)


for i in range(iterations):
    print('Iteration', i)

    best_auc_list, best_acc_list = mtl_experiment(dataloaders=dataloaders,
                                                  data_transforms=data_transforms,
                                                  dataset_sizes=dataset_sizes,
                                                  best_model_loc=best_model_loc, 
                                                  best_auc_list=best_auc_list,
                                                  best_acc_list=best_acc_list,
                                                  weight_outfile_prefix=weight_outfile_prefix,
                                                  class_names=class_names,
                                                  channels=channels, 
                                                  loss_weights=loss_weights,
                                                  seg_4class_weights=seg_4class_weights,
                                                  seg_2class_weights=seg_2class_weights,
                                                  seg_loss_weight=seg_loss_weight,
                                                  surv_loss_weight=surv_loss_weight,
                                                  device=device,
                                                  brats_seg_ids=brats_seg_ids,
                                                  writer=writer,
                                                  model_weights_dir='../model_weights/results/',
                                                  epochs=epochs,
                                                  iterations=iterations,
                                                  standard_unlabled_loss=True,
                                                  include_genomic_data=include_genomic_data,
                                                  modality=modality,
                                                  take_surv_loss=False,
                                                  g_in_features=50,
                                                  g_out_features=128)


tensorboad: 3D_idh_all_64x64x64_genomic-False
Iteration 0


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best AUC:	 0.6909722222222222 	in epoch 6
New Best AUC:	 0.7465277777777778 	in epoch 7
New Best AUC:	 0.8101851851851851 	in epoch 8
New Best AUC:	 0.8449074074074074 	in epoch 9
New Best AUC:	 0.8657407407407407 	in epoch 14
Epoch    42: reducing learning rate of group 0 to 5.0000e-05.

Finished Training
Iteration 1


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best AUC:	 0.8368055555555556 	in epoch 6
New Best AUC:	 0.8472222222222223 	in epoch 10
Epoch    22: reducing learning rate of group 0 to 5.0000e-05.
Epoch    33: reducing learning rate of group 0 to 5.0000e-06.
Epoch    44: reducing learning rate of group 0 to 5.0000e-07.

Finished Training
Iteration 2


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best AUC:	 0.8171296296296297 	in epoch 6
New Best AUC:	 0.8541666666666666 	in epoch 9
New Best AUC:	 0.8657407407407408 	in epoch 11
New Best AUC:	 0.875 	in epoch 14
New Best AUC:	 0.8784722222222223 	in epoch 17
Epoch    33: reducing learning rate of group 0 to 5.0000e-05.
Epoch    44: reducing learning rate of group 0 to 5.0000e-06.

Finished Training
Iteration 3


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best AUC:	 0.8055555555555555 	in epoch 6
New Best AUC:	 0.857638888888889 	in epoch 9
New Best AUC:	 0.8611111111111112 	in epoch 16
New Best AUC:	 0.8634259259259259 	in epoch 30
Epoch    36: reducing learning rate of group 0 to 5.0000e-05.
Epoch    47: reducing learning rate of group 0 to 5.0000e-06.

Finished Training
Iteration 4


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

New Best AUC:	 0.7222222222222222 	in epoch 6
New Best AUC:	 0.7731481481481481 	in epoch 7
New Best AUC:	 0.8055555555555556 	in epoch 11
Epoch    16: reducing learning rate of group 0 to 5.0000e-05.
New Best AUC:	 0.8090277777777778 	in epoch 21
Epoch    27: reducing learning rate of group 0 to 5.0000e-06.
