In [1]:
import sys
sys.path.append('../')

import os
import numpy as np
import torch
import pandas as pd
from skimage import transform # io, 
import PIL
import math

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms, utils, datasets, models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import pickle

from utils import get_bb_3D

from glob import glob
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm_notebook as tqdm
%matplotlib inline
plt.ion()   # interactive mode

%load_ext autoreload
%autoreload 2


from datasets import GeneralDataset
import Transforms as myTransforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

# Data

In [2]:
def get_data_splits(glioma_metadata_df, task='idh', mtl = False):
    '''
    glioma_metadata_df: for each sample, indicates
        - if it is in the labeled training, unlabeled training set, or valiation set, 
        - idh status, and 1p19q status
    task: 'idh' or '1p19q'
    mlt: True or False
    '''
    if task == 'idh':
        classes = ['wildtype', 'mutant']
        val_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df[task].isin([0,1]))]
        

    elif task == '1p19q':
        classes = ['non-codel', 'oligo']
        val_df = glioma_metadata_df.loc[glioma_metadata_df['phase'] == 'val']
        
    if mtl:
        train_df = glioma_metadata_df.loc[glioma_metadata_df['phase'].isin(['train', 'unlabeled train'])]
    else:
        train_df = glioma_metadata_df.loc[(glioma_metadata_df['phase'] == 'train') & (glioma_metadata_df[task].isin([0,1]))]

    return train_df, val_df, classes

In [3]:
# task = '1p19q'
task = 'idh'
mtl = True
best_model_loc = '../pretrained/espnet_3d_brats.pth'

# we are predicting either idh gene mutation or chromosome arm 1p and 19q co-deletion
old_glioma_metadata_df = pd.read_csv('../../miccai_clean/data/all_glioma_metadata_542x30.csv', index_col=0)
old_glioma_metadata_df = old_glioma_metadata_df.rename(columns={'OS.time':'OS', '_EVENT':'OS_EVENT'})

glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0)
glioma_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1

glioma_metadata_df['idh_cluster'] = old_glioma_metadata_df['idh_cluster']

train_df, val_df, classes = get_data_splits(glioma_metadata_df=glioma_metadata_df, task=task, mtl=mtl)

# metadata for all brats (including tcia) data

brats2tcia_df = pd.read_csv('../../miccai_clean/data/brats2tcia_df_542x1.csv', index_col=0)


# genomic data (you won't need this)
genomic_csv_files = {'train':'../data/MGL/MGL_235x50.csv', 'val':  '../data/MGL/MGL_235x50.csv'}

image_dir = '../data/all_brats_scans/'

# labels we predict during classification 
cluster_column= task + '_cluster'

# downsample dim. 
resize_shape = (64, 64, 64)

# dataloader batch size
train_batch_size = 4
val_batch_size = 4
data_batch_size = 1

shuffle = True
shuffle_data = True

dataformat = 'modality3D_mtl'
modality = 't1ce'
channels = 1

null_genomic = False

print('Train size', len(train_df))

Train size 467


In [16]:
glioma_metadata_df.loc[glioma_metadata_df['idh'].isin([0,1])]['OS'].dropna()

Brats18_TCIA01_401_1    448.0
Brats18_TCIA01_429_1     86.0
Brats18_TCIA01_454_1    144.0
Brats18_TCIA02_117_1    268.0
Brats18_TCIA02_135_1    828.0
                        ...  
Brats18_TCIA13_651_1     15.0
Brats18_TCIA13_652_1     14.0
Brats18_TCIA13_653_1      7.0
Brats18_TCIA13_654_1      3.0
Brats18_TCIA13_655_1      3.0
Name: OS, Length: 171, dtype: float64

In [22]:
ctr = 0
for i in range(len(train_df)):
    if train_df.iloc[i]['1p19q'] in [0,1]:
        ctr += 1
ctr

112

In [23]:
train_df['1p19q'].dropna()

Brats18_TCIA01_131_1   -1.0
Brats18_TCIA01_147_1   -1.0
Brats18_TCIA01_150_1   -1.0
Brats18_TCIA01_180_1   -1.0
Brats18_TCIA01_186_1   -1.0
                       ... 
Brats18_TCIA13_642_1    0.0
Brats18_TCIA13_645_1    0.0
Brats18_TCIA13_650_1    0.0
Brats18_TCIA13_653_1    0.0
Brats18_TCIA13_654_1    0.0
Name: 1p19q, Length: 160, dtype: float64

