TODO
- genomic data needs to be PCA and split up into train/val/data
- get rid of acc

In [1]:
import sys
sys.path.append('../')

import pandas as pd
import numpy as np

# can revome the lines that need these
import os
import pickle

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from datasets import GeneralDataset
import Transforms as myTransforms
from utils import get_data_splits

%load_ext autoreload
%autoreload 2

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

Choose parameters
- 4 channel full resolution:
```python
    task \in ['idh', '1p19q']
    dataformat \in ['raw3D', 'crop3Dslice', 'modality3D']
    modality \in ['t1ce', 'flair', 't2', 't1', 't1ce-t1']
    mtl \in [True, False]
    include_genomic_data \in [True, False]
```

In [3]:
task = 'idh'
dataformat = 'raw3D'
mtl = True

modality = None
modality = 't1ce'

# don't include genomic data
include_genomic_data = False
null_genomic = not include_genomic_data

In [4]:
def get_input_params(dataformat, mtl=True):
    '''
    This function returns a parameters used in our dataloader
    Arguments
    ---------
    dataformat: one of 'raw3D', 'crop3Dslice', 'raw3D'
               - 'modality3D' <- single MRI sequence cropped to tumor boundary (1 channel volume)
               - 'crop3Dslice' <- 4 MRI sequences cropped to tumor boundary (4 channel volume)
               - 'raw3D' <- 4 whole-brain MRI sequences (4 channel volume)
    mtl: whether or not to use unlabeled MRI sequences; i.e., a choice between simple CNNs or MTL network

    Outputs
    ---------
    dataformat:      Revise dataformat string if MTL is used
    channels:        Input channels for MRI input
    resize_shape:    Hard coded input sizes for indiviual MRI sequences: either (64, 64, 64) or (144, 144, 144)
    '''    
    try:
        if dataformat == 'modality3D':
            channels = 1
            resize_shape = (64, 64, 64)
            if mtl:
                dataformat = 'modality3D_mtl'
        elif dataformat == 'crop3Dslice':
            channels = 4
            resize_shape = (64, 64, 64)
            if mtl:
                dataformat = 'mtl_cropped'
        elif dataformat == 'raw3D':
            channels = 4
            resize_shape = (144, 144, 144)
            if mtl:
                dataformat = 'raw3D_mtl'

        return dataformat, channels, resize_shape
    except:
        print('Incorrect dataformat')
        return _, _, _

dataformat, channels, resize_shape = get_input_params(dataformat)

In [5]:
print('task:\t\t', task)
print('mtl:\t\t', mtl)
print('dataformat:\t', dataformat)
print('channels:\t', channels)
print('modality:\t', modality)
print('resize_shape:\t', resize_shape)
print('null_genomic:\t', null_genomic)

task:		 idh
mtl:		 True
dataformat:	 raw3D_mtl
channels:	 4
modality:	 t1ce
resize_shape:	 (144, 144, 144)
null_genomic:	 True


In [6]:
# def get_data_splits(metadata_df, task='idh', mtl = False):
    
#     '''
#     This function returns pandas dataframes containing training and validation indices and sample metadata
#     Arguments
#     ---------
#     metadata_df:   For each sample this dataframe indicates
#                         1) whether it is in the labeled training, unlabeled training set, or valiation set
#                         2) its idh status, and 1p19q status
#     task:          Either 'idh' or '1p19q'
#     mtl:           If True, MRI data without IDH mutation or 1p/19q co-deletion labels will be included in the traing set
    
#     Outputs
#     ---------
#     train_df:      Dataframe of training samples
#     val_df:        Dataframe of validation samples
#     classes:       Names of numerical labels
#     '''
    
#     # validation set
#     if task == 'idh':
#         classes = ['wildtype', 'mutant']
#         val_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'val') 
#                                         & (glioma_metadata_df[task].isin([0,1]))] # check whether IDH status known
#     elif task == '1p19q':
#         classes = ['non-codel', 'oligo']
#         val_df = glioma_metadata_df.loc[glioma_metadata_df['phase'] == 'val']
    
