In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json

from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import *
import time
from datetime import datetime
import os
from torch.utils import data
import random
import copy
import itertools
import io
import uuid
from sklearn.model_selection import KFold, train_test_split
import warnings
warnings.filterwarnings('ignore')

import wandb
wandb_username = 'andreasabo'
local_username = 'andreasabo'

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:1') 
else:
    device = torch.device('cpu')

print(device)

cuda:1


In [3]:
# root directory
root_dir = "/home/andreasabo/Documents/HNProject/"
split_file_base = "/home/andreasabo/Documents/HNUltra/"

# data directory on current machine: abhishekmoturu, andreasabo, denizjafari, navidkorhani
data_dir = "/home/" + local_username + "/Documents/HNProject/all_label_img/"

# read target df
csv_path = os.path.join(root_dir, "all_splits_1000000.csv")
data_df = pd.read_csv(csv_path, usecols=['subj_id', 'scan_num', 'view_label', 'image_ids', 'reflux_label', 'function_label', 'surgery_label', 'outcome_train'])

In [4]:
print(data_df.head())
len(data_df)

  function_label image_ids reflux_label surgery_label view_label  subj_id  \
0        Missing  1323_2_1      Missing       Missing    Missing     1323   
1        Missing  1323_2_2      Missing       Missing    Missing     1323   
2        Missing  1323_2_3      Missing       Missing    Missing     1323   
3        Missing  1323_2_4      Missing       Missing    Missing     1323   
4        Missing  1323_2_5      Missing       Missing    Missing     1323   

   scan_num  outcome_train  
0         2            NaN  
1         2            NaN  
2         2            NaN  
3         2            NaN  
4         2            NaN  


72459

### **Reading Data Indicies and Labels**

In [5]:
# Drop the images for which we do not have view labels or the label is "other"
data_df = data_df[data_df.view_label != "Missing"]
data_df = data_df[data_df.view_label != "Other"]
train_df = data_df[data_df.outcome_train == 1]
test_df = data_df[data_df.outcome_train == 0]

print(f"We have {len(test_df) + len(train_df)} images ({len(train_df)} train and {len(test_df)} test)")
unique_subj = train_df.subj_id.unique()

# Create the splits for 5-fold cross validation based on subj_id
data_split_file = split_file_base + 'data_splits_outcome.json'
if not os.path.isfile(data_split_file):

    kf = KFold(n_splits=5, random_state=0, shuffle=True)
    fold = 0
    all_folds = {}
    for train_subj, val_subj in kf.split(unique_subj):
        train_ids  = unique_subj[train_subj]
        val_ids = unique_subj[val_subj]
        
        # Save the image names
        train_images = train_df[train_df.subj_id.isin(train_ids)].image_ids.tolist()
        val_images = train_df[train_df.subj_id.isin(val_ids)].image_ids.tolist()
        
        # Save the scan number
        train_scan = train_df[train_df.subj_id.isin(train_ids)].scan_num.tolist()
        val_scan = train_df[train_df.subj_id.isin(val_ids)].scan_num.tolist()
        
        # Save the view 
        train_views = train_df[train_df.subj_id.isin(train_ids)].view_label.tolist()
        val_views = train_df[train_df.subj_id.isin(val_ids)].view_label.tolist()
        
        
        
        # Save the outcome labels
        train_function = train_df[train_df.subj_id.isin(train_ids)].function_label.tolist()
        val_function = train_df[train_df.subj_id.isin(val_ids)].function_label.tolist()
        
        train_reflux = train_df[train_df.subj_id.isin(train_ids)].reflux_label.tolist()
        val_reflux = train_df[train_df.subj_id.isin(val_ids)].reflux_label.tolist() 
        
        train_surgery = train_df[train_df.subj_id.isin(train_ids)].surgery_label.tolist()
        val_surgery = train_df[train_df.subj_id.isin(val_ids)].surgery_label.tolist()
        
        
        val_labels = train_df[train_df.subj_id.isin(val_ids)].view_label.tolist()
        cur_fold = {'train_images': train_images, 'val_images': val_images, 'train_reflux': train_reflux, 
                    'val_reflux': val_reflux, 'train_function': train_function, 'val_function': val_function, 
                    'train_surgery': train_surgery, 'val_surgery': val_surgery, 'train_scan': train_scan,
                    'val_scan': val_scan, 'train_views': train_views, 'val_views': val_views}
        
        all_folds[fold] = cur_fold
        fold += 1

    print("Saving data splits")
    with open(data_split_file, 'w') as f:
        json.dump(all_folds, f)
        
else: # just load from file
    print("Reading splits from file")
    with open(data_split_file, 'r') as f:
        all_folds = json.load(f)

We have 9581 images (7230 train and 2351 test)
Reading splits from file


In [6]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"

# Number of classes in the dataset: right_sag, right_trav, left_sag, left_trav, bladder, other
num_classes = 3

# Batch size for training (change depending on how much memory you have)
batch_size = 100

# Number of epochs to train for
num_epochs = 75

# Flag for feature extracting. When False, we finetune the whole model; when True we only update the reshaped layer params
feature_extract = False

# Flag for whether or not to use pretrained model
pretrain = False

# Flag for whether or not to sample sets of images by scan (ensures that all scans are seen the same number of times)
sample_by_scan = True


In [7]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [8]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True, num_input_channels = 1):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        source for conv1: https://discuss.pytorch.org/t/grayscale-images-for-resenet-and-deeplabv3/48693/2
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.conv1 =  nn.Conv2d(num_input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model_ft.fc = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(num_ftrs, num_classes),
#             nn.Sigmoid()
        )
        
        nn.Sigmoid
        input_size = 256

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299
        
    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