In [25]:
glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0)

In [73]:

glioma_metadata_df['1p19q'].value_counts()

0.0    210
1.0     28
Name: 1p19q, dtype: int64

In [28]:
glioma_metadata_df['idh'].value_counts()

 1.0    87
 0.0    84
-1.0    64
Name: idh, dtype: int64

In [77]:
87 + 84

171

In [78]:
171 - 112

59

In [71]:
glioma_metadata_df.loc[glioma_metadata_df['tciaID'].isin(oligo_idxs)].index

Index(['Brats18_TCIA09_141_1', 'Brats18_TCIA09_428_1', 'Brats18_TCIA10_106_1',
       'Brats18_TCIA10_109_1', 'Brats18_TCIA10_130_1', 'Brats18_TCIA10_220_1',
       'Brats18_TCIA10_239_1', 'Brats18_TCIA10_261_1', 'Brats18_TCIA10_266_1',
       'Brats18_TCIA10_271_1', 'Brats18_TCIA10_276_1', 'Brats18_TCIA10_325_1',
       'Brats18_TCIA10_387_1', 'Brats18_TCIA10_410_1', 'Brats18_TCIA10_614_1',
       'Brats18_TCIA10_631_1', 'Brats18_TCIA10_637_1', 'Brats18_TCIA10_647_1',
       'Brats18_TCIA12_466_1', 'Brats18_TCIA12_470_1', 'Brats18_TCIA12_613_1',
       'Brats18_TCIA12_641_1', 'Brats18_TCIA13_610_1', 'Brats18_TCIA13_616_1',
       'Brats18_TCIA13_617_1', 'Brats18_TCIA13_633_1', 'Brats18_TCIA13_638_1',
       'Brats18_TCIA13_646_1'],
      dtype='object')

In [72]:
brats_tcia_idxs = glioma_metadata_df['tciaID'].dropna()

subtype_df = pd.read_csv('../../tcga_data_cleaning/data/processed_data/santa_cruz/patient_metadata/subtype_spreadsheet_1128-36.csv', index_col=0)
df = subtype_df.loc[brats_tcia_idxs]
subtypes = ['Astro/GBM IDHwt', 'Astro/GBM IDHmut', 'Likely Astro/GBM IDHwt', 'Oligodendroglioma', 'Likely Astro/GBM IDHmut']
possible_oligo_idxs = df.loc[df['subtype'].isin(subtypes)].index

oligo_idxs = subtype_df.loc[subtype_df['subtype'] == 'Oligodendroglioma'].index
oligo_idxs = [x for x in oligo_idxs if x in possible_oligo_idxs]
non_oligo_idxs = [x for x in possible_oligo_idxs if x not in oligo_idxs]

brats_oligo_idxs = glioma_metadata_df.loc[glioma_metadata_df['tciaID'].isin(oligo_idxs)].index
brats_non_oligo_idxs = glioma_metadata_df.loc[glioma_metadata_df['tciaID'].isin(non_oligo_idxs)].index

glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0)
glioma_metadata_df.loc[brats_oligo_idxs, '1p19q'] = 1
glioma_metadata_df.loc[brats_non_oligo_idxs, '1p19q'] = 0

In [75]:
glioma_metadata_df.to_csv('../data/glioma_metadata.csv')

In [76]:
glioma_metadata_df

Unnamed: 0,tciaID,phase,idh,1p19q,BoundingBox,gt_seg,some_seg,OS,OS_EVENT
Brats18_2013_0_1,,unlabeled train,,,"(77, 109, 35, 54, 52, 48)",1,1,,
Brats18_2013_10_1,,unlabeled train,,,"(57, 76, 29, 55, 96, 74)",1,1,,
Brats18_2013_11_1,,unlabeled train,,,"(115, 72, 43, 60, 80, 76)",1,1,,
Brats18_2013_12_1,,unlabeled train,,,"(64, 49, 52, 80, 104, 74)",1,1,,
Brats18_2013_13_1,,unlabeled train,,,"(133, 148, 70, 43, 48, 41)",1,1,,
...,...,...,...,...,...,...,...,...,...
Brats18_WashU_W051_1,,unlabeled train,,,,0,0,,
Brats18_WashU_W053_1,,unlabeled train,,,,0,0,,
Brats18_WashU_W061_1,,unlabeled train,,,,0,0,,
Brats18_WashU_W065_1,,unlabeled train,,,,0,0,,