#     # training set
#     if mtl:
#         train_df = glioma_metadata_df.loc[glioma_metadata_df['phase'].isin(['train', 'unlabeled train'])]
#     else:
#         train_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'train') 
#                                           & (glioma_metadata_df[task].isin([0,1]))] # only labeled data (0/1)

#     return train_df, val_df, classes

In [7]:
# MRI directory
image_dir = '../data/all_brats_scans/'

# metadata for all brats (including tcia) data
best_model_loc = '../pretrained/espnet_3d_brats.pth' # segmentation model weights
glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0) # metadata file
glioma_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1 ######

# get training splits
train_df, val_df, classes = get_data_splits(metadata_df=glioma_metadata_df, task=task, mtl=mtl)

# map between brats dataset and tcia data (tcia data is avalible for a subset of the brats patients)
brats2tcia_df = glioma_metadata_df['tciaID']
# brats2tcia_df = pd.read_csv('../../miccai_clean/data/brats2tcia_df_542x1.csv', index_col=0)

# these are labeled files (they were paths in old dataloader) but they are dataframes
labels_dict = {'train':train_df, 'val':val_df, 'data':glioma_metadata_df}

genomic_data_dict = {'train':'../data/MGL/MGL_235x50.csv', 'val':'../data/MGL/MGL_235x50.csv'}

label = task

print('Train size', len(train_df))

Train size 467


# 3D cropped

In [8]:
# def get_transformations(channels, resize_shape, voxel_zero_prob=0.8, voxel_zero_rate=0.2, prob_zero_channel=0.5):
#     if channels == 4:
#         train_transformations = myTransforms.Compose([
#                 myTransforms.MinMaxNormalize(),
#                 myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]), 
#                                           interpolation=1, 
#                                           channels=channels),
#                 myTransforms.ZeroSprinkle(prob_zero=voxel_zero_rate, prob_true=voxel_zero_prob),
#                 myTransforms.ZeroChannel(prob_zero=prob_zero_channel),
#                 myTransforms.RandomFlip(),
#                 myTransforms.ToTensor(),
#             ])

#         seg_transformations = myTransforms.Compose([
#             myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
#                                           interpolation=0, 
#                                           channels=1),
#                 myTransforms.ToTensor(),
#             ])


#         val_transformations = myTransforms.Compose([
#                 myTransforms.MinMaxNormalize(),
#                 myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]), 
#                                           interpolation=1, 
#                                           channels=channels),
#                 myTransforms.ToTensor(),
#             ])
#     elif channels == 1:
#         # minimal data augmentation (you can add more)
#         train_transformations = myTransforms.Compose([
#                 myTransforms.MinMaxNormalize(),
#                 myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
#                 myTransforms.ZeroSprinkle(prob_zero=voxel_zero_rate, prob_true=voxel_zero_prob),
#                 myTransforms.ToTensor(),
#             ])

#         # segmentation masks have separate transformations (don't want to normalize)
#         seg_transformations = myTransforms.Compose([
#               myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
#                                           interpolation=0, 
#                                           channels=channels),
#                 myTransforms.ToTensor(),
#             ])

#         val_transformations = myTransforms.Compose([
#                 myTransforms.MinMaxNormalize(),
#                 myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
#                 myTransforms.ToTensor(),
#             ])
        
#     return train_transformations, seg_transformations, val_transformations

In [9]:
# batch size
train_batch_size, val_batch_size = 4, 4