In [9]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [10]:
# Code from: https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=True):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [11]:
import random
class ScanDataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, image_list, reflux_labels, surgery_labels, function_labels, view_labels, scan_labels, binarize_labels = True, transformations=None, sample_by_scan=True , target_view=None):
        'Initialization'
        self.image_list = image_list
        self.subj_id = np.asarray([int(s[0:4]) for s in image_list])

        self.reflux_labels = reflux_labels        
        self.surgery_labels = surgery_labels
        self.function_labels = function_labels       
        self.view_labels = view_labels
        self.scan_labels = np.asarray(scan_labels)
        self.binarize_labels = binarize_labels
        self.transformations = transformations
        self.sample_by_scan = sample_by_scan
        self.target_view = target_view
        self.image_return_order = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']
        
        # Create a list of indices that we will access the images in
        random.seed(0)
        list_of_inds = [num for num in range(len(image_list))]
        
        # UPDATE: don't need this since the dataloader shuffles for us
        #random.shuffle(list_of_inds) #shuffle method

        self.index_order = list_of_inds
        self.all_view_names = list(set((view_labels)))
        
        # If we are going to group by scans, create the scan list
        self.all_scan_ids = np.asarray([(str(self.subj_id[i]) + "_" + str(self.scan_labels[i])) for i in range(len(self.scan_labels))])
        self.unique_scans = list(set(self.all_scan_ids))
        
        if target_view is not None:
            self.images_of_view = []
            for ind, v in enumerate(view_labels):
                if v == target_view:
                    self.images_of_view.append(ind)

        
  def __len__(self):
        'Denotes the total number of image_list or scan_list'
        if self.target_view is not None:
            return len(self.images_of_view)
        
        if self.sample_by_scan:
            return len(self.unique_scans)
        else:
            return len(self.image_list)

  def __getitem__(self, ind):
        'Generates one sample of data'
        output_image_path_dict = {}
        if not self.sample_by_scan and self.target_view is not None:
            index = self.images_of_view[ind]
            output_image_path_dict[self.target_view] = data_dir + self.image_list[index] + '.jpg'
        
        else:
            if self.sample_by_scan:
                cur_scan = self.unique_scans[ind]
                images_of_scan = (self.all_scan_ids == cur_scan)
                indexes_in_scan = [i for i, x in enumerate(images_of_scan) if x]
                index = indexes_in_scan[0] # For now, just take the first index

            else:
                # Select sample based on current images
                index = self.index_order[ind]
                image_id = self.image_list[index]

                # We want to use the current image, as well as images from the same scan (from the same patient)
                # To fill images of the other 4 classes of images
                # First finds all of the images in this scan
                images_of_scan = (self.subj_id == self.subj_id[index]) & (self.scan_labels == self.scan_labels[index])
                indexes_in_scan = [i for i, x in enumerate(images_of_scan) if x]

            # Group images in the scan by view and select one of each (if available)
            dict_of_scans_by_view = {}
            # looping through views
            for inds in indexes_in_scan:
                if self.view_labels[inds] not in dict_of_scans_by_view:
                    dict_of_scans_by_view[self.view_labels[inds]] = [inds]
                else:
                    dict_of_scans_by_view[self.view_labels[inds]].append(inds)

            # If we directly picked an image, make sure this one is the only one for the view
            if not self.sample_by_scan:
                dict_of_scans_by_view[self.view_labels[index]] = [index]


            output_image_path_dict = {}
            # Now loop through all the views and randomly select an image.  
            for view in dict_of_scans_by_view:
                    
                    random_view_from_scan_ind = random.choice(dict_of_scans_by_view[view])
                    output_image_path_dict[view] = data_dir + self.image_list[random_view_from_scan_ind] + '.jpg'
                    
                    if self.target_view is view and self.sample_by_scan:
                        index = random_view_from_scan_ind
                    

        output_images = []
        for view in self.image_return_order:
            if view not in output_image_path_dict:
                empty_im = [np.ones((256, 256))*np.nan]
                empty_im = torch.FloatTensor(empty_im)
                output_images.append( empty_im)
            else:
                image = Image.open(output_image_path_dict[view]).convert('L')
                if self.transformations:
                    image = self.transformations(image)
                image = ToTensor()(image)
                output_images.append(image)



        y = [self.reflux_labels[index], self.surgery_labels[index], self.function_labels[index]]
        for i, outcome in enumerate(y):
            # Convert "yes" and "no" to 0/1
            if outcome == "No":
                y[i] = 0
            elif outcome == "Yes":
                y[i] = 1
            elif outcome == "Missing":
                y[i] = np.nan

            # Should we also binarize the function labels?
            elif self.binarize_labels:
                if float(outcome) > 60 or float(outcome) < 40:
                    y[i] = 1
                else:
                    y[i] = 0
            else:
                y[i] = float(outcome)

        y = torch.FloatTensor(y)
        output_images = torch.stack(output_images)
        return output_images, y



In [12]:
def train_model(model, dataloaders, criterion, optimizer, view_to_use, num_epochs=2, is_inception=False, final_testing=False):
    es = EarlyStopping(patience = 20)
    stop_now = 0
    results_dict = {}
    
    all_views = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']

    if view_to_use not in all_views:
        print(f"View {view_to_use} does not exist")
        return
    
    view_index = all_views.index(view_to_use)
    since = time.time()
    
    classnames = ['Reflux', 'Surgery', 'Function']
    val_acc_history = []
    
    val_metrics_list = []
    train_metrics_list = []
    best_auc = 0
    best_epoch = 0
    lowest_loss = 999999
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_accs = [0.0]*3
    epoch_with_best_val_acc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 54)

        if stop_now:
            break
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
                if epoch % 1 is not 0:
                    continue
            running_loss = 0.0
            running_corrects = 0
            
            running_preds = []
            running_labels = []
            
            running_num_samples = 0
            
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
#                 labels = labels.type(torch.long)
                # Only use the image for the specified view 
                inputs = inputs[:, view_index, :, :]
#                 inputs = inputs.squeeze()
#                 print("-"*45)
                
                # Remove the indices with missing images
                input_sums = torch.sum(torch.sum(torch.sum(torch.isnan(inputs), 1), 1), 1)
                inputs_good = inputs[input_sums < 1, :, :, :]
                labels_good = labels[input_sums < 1, :]