In [64]:
glioma_metadata_df['1p19q']

28

In [65]:
len(non_oligo_idxs)

210

In [66]:
210+28

238

In [67]:
143 + 64 + 28

235

In [68]:
glioma_metadata_df['1p19q']

Brats18_2013_0_1       NaN
Brats18_2013_10_1      NaN
Brats18_2013_11_1      NaN
Brats18_2013_12_1      NaN
Brats18_2013_13_1      NaN
                        ..
Brats18_WashU_W051_1   NaN
Brats18_WashU_W053_1   NaN
Brats18_WashU_W061_1   NaN
Brats18_WashU_W065_1   NaN
Brats18_WashU_W082_1   NaN
Name: 1p19q, Length: 542, dtype: float64

In [None]:
subtype_df.loc[.dropna()]

In [45]:
a = [x for x in brats_tcia_idxs if x in tcia_idxs]
len(a)

243

In [37]:
glioma_metadata_df['tciaID'].dropna().index

Index(['Brats18_TCIA01_131_1', 'Brats18_TCIA01_147_1', 'Brats18_TCIA01_150_1',
       'Brats18_TCIA01_180_1', 'Brats18_TCIA01_186_1', 'Brats18_TCIA01_190_1',
       'Brats18_TCIA01_201_1', 'Brats18_TCIA01_203_1', 'Brats18_TCIA01_215_1',
       'Brats18_TCIA01_221_1',
       ...
       'Brats18_TCIA13_645_1', 'Brats18_TCIA13_646_1', 'Brats18_TCIA13_648_1',
       'Brats18_TCIA13_649_1', 'Brats18_TCIA13_650_1', 'Brats18_TCIA13_651_1',
       'Brats18_TCIA13_652_1', 'Brats18_TCIA13_653_1', 'Brats18_TCIA13_654_1',
       'Brats18_TCIA13_655_1'],
      dtype='object', length=243)

In [46]:
glioma_metadata_df[['tciaID', 'some_seg', 'subtype']].dropna()

KeyError: "['subtype'] not in index"

In [51]:
glioma_metadata_df[['tciaID', 'some_seg', 'OS']].dropna() #['tciaID'].dropna()

Unnamed: 0,tciaID,some_seg,OS
Brats18_TCIA01_131_1,TCGA-02-0070,1,762.0
Brats18_TCIA01_147_1,TCGA-02-0046,1,209.0
Brats18_TCIA01_150_1,TCGA-02-0116,1,1489.0
Brats18_TCIA01_180_1,TCGA-02-0069,1,873.0
Brats18_TCIA01_186_1,TCGA-02-0027,1,370.0
...,...,...,...
Brats18_TCIA13_651_1,TCGA-HT-7860,1,15.0
Brats18_TCIA13_652_1,TCGA-HT-8107,1,14.0
Brats18_TCIA13_653_1,TCGA-HT-8111,1,7.0
Brats18_TCIA13_654_1,TCGA-HT-7690,1,3.0


In [13]:
train_df

Unnamed: 0,tciaID,phase,idh,1p19q,BoundingBox,gt_seg,some_seg,OS,OS_EVENT,idh_cluster
Brats18_2013_0_1,,unlabeled train,,,"(77, 109, 35, 54, 52, 48)",1,1,,,
Brats18_2013_10_1,,unlabeled train,,,"(57, 76, 29, 55, 96, 74)",1,1,,,
Brats18_2013_11_1,,unlabeled train,,,"(115, 72, 43, 60, 80, 76)",1,1,,,
Brats18_2013_12_1,,unlabeled train,,,"(64, 49, 52, 80, 104, 74)",1,1,,,
Brats18_2013_13_1,,unlabeled train,,,"(133, 148, 70, 43, 48, 41)",1,1,,,
...,...,...,...,...,...,...,...,...,...,...
Brats18_WashU_W051_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W053_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W061_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W065_1,,unlabeled train,,,,0,0,,,


In [4]:
# # task = '1p19q'
# task = 'idh'

# best_model_loc = '../pretrained/espnet_3d_brats.pth'

# # we are predicting either idh gene mutation or chromosome arm 1p and 19q co-deletion
# glioma_metadata_df = pd.read_csv('../../miccai_clean/data/all_glioma_metadata_542x30.csv', index_col=0)
# glioma_metadata_df = glioma_metadata_df.rename(columns={'OS.time':'OS', '_EVENT':'OS_EVENT'})