if channels == 1:
    interpolation = 1
    # minimal data augmentation (you can add more)
    train_transformations = myTransforms.Compose([
            myTransforms.MinMaxNormalize(),
            myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]),
                                      interpolation=interpolation,
                                      channels=channels), # 1 is also channels
            myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
            myTransforms.RandomFlip(),
            myTransforms.ToTensor(),
        ])

    # segmentation masks have separate transformations (don't want to normalize)
    seg_transformations = myTransforms.Compose([
          myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                      interpolation=0, 
                                      channels=channels),
            myTransforms.ToTensor(),
        ])

    val_transformations = myTransforms.Compose([
            myTransforms.MinMaxNormalize(),
            myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                      interpolation=interpolation,
                                      channels=channels),
            myTransforms.ToTensor(),
        ])
elif channels == 4:
    if mtl: # don't flip with MTL
        train_transformations = myTransforms.Compose([
                myTransforms.MinMaxNormalize(),
                myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                          interpolation=1, 
                                          channels=channels),
                myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
                myTransforms.ZeroChannel(prob_zero=0.5),
                myTransforms.ToTensor(),
            ])
    else:
            train_transformations = myTransforms.Compose([
            myTransforms.MinMaxNormalize(),
            myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                      interpolation=1, 
                                      channels=channels),
            myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
            myTransforms.ZeroChannel(prob_zero=0.5),
            myTransforms.RandomFlip(),
            myTransforms.ToTensor(),
        ])

    seg_transformations = myTransforms.Compose([
        myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                      interpolation=0, 
                                      channels=1),
            myTransforms.ToTensor(),
        ])


    val_transformations = myTransforms.Compose([
            myTransforms.MinMaxNormalize(),
            myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                      interpolation=1, 
                                      channels=channels),
            myTransforms.ToTensor(),
        ])


data_transforms = {'train': train_transformations,
                   'val':   val_transformations,
                   'seg': seg_transformations}

# train_trans, seg_trans, val_trans = get_transformations(channels=channels, resize_shape=resize_shape)
# data_transforms = {'train': train_trans,
#                    'val':   seg_trans,
#                    'seg': val_trans}



transformed_dataset_train = GeneralDataset(csv_file=train_df, ## change this from "csv_file"
                                           root_dir=image_dir,
                                           genomic_csv_file = genomic_data_dict['train'],
                                           transform=data_transforms['train'],
                                           seg_transform=data_transforms['seg'],
                                           label=label,
                                           classes=classes,
                                           dataformat=dataformat,
                                           returndims=resize_shape,
                                           brats2tcia_df=brats2tcia_df,
                                           null_genomic = null_genomic,
                                           pretrained=best_model_loc,
                                           modality=modality)

transformed_dataset_val = GeneralDataset(csv_file=val_df,
                                         root_dir=image_dir,
                                         genomic_csv_file = genomic_data_dict['val'],
                                         transform=data_transforms['val'],
                                         seg_transform=data_transforms['seg'],
                                         label=label,
                                         classes=classes,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         brats2tcia_df=brats2tcia_df,
                                         null_genomic = null_genomic,
                                         pretrained=best_model_loc,
                                         modality=modality)


image_datasets = {'train':transformed_dataset_train, 
                  'val':transformed_dataset_val}


dataloader_train = DataLoader(image_datasets['train'], batch_size=train_batch_size, shuffle=True, num_workers=4)
dataloader_val = DataLoader(image_datasets['val'], batch_size=val_batch_size, shuffle=False, num_workers=4)

dataloaders = {'train':dataloader_train, 'val':dataloader_val}

dataset_sizes = {'train':len(image_datasets['train']), 
                 'val':len(image_datasets['val'])}

class_names = image_datasets['train'].classes
class_names

['wildtype', 'mutant']

In [10]:
# subtype_dict = {0:'wildtype', 1:'val'}
# for i, data in enumerate(dataloaders['train']):
#     (inputs, seg_image, genomic_data, seg_probs), labels,(OS, event), bratsID = data
#     inputs, labels = inputs.to(device), labels.to(device)
#     print(np.unique(seg_probs))
    
# #     seg_image2 = seg_probs[0].max(1)[1].data.byte().cpu().numpy()
    
#     print('**', inputs.shape)
#     break