#                 print(inputs_good.shape)
                # Create a mask of whether we have each an outcome label at each label index
                label_missing = torch.isnan(labels_good)

                inputs = inputs_good.to(device)
                labels = labels_good.to(device)
                label_missing = label_missing.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                
                 # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        
                        
                        running_preds += outputs.tolist()
                        running_labels += labels.tolist()
                        
                        # multiply the outputs and labels by the label mask so that missing labels do not contribute 
                        # to loss
                        outputs_for_loss = outputs * (~label_missing)
                        outputs_for_loss[abs(outputs_for_loss) < 0.000001] = 0
                        labels_for_loss = labels.clone()
                        labels_for_loss[labels_for_loss != labels_for_loss] = 0
                    
                        # Create a weights matrix so that we can place higher value on the function outcome
                        # which is often missing
                        loss_weights = torch.ones_like(outputs)
                        loss_weights[label_missing] = 0
                        col_sum = torch.sum(loss_weights, 0)
                        loss_weights = (loss_weights / col_sum) * inputs.size(0)
                        loss_weights[loss_weights > 1] = 1.25
                        #print(loss_weights)
                        criterion = nn.BCEWithLogitsLoss(weight=loss_weights)
                        loss = criterion(outputs_for_loss, labels_for_loss)
    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                        
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_num_samples += inputs.size(0)
                
            # Epoch is done    
            epoch_loss = running_loss / running_num_samples
            
            # Calculate AUROC and AUPRC for each outcome
            results_dict.update({phase + '_loss': epoch_loss})
            preds = np.asarray(running_preds)
            true_labels = np.asarray(running_labels)
            
            sum_auc = 0
            sum_ave_prec = 0
            # source: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            precision = dict()
            recall = dict()
            no_skill = dict()
            average_precision = dict()
            only_one_class = 0
            for ind, cl in enumerate(classnames):
                try:

                    # Remove the nans when computing metrics
                    true_outcome_labels = true_labels[:, ind]
                    predicted_outcome_labels = preds[:, ind]
                    non_nan_locs = np.squeeze(np.argwhere(~np.isnan(true_outcome_labels)))
                    true_outcome_labels = true_outcome_labels[non_nan_locs]
                    predicted_outcome_labels = predicted_outcome_labels[non_nan_locs]

                    auroc = roc_auc_score(true_outcome_labels, predicted_outcome_labels)
                    auprc = average_precision_score(true_outcome_labels, predicted_outcome_labels)
                    results_dict[phase+ "_" + cl+"_auroc"] = auroc
                    results_dict[phase+ "_" + cl+"_auprc"] = auprc
                    sum_auc += auroc


                    fpr[cl], tpr[cl], _ = roc_curve(true_outcome_labels, predicted_outcome_labels)
                    precision[cl], recall[cl], _ = precision_recall_curve(true_outcome_labels, predicted_outcome_labels)
                    roc_auc[cl] = auroc
                    average_precision[cl] = auprc
                    no_skill[cl] = len(true_outcome_labels[true_outcome_labels==1]) / len(true_outcome_labels)

                    sum_ave_prec += auprc

                    results_dict[phase+"_"+cl +"_auroc"] = auroc
                    results_dict[phase+"_"+cl +"_auprc"] = auprc                


                except:
                    only_one_class = 1
                    
                results_dict[phase+"_sum_auroc"] = sum_auc
                results_dict[phase+"_sum_auprc"] = sum_ave_prec
#             plt.figure()
#             lw = 2
#             if not only_one_class:
#                 for cl in classnames:
#                     plt.plot(fpr[cl], tpr[cl], 
#                              lw=lw, label=cl + '  ROC curve (area = %0.2f)' % roc_auc[cl])
#     #                 plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
#                     plt.xlim([0.0, 1.0])
#                     plt.ylim([0.0, 1.05])
#                     plt.xlabel('False Positive Rate')
#                     plt.ylabel('True Positive Rate')
#                     plt.title(phase + ' ROC of Outcomes, Epoch: ' + str(epoch))
#                     plt.legend(loc="lower right")
#     #             plt.show()

#                 # Logging to wandb for analysis 
#                 wandb.log({phase + "_auroc_plot_"+ str(epoch): wandb.Image(plt)}, step=epoch)

#                 plt.figure()
#                 lw = 2
#                 #source: https://machinelearningmastery.com/roc-curves-and-precision-recall-curves-for-classification-in-python/
#                 for cl in classnames:

#                     plt.plot(recall[cl], precision[cl], 
#                              lw=lw, label=cl + '  PR curve (area = %0.2f)' % average_precision[cl])
#     #                 plt.plot([0, 1], [no_skill[cl], no_skill[cl]], color='navy', lw = lw, linestyle='--', label='No Skill')

#                     plt.xlim([0.0, 1.0])
#                     plt.ylim([0.0, 1.05])
#                     plt.xlabel('Recall')
#                     plt.ylabel('Precision')
#                     plt.title(phase + ' Precision-Recall Curve of Outcomes, Epoch: '+ str(epoch))
#                     plt.legend(loc="upper right")
#     #             plt.show()
#                 wandb.log({phase + "_auprc_plot_"+ str(epoch): wandb.Image(plt), "epoch": epoch}, step=epoch)
            
            wandb.log(results_dict, step=epoch)
            
            
            if phase == "val" and epoch_loss < lowest_loss and epoch >= 30:
                lowest_loss = epoch_loss
                best_epoch = epoch
                results_dict.update({"best_epoch":best_epoch})
                wandb.log(results_dict, step=epoch)

                metrics_from_best_epoch = copy.deepcopy(results_dict)
                print(metrics_from_best_epoch)
            print('{} loss:\t{:.4f} | {} sum of auc:\t{:.4f} \n'.format(phase, epoch_loss, phase, sum_auc))
            
            if phase == 'val' and epoch >= 30:
                if es.step(epoch_loss) and not final_testing:
                    stop_now = 1
                    print("EARLY STOPPING " + str(epoch))
                    break
    
    return metrics_from_best_epoch


# all_views = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']
# view_to_use = 'Saggital_Left' 
# fold = '0'
# # Use the first fold for now
# partition = all_folds[fold]

# # Test out dataloaders
# shuffle = True
# num_workers = 0
# binarize_outcomes = True
# params = {'batch_size': batch_size,
#           'shuffle': shuffle,
#           'num_workers': num_workers}

# # Tranforms
# trans = transforms.Compose([transforms.RandomAffine(degrees=8, translate=(0.1, 0.1), scale=(0.95,1.25))])

# # Generators
# training_set = ScanDataset(partition['train_images'], partition['train_reflux'], partition['train_surgery'],
#                        partition['train_function'], partition['train_views'],partition['train_scan'], binarize_outcomes, trans, sample_by_scan)
# val_set = ScanDataset(partition['val_images'], partition['val_reflux'], partition['val_surgery'],
#                        partition['val_function'], partition['val_views'],partition['val_scan'], binarize_outcomes, trans, sample_by_scan, target_view=view_to_use)