# new_glioma_metadata_df = pd.read_csv('../data/glioma_metadata.csv', index_col=0)
# new_glioma_metadata_df.loc[['Brats18_TCIA09_462_1', 'Brats18_TCIA10_236_1'], 'idh'] = 1

# if task == 'idh':
#     classes = ['wildtype', 'mutant']
#     val_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'val') & (glioma_metadata_df['inferred_subtype'] == 0)]

# elif task == '1p19q':
#     classes = ['non-codel', 'oligo']
#     val_df = glioma_metadata_df[(glioma_metadata_df['phase'] == 'val')]

# else:
#     print('invalid classification given')

# # metadata for all brats (including tcia) data

# brats2tcia_df = pd.read_csv('../../miccai_clean/data/brats2tcia_df_542x1.csv', index_col=0)

# # these are labeled files (they were paths in old dataloader) but they are dataframes
# csv_files = {'train':glioma_metadata_df[~glioma_metadata_df['phase'].isin(['val'])],
# #               'train':glioma_metadata_df[glioma_metadata_df['phase'].isin(['train'])], 
#                  'val':val_df,
#                  'data':glioma_metadata_df}

# # genomic data (you won't need this)
# genomic_csv_files = {'train':'../data/MGL/MGL_235x50.csv', 
#                      'val':  '../data/MGL/MGL_235x50.csv',
#                      'data': '../data/MGL/MGL_235x50.csv'}

# image_dir = '../data/all_brats_scans/'

# # labels we predict during classification 
# cluster_column= task + '_cluster'

# # downsample dim. 
# resize_shape = (64, 64, 64)

# # dataloader batch size
# train_batch_size = 4
# val_batch_size = 4
# data_batch_size = 1

# shuffle = True
# shuffle_data = True

# dataformat = 'modality3D_mtl'
# modality = 't1ce'
# channels = 1

# null_genomic = False

# print('Train size', len(csv_files['train']))

In [5]:
glioma_metadata_df

Unnamed: 0,tciaID,phase,idh,1p19q,BoundingBox,gt_seg,some_seg,OS,OS_EVENT,idh_cluster
Brats18_2013_0_1,,unlabeled train,,,"(77, 109, 35, 54, 52, 48)",1,1,,,
Brats18_2013_10_1,,unlabeled train,,,"(57, 76, 29, 55, 96, 74)",1,1,,,
Brats18_2013_11_1,,unlabeled train,,,"(115, 72, 43, 60, 80, 76)",1,1,,,
Brats18_2013_12_1,,unlabeled train,,,"(64, 49, 52, 80, 104, 74)",1,1,,,
Brats18_2013_13_1,,unlabeled train,,,"(133, 148, 70, 43, 48, 41)",1,1,,,
...,...,...,...,...,...,...,...,...,...,...
Brats18_WashU_W051_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W053_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W061_1,,unlabeled train,,,,0,0,,,
Brats18_WashU_W065_1,,unlabeled train,,,,0,0,,,


# Dataloaders

In [6]:
# minimal data augmentation (you can add more)
train_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
#         myTransforms.ZeroChannel(prob_zero=0.5),
        myTransforms.ZeroSprinkle(prob_zero=0.2, prob_true=0.8),
        myTransforms.ToTensor(),
    ])

# segmentation masks have separate transformations (you don't want to normalize)
seg_transformations = myTransforms.Compose([
      myTransforms.ScaleToFixed((1, resize_shape[0],resize_shape[1],resize_shape[2]), 
                                  interpolation=0, 
                                  channels=channels),
        myTransforms.ToTensor(),
    ])

val_transformations = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed((channels, resize_shape[0],resize_shape[1],resize_shape[2])),
        myTransforms.ToTensor(),
    ])

data_transforms = {'train': train_transformations,
                    'val':   val_transformations,
                   'seg': seg_transformations }



best_model_loc = '../pretrained/espnet_3d_brats.pth'

# 'jointmodel' dataformat returns entire MR images (but with the black padding cropped out)
transformed_dataset_train = GeneralDataset(csv_file=train_df,
                                           root_dir=image_dir,
                                           genomic_csv_file = genomic_csv_files['train'],
                                           transform=data_transforms['train'],
                                           seg_transform=data_transforms['seg'],
                                           seg_probs_transform=data_transforms['seg'],
                                           classes=classes,
                                           dataformat=dataformat,
                                           returndims=resize_shape,
                                           label=cluster_column,
                                           brats2tcia_df=brats2tcia_df,
                                           null_genomic = null_genomic,
                                           pretrained=best_model_loc,
                                           device=device,
                                           modality=modality)