In [11]:
from torch.utils.tensorboard import SummaryWriter
img_dims = str(resize_shape[0]) + 'x' + str(resize_shape[1]) + 'x' + str(resize_shape[2])
model_outfile_dir = '3D_' + task + '_'+modality+'_mtl-' + str(mtl) + '_' + img_dims + '_genomic-' + str(null_genomic)
print('tensorboad:', model_outfile_dir)
writer = SummaryWriter('runs1/'+model_outfile_dir)

tensorboad: 3D_idh_t1ce_mtl-True_144x144x144_genomic-True


In [12]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [13]:
# num_classes = 2
# best_model_loc = '../pretrained/espnet_3d_brats.pth'

# cluster_df = train_df['idh']
# _, cnts = np.unique(cluster_df, return_counts=True)
# loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts) ## 
# loss_weights = torch.FloatTensor(loss_weights).to(device)
# print('subtype class weights:', loss_weights)

In [14]:
# num_classes = 2
# best_model_loc = '../pretrained/espnet_3d_brats.pth'

# label_df = train_df['idh']
# _, cnts = np.unique(label_df, return_counts=True)
# loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
# loss_weights = torch.FloatTensor(loss_weights).to(device)
# criterion = nn.CrossEntropyLoss(weight=loss_weights)
# print('loss weights:', loss_weights)


# from train import train
# best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
# epochs = 50
# iterations = 10
# for i in range(iterations):
#     print('Iteration', i)
    
    
#     # resize_shape = (64, 64, 64)

#     from models.Models import SegModel
    
#     esp_model = SegModel(best_model_loc=best_model_loc, 
#                          inp_res = resize_shape, 
#                          num_classes=num_classes, 
#                          channels=4)
    
#     level0_weight = esp_model.espnet.level0.conv.weight[:, 0].unsqueeze(1)
#     esp_model.espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
#                                              stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

#     esp_model.espnet.level0.conv.weight = nn.Parameter(level0_weight)
#     esp_model = esp_model.to(device=device)

#     optimizer_esp = optim.Adam(esp_model.parameters(), lr=0.0005) # change to adam
# #     exp_scheduler = optim.lr_scheduler.StepLR(optimizer_esp, step_size=7, gamma=0.1)

#     exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_esp, 
#                                                          mode='min', 
#                                                          factor=0.1, 
#                                                          patience=10, # number of epochs with no change 
#                                                          verbose=True, 
#                                                          threshold=0.0001, 
#                                                          threshold_mode='rel', 
#                                                          cooldown=0, 
#                                                          min_lr=0, 
#                                                          eps=1e-08)


#     model, best_wts, best_auc, best_acc, best_auc_acc = train(model=esp_model, 
#                        dataloaders=dataloaders,
#                        data_transforms=data_transforms,
#                        criterion = criterion, 
#                        optimizer=optimizer_esp, 
#                        scheduler=exp_scheduler,
#                        writer=writer,
#                        num_epochs=epochs, 
#                        verbose=False, 
#                        device=device,
#                        dataset_sizes=dataset_sizes,
#                        channels=1,
#                        resize_shape=resize_shape,
#                        classes=class_names,
#                        volume_val=False,
#                        weight_outfile_prefix=model_outfile_dir)
#     del esp_model
#     del model
    
#     best_auc_list.append(best_auc)
#     best_acc_list.append(best_acc)
#     best_auc_acc_list.append(best_auc_acc)
    
    
#     if not os.path.exists('../model_weights/results/'):
#         os.makedirs('../model_weights/results/')
    
#     results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
#     with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
#         pickle.dump(best_auc_list, fp)
#     with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
#         pickle.dump(best_acc_list, fp)
#     with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
#         pickle.dump(best_auc_acc_list, fp)

In [15]:
num_classes = 2
# best_model_loc = '../pretrained/espnet_3d_brats.pth'