# training_generator = data.DataLoader(training_set, **params)
# val_generator = data.DataLoader(val_set, **params)


# dataloaders_dict = {'train': training_generator, 'val': val_generator}


# random_str = str(uuid.uuid4()).split("-")[0]

# wandb.init(project='hnultra_outcome', entity=wandb_username, name=view_to_use + '_fold_' + fold, group=random_str)



# lr = 0.001
# wd = 0.001


# model_ft, _ = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
# model_ft = model_ft.to(device)
# wandb.watch(model_ft)

# # print(model)
# criterion = nn.BCEWithLogitsLoss()
# optimizer = torch.optim.Adam(model_ft.parameters(), lr=lr, weight_decay=wd)
# train_model(model_ft, dataloaders_dict, criterion, optimizer, view_to_use, num_epochs=200, is_inception=False, final_testing=False)


In [13]:
def train5fold(view_to_use, network_configs, model_name, lr, wd, amsgrad, feature_extract, i):
    random_str = str(uuid.uuid4()).split("-")[0]
    binarize_outcomes = True
    folds = ['0', '1', '2', '3', '4']
    project = 'hnultra_weigh_function_0.25_no_data_augment_min30_epochs'
    group = random_str + '_' + view_to_use
    
    best_metrics_per_fold = []
    for fold in folds:

        now = datetime.now()
        date_time = now.strftime("%d-%m-%Y.%H:%M:%S")
        wandb.init(project=project, entity=wandb_username, name='fold_' + fold, group=group)
        partition = all_folds[fold]

        model_ft, _ = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
        model_ft = model_ft.to(device)
        wandb.watch(model_ft)

        # Gather the parameters to be optimized/updated in this run. If we are
        #  finetuning we will be updating all parameters. However, if we are
        #  doing feature extract method, we will only update the parameters
        #  that we have just initialized, i.e. the parameters with requires_grad
        #  is True.
        params_to_update = model_ft.parameters()
        #print("Params to learn:")
        if feature_extract:
            params_to_update = []
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
#                     print("\t",name)
        else:
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    pass
#                     print("\t",name)
#         print(model_ft)
        # Observe that all parameters are being optimized
#         optimizer_ft = optim.Adam(params_to_update, lr=lr, weight_decay=wd, amsgrad=amsgrad)

        # Setup the loss fxn
#         criterion = nn.CrossEntropyLoss()

        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model_ft.parameters(), lr=lr, weight_decay=wd)
        
        shuffle = True
        num_workers = 0
        params = {'batch_size': batch_size,
                  'shuffle': shuffle,
                  'num_workers': num_workers}

        config_dict = {'i': i, 'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': num_workers, 'fold': int(fold),
                       'lr': lr, 'wd': wd, 'amsgrad': amsgrad, 'model_name': model_name, 'num_classes': num_classes, 
                       'num_epochs': num_epochs, 'feature_extract': feature_extract, "pretrain": pretrain, 'view': view_to_use }

        wandb.config.update(config_dict)
        # Tranforms
        # Tranforms
        trans = transforms.Compose([transforms.RandomAffine(degrees=8, translate=(0.1, 0.1), scale=(0.95,1.25))])

        # Generators
        training_set = ScanDataset(partition['train_images'], partition['train_reflux'], partition['train_surgery'],
                               partition['train_function'], partition['train_views'],partition['train_scan'], binarize_outcomes, None, sample_by_scan=True, )
        val_set = ScanDataset(partition['val_images'], partition['val_reflux'], partition['val_surgery'],
                               partition['val_function'], partition['val_views'],partition['val_scan'], binarize_outcomes, None, False, target_view=view_to_use)

        training_generator = data.DataLoader(training_set, **params)
        val_generator = data.DataLoader(val_set, **params)


        dataloaders_dict = {'train': training_generator, 'val': val_generator}

        # Train & Evaluate
        metrics_from_best_epoch = train_model(model_ft, dataloaders_dict, criterion, optimizer,  view_to_use, num_epochs, is_inception=(model_name=="inception"), final_testing = False)
        best_metrics_per_fold.append(metrics_from_best_epoch)

    # Calculate the performance metrics on the best model in each fold
    wandb.init(project=project, entity=wandb_username, name='ALL', group=group)
    config_dict['fold'] = -1
    wandb.config.update(config_dict)
    wandb.config.update(network_configs)


    metrics_all = {}
    for fold in best_metrics_per_fold:
        for key in fold:
            if key not in metrics_all:
                metrics_all[key] = [fold[key]]
            else:
                metrics_all[key].append(fold[key]) 
    # print(metrics_all)

    metrics_to_log = {}
    for m in metrics_all:
        metric_list = np.asarray(metrics_all[m])

        metrics_to_log[m + '_mean'] = metric_list.mean()    
        metrics_to_log[m + '_stdev'] = metric_list.std()

    wandb.config.update(metrics_to_log)

In [None]:
lr = 0.001
wd = 0.001
amsgrad = False
i = 0
lrs = [1e-3, 1e-4, 5e-4, 1e-3]
weight_decays = [1e-4, 5e-4, 1e-5]
feature_extracts = [False, True]
for lr in lrs:
    for wd in weight_decays:
        for feature_extract in feature_extracts:
            for view_to_use in ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left']:
                if i >= 0:
                    train5fold(view_to_use, {'feature_extract': feature_extract}, 'resnet', lr, wd, amsgrad, feature_extract, i)
                i += 1

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.4624 | train sum of auc:	1.4179 

val loss:	0.4868 | val sum of auc:	1.3903 

Epoch 2/75
------------------------------------------------------
train loss:	0.3807 | train sum of auc:	1.8673 

val loss:	0.5327 | val sum of auc:	1.3273 

Epoch 3/75
------------------------------------------------------
train loss:	0.3158 | train sum of auc:	2.1924 

val loss:	0.5196 | val sum of auc:	1.7810 

Epoch 4/75
------------------------------------------------------
train loss:	0.2857 | train sum of auc:	2.3536 

val loss:	0.6901 | val sum of auc:	1.1361 

Epoch 5/75
------------------------------------------------------
train loss:	0.2559 | train sum of auc:	2.4365 