transformed_dataset_val = GeneralDataset(csv_file=val_df,
                                         root_dir=image_dir,
                                         genomic_csv_file = genomic_csv_files['val'],
                                         transform=data_transforms['val'],
                                         seg_transform=data_transforms['seg'],
                                         seg_probs_transform=data_transforms['seg'],
                                         classes=classes,
                                         dataformat=dataformat,
                                         returndims=resize_shape,
                                         label=cluster_column,
                                         brats2tcia_df=brats2tcia_df,
                                         null_genomic = null_genomic,
                                         pretrained=best_model_loc,
                                         device=device,
                                         modality=modality)


image_datasets = {'train':transformed_dataset_train, 
                  'val':transformed_dataset_val}


dataloader_train = DataLoader(image_datasets['train'], batch_size=train_batch_size, shuffle=shuffle, num_workers=4, drop_last=True)
dataloader_val = DataLoader(image_datasets['val'], batch_size=val_batch_size, shuffle=shuffle, num_workers=4)

dataloaders = {'train':dataloader_train, 'val':dataloader_val}

dataset_sizes = {'train':len(image_datasets['train']), 'val':len(image_datasets['val'])}

class_names = image_datasets['train'].classes
class_names

['wildtype', 'mutant']

In [7]:
dataset_sizes['train']

467

In [8]:
# for i, data in tqdm(enumerate(dataloaders['train'])):
#     print('itr', i)
#     (inputs, seg_image, genomic_data, seg_probs), labels,(OS, event), bratsID = data
#     inputs, labels = inputs.to(device), labels.to(device)
#     OS, event = OS.to(device), event.to(device)
#     seg_image, seg_probs, genomic_data = seg_image.to(device), seg_probs.to(device), genomic_data.to(device)
#     seg_image = seg_image.squeeze(1)
#     seg_image = seg_image.type(torch.int64)
    
# # #     from models import ESPNet as Net
# # #     best_model_loc = '../pretrained/espnet_3d_brats.pth'
# # #     classes = 4
# # #     channels = 4


# # #     espnet = Net.ESPNet(classes=classes, channels=channels)
# # #     if os.path.isfile(best_model_loc):
# # #         espnet.load_state_dict(torch.load(best_model_loc, map_location=device))


# # #     level0_weight = espnet.level0.conv.weight[:, 0].unsqueeze(1)
# # #     espnet.level0.conv = nn.Conv3d(1, 16, kernel_size=(7, 7, 7), 
# # #                                              stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

# # #     espnet.level0.conv.weight = nn.Parameter(level0_weight)

# # #     espnet = espnet.to(device) 

# # #     output = espnet(inputs)

# # #     seg_img = output.mask_out.max(1)[1].data.byte().cpu().numpy()
    
#     break

# Class weights

In [9]:
# num_classes = 2
# best_model_loc = '../pretrained/espnet_3d_brats.pth'

# train_df = csv_files['train']
# cluster_df = train_df[train_df['phase'] == 'train']['idh_cluster']

# _, cnts = np.unique(cluster_df, return_counts=True)
# loss_weights = (np.ones(num_classes)/cnts)*np.max(cnts) ## 
# loss_weights = torch.FloatTensor(loss_weights).to(device)
# print('subtype class weights:', loss_weights)

In [10]:
num_classes = 2
loss_weights = [1.0000, 1.6667]
loss_weights = torch.FloatTensor(loss_weights).to(device)
print('subtype class weights:', loss_weights)

subtype class weights: tensor([1.0000, 1.6667], device='cuda:0')


# Training

In [11]:
seg_loss_weight = 1
surv_loss_weight = 1

from torch.utils.tensorboard import SummaryWriter
img_dims = str(resize_shape[0]) + 'x' + str(resize_shape[1]) + 'x' + str(resize_shape[2])
model_outfile_dir = '3D_' + task + '_'+modality+'_mtl_' + img_dims + '_genomic'
print('tensorboad:', model_outfile_dir)
writer = SummaryWriter('runs1/'+model_outfile_dir)

tensorboad: 3D_idh_t1ce_mtl_64x64x64_genomic