label_df = train_df.loc[train_df[task].isin([0,1])][task]
_, cnts = np.unique(label_df, return_counts=True)
loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
loss_weights = torch.FloatTensor(loss_weights).to(device)
criterion = nn.CrossEntropyLoss(weight=loss_weights)
print('loss weights:', loss_weights)

loss weights: tensor([1.0364, 1.0000], device='cuda:0')


In [16]:
best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
epochs = 50
iterations = 10
    
if not mtl:
    print('no mtl')
    from train import train
    ################
#     num_classes = 2
#     best_model_loc = '../pretrained/espnet_3d_brats.pth'

#     cluster_df = t_df[task + '_cluster']
#     _, cnts = np.unique(cluster_df, return_counts=True)
#     loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
#     loss_weights = torch.FloatTensor(loss_weights).to(device)
#     criterion = nn.CrossEntropyLoss(weight=loss_weights)
#     print('loss weights:', loss_weights)


#     from train import train
#     best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
#     epochs = 50
#     iterations = 10
    ##############
    
#     label_df = labels_dict['train']
#     label_df = label_df.loc[(label_df['phase'] == 'train') & (label_df[task] != -1)][task]
#     # label_df = labels_dict['train'][task]

#     _, cnts = np.unique(label_df, return_counts=True)
#     loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts) ## 
#     loss_weights = torch.FloatTensor(loss_weights).to(device)
#     print('subtype class weights:', loss_weights)

    if channels == 1:
#         best_model_loc = '../pretrained/espnet_3d_brats.pth'

#         label_df = train_df['idh']
#         _, cnts = np.unique(label_df, return_counts=True)
#         loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts)
#         loss_weights = torch.FloatTensor(loss_weights).to(device)
#         criterion = nn.CrossEntropyLoss(weight=loss_weights)
#         print('loss weights:', loss_weights)


#         from train import train
#         best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
#         epochs = 50
#         iterations = 10
        for i in range(iterations):
            print('Iteration', i)


            # resize_shape = (64, 64, 64)

            from models.Models import SegModel

            esp_model = SegModel(best_model_loc=best_model_loc, 
                                 inp_res = resize_shape, 
                                 num_classes=num_classes, 
                                 channels=4)

            level0_weight = esp_model.espnet.level0.conv.weight[:, 0].unsqueeze(1)
            esp_model.espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
                                                     stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

            esp_model.espnet.level0.conv.weight = nn.Parameter(level0_weight)
            esp_model = esp_model.to(device=device)

            optimizer_esp = optim.Adam(esp_model.parameters(), lr=0.0005) # change to adam
        #     exp_scheduler = optim.lr_scheduler.StepLR(optimizer_esp, step_size=7, gamma=0.1)

            exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_esp, 
                                                                 mode='min', 
                                                                 factor=0.1, 
                                                                 patience=10, # number of epochs with no change 
                                                                 verbose=True, 
                                                                 threshold=0.0001, 
                                                                 threshold_mode='rel', 
                                                                 cooldown=0, 
                                                                 min_lr=0, 
                                                                 eps=1e-08)


            model, best_wts, best_auc, best_acc, best_auc_acc = train(model=esp_model, 
                               dataloaders=dataloaders,
                               data_transforms=data_transforms,
                               criterion = criterion, 
                               optimizer=optimizer_esp, 
                               scheduler=exp_scheduler,
                               writer=writer,
                               num_epochs=epochs, 
                               verbose=False, 
                               device=device,
                               dataset_sizes=dataset_sizes,
                               channels=1,
                               resize_shape=resize_shape,
                               classes=class_names,
                               volume_val=False,
                               weight_outfile_prefix=model_outfile_dir)
            del esp_model
            del model

            best_auc_list.append(best_auc)
            best_acc_list.append(best_acc)
            best_auc_acc_list.append(best_auc_acc)


            if not os.path.exists('../model_weights/results/'):
                os.makedirs('../model_weights/results/')

            results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
            with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_auc_list, fp)
            with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_acc_list, fp)
            with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_auc_acc_list, fp)
                
    elif channels == 4:
        criterion = nn.CrossEntropyLoss(weight=loss_weights)
        for i in range(iterations):
            print('Iteration', i)
            from models.Models import SegModel
            esp_model = SegModel(best_model_loc=best_model_loc, inp_res = resize_shape, num_classes=num_classes)

            if channels == 1:
                level0_weight = esp_model.espnet.level0.conv.weight[:, 0].unsqueeze(1)
                esp_model.espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
                                                         stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
                esp_model.espnet.level0.conv.weight = nn.Parameter(level0_weight)



            esp_model = esp_model.to(device=device)

            optimizer_esp = optim.Adam(esp_model.parameters(), lr=0.0005) # change to adam
            exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_esp, 
                                                                 mode='min', 
                                                                 factor=0.1, 
                                                                 patience=10, # number of epochs with no change 
                                                                 verbose=True, 
                                                                 threshold=0.0001, 
                                                                 threshold_mode='rel', 
                                                                 cooldown=0, 
                                                                 min_lr=0, 
                                                                 eps=1e-08)


            
            model, best_wts, best_auc, best_acc, best_auc_acc = train(model=esp_model, 
                               dataloaders=dataloaders,
                               data_transforms=data_transforms,
                               criterion = criterion, 
                               optimizer=optimizer_esp, 
                               scheduler=exp_scheduler,
                               writer=writer,
                               num_epochs=epochs, 
                               verbose=False, 
                               device=device,
                               dataset_sizes=dataset_sizes,
                               channels=1,
                               resize_shape=resize_shape,
                               classes=class_names,
                               weight_outfile_prefix=model_outfile_dir)
            del esp_model
            del model

            best_auc_list.append(best_auc)
            best_acc_list.append(best_acc)
            best_auc_acc_list.append(best_auc_acc)


            results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)

            if not os.path.exists('../model_weights/results/'):
                os.makedirs('../model_weights/results/')

            with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_auc_list, fp)

            with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_acc_list, fp)

            with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
                pickle.dump(best_auc_acc_list, fp)
            