val loss:	0.9215 | val sum of auc:	1.3497 

Epoch 6/75
------------------------------------------------------
train loss:	0.2262 | train sum of auc:	2.6831 

val loss:	0.8711 | val sum of auc:	1.7833 

Epoch 7/75
-----------------------------------------

train loss:	0.0494 | train sum of auc:	2.9888 

val loss:	1.1125 | val sum of auc:	1.7789 

Epoch 49/75
------------------------------------------------------
train loss:	0.0396 | train sum of auc:	2.9929 

val loss:	0.9996 | val sum of auc:	1.9560 

Epoch 50/75
------------------------------------------------------
train loss:	0.0379 | train sum of auc:	2.9928 

val loss:	1.0292 | val sum of auc:	2.0633 

Epoch 51/75
------------------------------------------------------
train loss:	0.0481 | train sum of auc:	2.9915 

val loss:	1.0836 | val sum of auc:	2.0699 

EARLY STOPPING 50
Epoch 52/75
------------------------------------------------------


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.6265 | train sum of auc:	1.4567 

val loss:	0.4877 | val sum of auc:	1.3015 

Epoch 2/75
------------------------------------------------------
train loss:	0.4440 | train sum of auc:	1.8422 

val loss:	0.4879 | val sum of auc:	1.2699 

Epoch 3/75
------------------------------------------------------
train loss:	0.3525 | train sum of auc:	2.2495 

val loss:	0.5184 | val sum of auc:	1.4560 

Epoch 4/75
------------------------------------------------------
train loss:	0.3171 | train sum of auc:	2.3798 

val loss:	0.5055 | val sum of auc:	1.4629 

Epoch 5/75
------------------------------------------------------
train loss:	0.3221 | train sum of auc:	2.2952 

val loss:	0.6055 | val sum of auc:	1.1958 

Epoch 6/75
------------------------------------------------------
train loss:	0.2879 | train sum of auc:	2.5621 

val loss:	0.7401 | val sum of auc:	1.3464 

Epoch 7/75
-----------------------------------------

train loss:	0.0620 | train sum of auc:	2.9668 

val loss:	0.6217 | val sum of auc:	1.6160 

Epoch 40/75
------------------------------------------------------
train loss:	0.0507 | train sum of auc:	2.9976 

val loss:	0.7625 | val sum of auc:	1.7228 

Epoch 41/75
------------------------------------------------------
train loss:	0.0754 | train sum of auc:	2.9834 

val loss:	0.7289 | val sum of auc:	1.6571 

Epoch 42/75
------------------------------------------------------
train loss:	0.0521 | train sum of auc:	2.9940 

val loss:	0.8318 | val sum of auc:	1.4423 

Epoch 43/75
------------------------------------------------------
train loss:	0.0552 | train sum of auc:	2.9765 

val loss:	1.0721 | val sum of auc:	1.3387 

Epoch 44/75
------------------------------------------------------
train loss:	0.0583 | train sum of auc:	2.9882 

val loss:	1.0204 | val sum of auc:	1.3570 

Epoch 45/75
------------------------------------------------------
train loss:	0.0439 | train sum of auc:	2.9940 

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.4940 | train sum of auc:	1.7852 

val loss:	0.3885 | val sum of auc:	0.6283 

Epoch 2/75
------------------------------------------------------
train loss:	0.4033 | train sum of auc:	1.8469 

val loss:	0.3682 | val sum of auc:	0.7353 

Epoch 3/75
------------------------------------------------------
train loss:	0.3450 | train sum of auc:	2.1342 

val loss:	0.3848 | val sum of auc:	0.8656 

Epoch 4/75
------------------------------------------------------
train loss:	0.3856 | train sum of auc:	2.0671 

val loss:	0.4306 | val sum of auc:	0.7499 

Epoch 5/75
------------------------------------------------------
train loss:	0.2864 | train sum of auc:	2.4879 

val loss:	0.3760 | val sum of auc:	0.7462 

Epoch 6/75
------------------------------------------------------
train loss:	0.2738 | train sum of auc:	2.5685 

val loss:	0.3656 | val sum of auc:	0.8158 

Epoch 7/75
-----------------------------------------

{'train_loss': 0.05312777047178575, 'train_Reflux_auroc': 0.9904099736274274, 'train_Reflux_auprc': 0.9849211481333427, 'train_sum_auroc': 2.989876640294094, 'train_sum_auprc': 2.980754481466676, 'train_Surgery_auroc': 0.9994666666666666, 'train_Surgery_auprc': 0.9958333333333333, 'train_Function_auroc': 1.0, 'train_Function_auprc': 1.0, 'val_loss': 0.7299654516013893, 'val_Reflux_auroc': 0.37647118506493504, 'val_Reflux_auprc': 0.17845459308181014, 'val_sum_auroc': 0.8264711850649351, 'val_sum_auprc': 0.9307258385511044, 'val_Function_auroc': 0.45000000000000007, 'val_Function_auprc': 0.7522712454692942, 'best_epoch': 43}
val loss:	0.7300 | val sum of auc:	0.8265 

Epoch 45/75
------------------------------------------------------
train loss:	0.0358 | train sum of auc:	2.9962 

val loss:	0.7746 | val sum of auc:	0.7731 

Epoch 46/75
------------------------------------------------------
train loss:	0.0455 | train sum of auc:	2.9941 

val loss:	0.7619 | val sum of auc:	0.7354 

Epoch 4

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5485 | train sum of auc:	1.5931 

val loss:	0.5066 | val sum of auc:	1.8288 

Epoch 2/75
------------------------------------------------------
train loss:	0.3999 | train sum of auc:	1.9654 

val loss:	0.7481 | val sum of auc:	1.2230 

Epoch 3/75
------------------------------------------------------
train loss:	0.3276 | train sum of auc:	2.1594 

val loss:	0.7997 | val sum of auc:	1.3742 

Epoch 4/75
------------------------------------------------------
train loss:	0.3407 | train sum of auc:	2.1383 

val loss:	0.7929 | val sum of auc:	1.7396 

Epoch 5/75
------------------------------------------------------
train loss:	0.2790 | train sum of auc:	2.3992 

val loss:	1.1993 | val sum of auc:	1.5560 

Epoch 6/75
------------------------------------------------------
train loss:	0.2697 | train sum of auc:	2.4873 