In [12]:
best_auc_list, best_acc_list, best_auc_acc_list = [], [], []
epochs = 20
iterations = 10

brats_seg_ids = glioma_metadata_df[glioma_metadata_df['gt_seg'] == 1].index

seg_4class_weights = np.load('../data/segmentation_notcropped_4-class_weights.npy')
seg_4class_weights = torch.FloatTensor(seg_4class_weights).to(device)
print('seg_4class_weights:', seg_4class_weights)

seg_2class_weights = np.load('../data/segmentation_notcropped_2-class_weights.npy')
seg_2class_weights = torch.FloatTensor(seg_2class_weights).to(device)
print('seg_2class_weights:', seg_2class_weights)

for i in range(iterations):
    print('Iteration', i)
    

    from models.nick_mtl_model import GBMNetMTL
    gbm_net = GBMNetMTL(g_in_features=50, 
                        g_out_features=128, 
                        n_classes=num_classes, 
                        n_volumes=1, 
                        seg_classes=4, 
                        pretrained=best_model_loc, 
                        class_loss_weights = loss_weights,
                        seg_4class_weights=seg_4class_weights,
                        seg_2class_weights=seg_2class_weights,
                        seg_loss_scale=seg_loss_weight,
                        surv_loss_scale=surv_loss_weight,
                        device = device,
                        brats_seg_ids=brats_seg_ids,
                        standard_unlabled_loss=False,
                        fusion_net_flag=True,
                        modality=modality,
                        take_surv_loss=False) # might have to put loss_weights on device.

    gbm_net = gbm_net.to(device)

    optimizer_gbmnet = optim.Adam(gbm_net.parameters(), lr=0.0005) # change to adami
#     exp_scheduler = optim.lr_scheduler.StepLR(optimizer_gbmnet, step_size=7, gamma=0.1)
    
    exp_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_gbmnet, 
                                                         mode='min', 
                                                         factor=0.1, 
                                                         patience=10, # number of epochs with no change 
                                                         verbose=True, 
                                                         threshold=0.0001, 
                                                         threshold_mode='rel', 
                                                         cooldown=0, 
                                                         min_lr=0, 
                                                         eps=1e-08)

    from train_mtl import train
    model, best_wts, best_auc, best_acc, best_auc_acc = train(model=gbm_net, 
                   dataloaders=dataloaders,
                   data_transforms=data_transforms,
                   optimizer=optimizer_gbmnet, 
                   scheduler=exp_scheduler,
                   writer=writer,
                   num_epochs=epochs, 
                   verbose=False, 
                   device=device,
                   dataset_sizes=dataset_sizes,
                   channels=1,
                   classes=class_names,
                   weight_outfile_prefix=model_outfile_dir,
                   pad=0)

    
    del gbm_net
    del model
    
    best_auc_list.append(best_auc)
    best_acc_list.append(best_acc)
    best_auc_acc_list.append(best_auc_acc)
    
    if not os.path.exists('../model_weights/results/'):
        os.makedirs('../model_weights/results/')
    
    results_outfile_dir = model_outfile_dir + '_epochs-' + str(epochs) +'_iterations-' + str(iterations)
    with open('../model_weights/results/auc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_list, fp)
    with open('../model_weights/results/acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_acc_list, fp)
    with open('../model_weights/results/avg_auc_acc_' + results_outfile_dir + '.txt', "wb") as fp:   #Pickling
        pickle.dump(best_auc_acc_list, fp)

  0%|          | 0/20 [00:00<?, ?it/s]

seg_4class_weights: tensor([  1.0000, 124.3847,  60.4277, 188.0025], device='cuda:0')
seg_2class_weights: tensor([ 1.0000, 33.4366], device='cuda:0')
Iteration 0
GBMNet!
Segmentation model will only use one modality (channel)
training . . . 
before epochs
  >> val_loss 0.3297903113445993 epoch 0
  >> val AUC  0.9664351851851852 | mean acc auc 0.9490740740740742 | acc 0.931712962962963 | epoch 0
New Best AUC-acc average:	 0.9490740740740742 	in epoch 0
New Best Dice:	 0.4640630003493564 	in epoch 0


  5%|▌         | 1/20 [01:14<23:41, 74.81s/it]

New Best ACC:	 0.931712962962963 	in epoch 0


  5%|▌         | 1/20 [01:19<25:17, 79.87s/it]


KeyboardInterrupt: 