In [None]:
if mtl:
    print('mtl')
    seg_loss_weight = 1
    surv_loss_weight = 1
    

    brats_seg_ids = glioma_metadata_df[glioma_metadata_df['gt_seg'] == 1].index

    seg_4class_weights = np.load('../data/segmentation_notcropped_4-class_weights.npy')
    seg_4class_weights = torch.FloatTensor(seg_4class_weights).to(device)
    print('seg_4class_weights:', seg_4class_weights)

    seg_2class_weights = np.load('../data/segmentation_notcropped_2-class_weights.npy')
    seg_2class_weights = torch.FloatTensor(seg_2class_weights).to(device)
    print('seg_2class_weights:', seg_2class_weights)
    
    
    for i in range(iterations):
        print('Iteration', i)


        from models.nick_mtl_model import GBMNetMTL
        gbm_net = GBMNetMTL(g_in_features=50, 
                            g_out_features=128, 
                            n_classes=num_classes, 
                            n_volumes=channels, 
                            seg_classes=4, 
                            pretrained=best_model_loc, 
                            class_loss_weights = loss_weights,
                            seg_4class_weights=seg_4class_weights,
                            seg_2class_weights=seg_2class_weights,
                            seg_loss_scale=seg_loss_weight,
                            surv_loss_scale=surv_loss_weight,
                            device = device,
                            brats_seg_ids=brats_seg_ids,
                            standard_unlabled_loss=False,
                            fusion_net_flag=True,
                            modality=modality,
                            take_surv_loss=False)
        gbm_net = gbm_net.to(device)

        optimizer_gbmnet = optim.Adam(gbm_net.parameters(), lr=0.0005) # change to adami
 
        exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_gbmnet, 
                                                             mode='min', 
                                                             factor=0.1, 
                                                             patience=10, # number of epochs with no change 
                                                             verbose=True, 
                                                             threshold=0.0001, 
                                                             threshold_mode='rel', 
                                                             cooldown=0, 
                                                             min_lr=0, 
                                                             eps=1e-08)

        from train_mtl import train
        model, best_wts, best_auc, best_acc, best_auc_acc = train(model=gbm_net, 
                       dataloaders=dataloaders,
                       data_transforms=data_transforms,
                       optimizer=optimizer_gbmnet, 
                       scheduler=exp_scheduler,
                       writer=writer,
                       num_epochs=epochs, 
                       verbose=False, 
                       device=device,
                       dataset_sizes=dataset_sizes,
                       channels=channels,
                       classes=class_names,
                       weight_outfile_prefix=model_outfile_dir,
                       pad=0)


        del gbm_net
        del model

        best_auc_list.append(best_auc)
        best_acc_list.append(best_acc)
        best_auc_acc_list.append(best_auc_acc)

        if not os.path.exists('../model_weights/results/'):
            os.makedirs('../model_weights/results/')

        results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
        with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
            pickle.dump(best_auc_list, fp)
        with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
            pickle.dump(best_acc_list, fp)
        with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
            pickle.dump(best_auc_acc_list, fp)