val loss:	0.6579 | val sum of auc:	1.4992 

Epoch 7/75
-----------------------------------------

train loss:	0.0850 | train sum of auc:	2.9651 

{'train_loss': 0.0850221212697393, 'train_Reflux_auroc': 0.9847953216374268, 'train_Reflux_auprc': 0.9700084518944907, 'train_sum_auroc': 2.965142915220315, 'train_sum_auprc': 2.876924641765226, 'train_Surgery_auroc': 0.9901515151515152, 'train_Surgery_auprc': 0.9202708907254362, 'train_Function_auroc': 0.9901960784313726, 'train_Function_auprc': 0.9866452991452993, 'val_loss': 0.8577154609428975, 'val_Reflux_auroc': 0.41405895691609973, 'val_Reflux_auprc': 0.27712312318542487, 'val_sum_auroc': 1.199669170774943, 'val_sum_auprc': 0.47070277948365924, 'val_Surgery_auroc': 0.6134288413098237, 'val_Surgery_auprc': 0.10988958987929254, 'val_Function_auroc': 0.1721813725490196, 'val_Function_auprc': 0.08369006641894185, 'best_epoch': 34}
val loss:	0.8577 | val sum of auc:	1.1997 

Epoch 36/75
------------------------------------------------------
train loss:	0.0857 | train sum of auc:	2.9522 

val loss:	0.8776 | val sum of auc:	1.1712 

Epoch 

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.4982 | train sum of auc:	1.5397 

val loss:	0.8632 | val sum of auc:	1.5600 

Epoch 2/75
------------------------------------------------------
train loss:	0.3635 | train sum of auc:	2.0955 

val loss:	0.5273 | val sum of auc:	1.5123 

Epoch 3/75
------------------------------------------------------
train loss:	0.3444 | train sum of auc:	2.0654 

val loss:	0.3968 | val sum of auc:	1.8389 

Epoch 4/75
------------------------------------------------------
train loss:	0.3280 | train sum of auc:	2.2931 

val loss:	0.4114 | val sum of auc:	1.6063 

Epoch 5/75
------------------------------------------------------
train loss:	0.3069 | train sum of auc:	2.3918 

val loss:	0.4644 | val sum of auc:	1.8948 

Epoch 6/75
------------------------------------------------------
train loss:	0.2639 | train sum of auc:	2.3888 

val loss:	0.6533 | val sum of auc:	1.6827 

Epoch 7/75
-----------------------------------------

train loss:	0.0546 | train sum of auc:	2.9758 

val loss:	0.8447 | val sum of auc:	1.7465 

Epoch 40/75
------------------------------------------------------
train loss:	0.0737 | train sum of auc:	2.9755 

val loss:	0.8126 | val sum of auc:	1.6828 

Epoch 41/75
------------------------------------------------------
train loss:	0.0759 | train sum of auc:	2.9691 

{'train_loss': 0.07593273374197237, 'train_Reflux_auroc': 0.991600790513834, 'train_Reflux_auprc': 0.9822479110928379, 'train_sum_auroc': 2.9691159420289854, 'train_sum_auprc': 2.9114226999167494, 'train_Surgery_auroc': 0.9941818181818182, 'train_Surgery_auprc': 0.9617224880382775, 'train_Function_auroc': 0.9833333333333334, 'train_Function_auprc': 0.9674523007856342, 'val_loss': 0.69013731113889, 'val_Reflux_auroc': 0.4048021235521235, 'val_Reflux_auprc': 0.1299524657256346, 'val_sum_auroc': 1.6406783717135691, 'val_sum_auprc': 0.7761944491920826, 'val_Surgery_auroc': 0.6828459451311426, 'val_Surgery_auprc': 0.207207304832508

train loss:	0.0456 | train sum of auc:	2.9811 

val loss:	0.9597 | val sum of auc:	1.4574 

Epoch 74/75
------------------------------------------------------
train loss:	0.0395 | train sum of auc:	2.9901 

val loss:	0.8820 | val sum of auc:	1.6269 

Epoch 75/75
------------------------------------------------------
train loss:	0.0470 | train sum of auc:	2.9738 

val loss:	0.8658 | val sum of auc:	1.7193 



wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5817 | train sum of auc:	1.4684 

val loss:	0.4679 | val sum of auc:	1.8521 

Epoch 2/75
------------------------------------------------------
train loss:	0.3837 | train sum of auc:	2.1270 

val loss:	0.4824 | val sum of auc:	1.7099 

Epoch 3/75
------------------------------------------------------
train loss:	0.2748 | train sum of auc:	2.5253 

val loss:	0.5539 | val sum of auc:	1.9350 

Epoch 4/75
------------------------------------------------------
train loss:	0.3050 | train sum of auc:	2.4417 

val loss:	0.5732 | val sum of auc:	1.9420 

Epoch 5/75
------------------------------------------------------
train loss:	0.2175 | train sum of auc:	2.7649 

val loss:	0.5547 | val sum of auc:	1.8257 

Epoch 6/75
------------------------------------------------------
train loss:	0.1846 | train sum of auc:	2.7600 

val loss:	0.6307 | val sum of auc:	1.3575 

Epoch 7/75
-----------------------------------------

train loss:	0.0154 | train sum of auc:	3.0000 

{'train_loss': 0.015406752664623035, 'train_Reflux_auroc': 1.0, 'train_Reflux_auprc': 1.0, 'train_sum_auroc': 3.0, 'train_sum_auprc': 3.0, 'train_Surgery_auroc': 1.0, 'train_Surgery_auprc': 1.0, 'train_Function_auroc': 1.0, 'train_Function_auprc': 0.9999999999999998, 'val_loss': 1.063867175579071, 'val_Reflux_auroc': 0.4337858606557377, 'val_Reflux_auprc': 0.4614727867700874, 'val_sum_auroc': 1.209503198873076, 'val_sum_auprc': 0.9123136210388599, 'val_Surgery_auroc': 0.6765109890109889, 'val_Surgery_auprc': 0.34758490635094924, 'val_Function_auroc': 0.09920634920634921, 'val_Function_auprc': 0.10325592791782323, 'best_epoch': 42}
val loss:	1.0639 | val sum of auc:	1.2095 

Epoch 44/75
------------------------------------------------------
train loss:	0.0250 | train sum of auc:	2.9993 

{'train_loss': 0.024953660330850714, 'train_Reflux_auroc': 1.0, 'train_Reflux_auprc': 1.0, 'train_sum_auroc': 2.9992784992784993, 'train_sum_auprc': 2.992

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.4633 | train sum of auc:	1.3407 

val loss:	0.6676 | val sum of auc:	1.5355 

Epoch 2/75
------------------------------------------------------
train loss:	0.4076 | train sum of auc:	1.7029 

val loss:	0.4840 | val sum of auc:	1.3434 

Epoch 3/75
------------------------------------------------------
train loss:	0.3455 | train sum of auc:	2.0833 

val loss:	0.5543 | val sum of auc:	1.6583 

Epoch 4/75
------------------------------------------------------
train loss:	0.2980 | train sum of auc:	2.4696 

val loss:	0.7832 | val sum of auc:	1.4686 

Epoch 5/75
------------------------------------------------------
train loss:	0.3168 | train sum of auc:	2.3370 

val loss:	0.5912 | val sum of auc:	1.3387 

Epoch 6/75
------------------------------------------------------
train loss:	0.2395 | train sum of auc:	2.6965 

val loss:	0.5485 | val sum of auc:	1.1144 

Epoch 7/75
-----------------------------------------

train loss:	0.0234 | train sum of auc:	3.0000 

val loss:	0.6601 | val sum of auc:	1.9005 

Epoch 44/75
------------------------------------------------------
train loss:	0.0396 | train sum of auc:	2.9859 

val loss:	0.7859 | val sum of auc:	1.8096 

Epoch 45/75
------------------------------------------------------
train loss:	0.0273 | train sum of auc:	2.9977 

val loss:	0.8181 | val sum of auc:	1.7768 

Epoch 46/75
------------------------------------------------------
train loss:	0.0429 | train sum of auc:	2.9880 

val loss:	0.8573 | val sum of auc:	1.8148 

Epoch 47/75
------------------------------------------------------
train loss:	0.0247 | train sum of auc:	2.9994 

val loss:	0.8715 | val sum of auc:	1.7778 

Epoch 48/75
------------------------------------------------------
train loss:	0.0408 | train sum of auc:	2.9980 

val loss:	0.9299 | val sum of auc:	1.7831 

Epoch 49/75
------------------------------------------------------
train loss:	0.0369 | train sum of auc:	2.9975 

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5645 | train sum of auc:	1.7065 

val loss:	0.4313 | val sum of auc:	0.8584 

Epoch 2/75
------------------------------------------------------
train loss:	0.4003 | train sum of auc:	1.9210 

val loss:	0.3450 | val sum of auc:	1.0143 

Epoch 3/75
------------------------------------------------------
train loss:	0.3349 | train sum of auc:	2.2873 

val loss:	0.4245 | val sum of auc:	1.0332 

Epoch 4/75
------------------------------------------------------
train loss:	0.2990 | train sum of auc:	2.4477 

val loss:	0.4860 | val sum of auc:	1.1297 

Epoch 5/75
------------------------------------------------------
train loss:	0.2926 | train sum of auc:	2.4771 

val loss:	0.3312 | val sum of auc:	1.0425 

Epoch 6/75
------------------------------------------------------
train loss:	0.2483 | train sum of auc:	2.7058 

val loss:	nan | val sum of auc:	1.1797 

Epoch 7/75
--------------------------------------------

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5131 | train sum of auc:	1.2367 

val loss:	0.4434 | val sum of auc:	1.4236 

Epoch 2/75
------------------------------------------------------
train loss:	0.3415 | train sum of auc:	2.3687 

val loss:	0.4555 | val sum of auc:	1.7983 

Epoch 3/75
------------------------------------------------------
train loss:	0.3326 | train sum of auc:	2.2722 

val loss:	0.4855 | val sum of auc:	1.7730 

Epoch 4/75
------------------------------------------------------
train loss:	0.2710 | train sum of auc:	2.4916 

val loss:	0.5371 | val sum of auc:	1.4897 

Epoch 5/75
------------------------------------------------------
train loss:	0.2757 | train sum of auc:	2.5296 

val loss:	0.9556 | val sum of auc:	1.6495 

Epoch 6/75
------------------------------------------------------
train loss:	0.2559 | train sum of auc:	2.5296 

val loss:	0.9934 | val sum of auc:	1.7237 

Epoch 7/75
-----------------------------------------

train loss:	0.0425 | train sum of auc:	2.9953 

val loss:	1.0386 | val sum of auc:	1.5141 

Epoch 49/75
------------------------------------------------------
train loss:	0.0758 | train sum of auc:	2.9701 

val loss:	1.0053 | val sum of auc:	1.5810 

Epoch 50/75
------------------------------------------------------
train loss:	0.0403 | train sum of auc:	2.9936 

val loss:	1.2164 | val sum of auc:	1.6069 

Epoch 51/75
------------------------------------------------------
train loss:	0.0610 | train sum of auc:	2.9808 

val loss:	0.9488 | val sum of auc:	1.6629 

EARLY STOPPING 50
Epoch 52/75
------------------------------------------------------


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5801 | train sum of auc:	1.3597 

val loss:	0.6832 | val sum of auc:	1.9750 

Epoch 2/75
------------------------------------------------------
train loss:	0.4030 | train sum of auc:	2.0917 

val loss:	0.3473 | val sum of auc:	1.6547 

Epoch 3/75
------------------------------------------------------
train loss:	0.3580 | train sum of auc:	2.2448 

val loss:	0.5400 | val sum of auc:	1.7338 

Epoch 4/75
------------------------------------------------------
train loss:	0.2780 | train sum of auc:	2.5672 

val loss:	0.4869 | val sum of auc:	1.7680 

Epoch 5/75
------------------------------------------------------
train loss:	0.2648 | train sum of auc:	2.5715 

val loss:	0.4326 | val sum of auc:	1.7009 

Epoch 6/75
------------------------------------------------------
train loss:	0.2453 | train sum of auc:	2.6369 

val loss:	0.3723 | val sum of auc:	1.4873 

Epoch 7/75
-----------------------------------------

train loss:	0.0403 | train sum of auc:	2.9920 

val loss:	0.6510 | val sum of auc:	1.7338 

Epoch 47/75
------------------------------------------------------
train loss:	0.0422 | train sum of auc:	2.9968 

val loss:	0.6548 | val sum of auc:	1.6733 

Epoch 48/75
------------------------------------------------------
train loss:	0.0306 | train sum of auc:	2.9960 

val loss:	0.7286 | val sum of auc:	1.5444 

Epoch 49/75
------------------------------------------------------
train loss:	0.0320 | train sum of auc:	2.9971 

val loss:	0.9215 | val sum of auc:	1.4925 

Epoch 50/75
------------------------------------------------------
train loss:	0.0420 | train sum of auc:	2.9866 

val loss:	1.0811 | val sum of auc:	1.4930 

Epoch 51/75
------------------------------------------------------
train loss:	0.0374 | train sum of auc:	2.9936 

val loss:	0.9420 | val sum of auc:	1.6166 

Epoch 52/75
------------------------------------------------------
train loss:	0.0246 | train sum of auc:	2.9997 

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.4527 | train sum of auc:	1.2425 

val loss:	0.4288 | val sum of auc:	1.5485 

Epoch 2/75
------------------------------------------------------
train loss:	0.3560 | train sum of auc:	2.0737 

val loss:	0.5129 | val sum of auc:	1.5590 

Epoch 3/75
------------------------------------------------------
train loss:	0.3260 | train sum of auc:	2.1828 

val loss:	0.5724 | val sum of auc:	1.4799 

Epoch 4/75
------------------------------------------------------
train loss:	0.2798 | train sum of auc:	2.4708 

val loss:	0.6922 | val sum of auc:	1.4709 

Epoch 5/75
------------------------------------------------------
train loss:	0.2745 | train sum of auc:	2.5447 

val loss:	1.0550 | val sum of auc:	1.3783 

Epoch 6/75
------------------------------------------------------
train loss:	0.3000 | train sum of auc:	2.3019 

val loss:	1.1814 | val sum of auc:	1.6129 

Epoch 7/75
-----------------------------------------

train loss:	0.0768 | train sum of auc:	2.9436 

val loss:	2.1871 | val sum of auc:	1.4296 

Epoch 49/75
------------------------------------------------------
train loss:	0.0763 | train sum of auc:	2.9746 

val loss:	2.2546 | val sum of auc:	1.4537 

Epoch 50/75
------------------------------------------------------
train loss:	0.0691 | train sum of auc:	2.9795 

val loss:	2.2194 | val sum of auc:	1.5339 

Epoch 51/75
------------------------------------------------------
train loss:	0.0824 | train sum of auc:	2.9569 

val loss:	2.4590 | val sum of auc:	1.4964 

EARLY STOPPING 50
Epoch 52/75
------------------------------------------------------


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5014 | train sum of auc:	1.3212 

val loss:	0.5851 | val sum of auc:	1.7167 

Epoch 2/75
------------------------------------------------------
train loss:	0.3941 | train sum of auc:	1.7854 

val loss:	0.3671 | val sum of auc:	1.2282 

Epoch 3/75
------------------------------------------------------
train loss:	0.3456 | train sum of auc:	2.0678 

val loss:	0.4942 | val sum of auc:	1.2194 

Epoch 4/75
------------------------------------------------------
train loss:	0.3254 | train sum of auc:	2.2435 

val loss:	0.6658 | val sum of auc:	1.5522 

Epoch 5/75
------------------------------------------------------
train loss:	0.3437 | train sum of auc:	2.3195 

val loss:	0.7682 | val sum of auc:	1.7421 

Epoch 6/75
------------------------------------------------------
train loss:	0.2649 | train sum of auc:	2.6335 

val loss:	0.6961 | val sum of auc:	1.7180 

Epoch 7/75
-----------------------------------------

train loss:	0.0535 | train sum of auc:	2.9953 

val loss:	0.5800 | val sum of auc:	1.5063 

Epoch 49/75
------------------------------------------------------
train loss:	0.0684 | train sum of auc:	2.9822 

val loss:	0.5391 | val sum of auc:	1.4261 

Epoch 50/75
------------------------------------------------------
train loss:	0.0666 | train sum of auc:	2.9839 

val loss:	0.5375 | val sum of auc:	1.5014 

Epoch 51/75
------------------------------------------------------
train loss:	0.0426 | train sum of auc:	2.9958 

val loss:	0.6050 | val sum of auc:	1.5005 

EARLY STOPPING 50
Epoch 52/75
------------------------------------------------------


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 1/75
------------------------------------------------------
train loss:	0.5084 | train sum of auc:	1.5805 

val loss:	0.4512 | val sum of auc:	1.8987 

Epoch 2/75
------------------------------------------------------
train loss:	0.4168 | train sum of auc:	1.8255 

val loss:	0.3819 | val sum of auc:	1.9085 

Epoch 3/75
------------------------------------------------------
train loss:	0.3546 | train sum of auc:	2.1092 

val loss:	0.8697 | val sum of auc:	1.9442 

Epoch 4/75
------------------------------------------------------
train loss:	0.3373 | train sum of auc:	2.3805 

val loss:	0.4569 | val sum of auc:	1.9493 

Epoch 5/75
------------------------------------------------------
train loss:	0.3515 | train sum of auc:	2.2498 

val loss:	0.6228 | val sum of auc:	1.8722 

Epoch 6/75
------------------------------------------------------
train loss:	0.2953 | train sum of auc:	2.4401 

val loss:	1.3613 | val sum of auc:	1.7919 

Epoch 7/75
-----------------------------------------

In [None]:
# This is for visualizing images, 
# MAKE SURE THIS IS COMMENTED OUT WHEN COMMITTING!!
# image_return_order = ['Saggital_Right', 'Transverse_Right', 'Saggital_Left', 'Transverse_Left', 'Bladder']

# for inputs, labs in training_generator:
#     plt.figure(figsize=(20,10)) 
#     first_scan = inputs[0]
#     for i in range(5):
#         im = first_scan[i,:, :, :]

#         im_np = np.asarray(im).squeeze()
#         plt.subplot(2,5,i+ 1)
#         plt.imshow(im_np, cmap='gray')
#         frame1 = plt.gca()
#         frame1.axes.get_xaxis().set_visible(False)
#         frame1.axes.get_yaxis().set_visible(False)

#         plt.title(image_return_order[i])
        
        
#     break