mtl
seg_4class_weights: tensor([  1.0000, 124.3847,  60.4277, 188.0025], device='cuda:0')
seg_2class_weights: tensor([ 1.0000, 33.4366], device='cuda:0')
Iteration 0
GBMNet!


  0%|          | 0/50 [00:00<?, ?it/s]

training . . . 
before epochs
  >> val_loss 0.41625828258061814 epoch 0
  >> val AUC  0.5486111111111112 | mean acc auc 0.5243055555555556 | acc 0.5 | epoch 0
New Best AUC-acc average:	 0.5243055555555556 	in epoch 0
New Best Dice:	 0.5014230102839139 	in epoch 0
New Best ACC:	 0.5 	in epoch 0


  2%|▏         | 1/50 [12:07<9:54:29, 727.94s/it]

New Best AUC-acc average:	 0.5277777777777778 	in epoch 1
New Best Dice:	 0.7100452012633308 	in epoch 1


  4%|▍         | 2/50 [24:17<9:42:51, 728.57s/it]

  >> val_loss 0.4393914251004235 epoch 2
  >> val AUC  0.5995370370370371 | mean acc auc 0.5497685185185186 | acc 0.5 | epoch 2
New Best AUC-acc average:	 0.5497685185185186 	in epoch 2


 10%|█         | 5/50 [1:00:36<9:05:10, 726.90s/it]

  >> val_loss 0.4352293196371046 epoch 4
  >> val AUC  0.5393518518518519 | mean acc auc 0.5196759259259259 | acc 0.5 | epoch 4


 12%|█▏        | 6/50 [1:12:42<8:52:56, 726.74s/it]

  >> val_loss 0.40448559744883394 epoch 6
  >> val AUC  0.681712962962963 | mean acc auc 0.6417824074074074 | acc 0.6018518518518519 | epoch 6
New Best AUC-acc average:	 0.6417824074074074 	in epoch 6
New Best AUC:	 0.681712962962963 	in epoch 6
New Best ACC:	 0.6018518518518519 	in epoch 6


 18%|█▊        | 9/50 [1:49:07<8:17:15, 727.69s/it]

  >> val_loss 0.3932077015860606 epoch 8
  >> val AUC  0.5844907407407408 | mean acc auc 0.5422453703703705 | acc 0.5 | epoch 8


 22%|██▏       | 11/50 [2:13:24<7:53:17, 728.15s/it]

  >> val_loss 0.39004761283680545 epoch 10
  >> val AUC  0.6736111111111112 | mean acc auc 0.5868055555555556 | acc 0.5 | epoch 10


 24%|██▍       | 12/50 [2:25:34<7:41:22, 728.49s/it]

# need to combine
- nick_mtl_model
- train_mtl
- joint_model2

### 