In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json

from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import *
import time
from datetime import datetime
import os
from torch.utils import data
import random
import copy
import itertools
import io
import uuid
from sklearn.model_selection import KFold, train_test_split

import warnings
warnings.filterwarnings('ignore')

import wandb
wandb_username = 'denizjafari'
local_username = 'denizjafari'

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:1') 
else:
    device = torch.device('cpu')
print(device)

cuda:1


In [3]:
# root directory
root_dir = "/home/andreasabo/Documents/HNProject/"
split_file_base = "/home/andreasabo/Documents/HNUltra/"

# data directory on current machine: abhishekmoturu, andreasabo, denizjafari, navidkorhani
data_dir = "/home/" + local_username + "/Documents/HNProject/all_label_img/"

# read target df
csv_path = os.path.join(root_dir, "all_splits_1000000.csv")
data_df = pd.read_csv(csv_path, usecols=['subj_id', 'image_ids', 'view_label', 'view_train'])

# Are we doing the final test?
test_data = True

### **Reading Data Indicies and Labels**

In [4]:
label_mapping = {'Other':0, 'Saggital_Right':1, 'Transverse_Right':2, 
                 'Saggital_Left':3, 'Transverse_Left':4, 'Bladder':5}
label_unmapping = {0: 'Other', 1:'Saggital_Right', 2: 'Transverse_Right', 
                   3:'Saggital_Left', 4:'Transverse_Left', 5: 'Bladder'}

data_df['view_label'] = data_df['view_label'].map(label_mapping)

train_df = data_df[data_df.view_train == 1]
test_df = data_df[data_df.view_train == 0]

unique_subj = train_df.subj_id.unique()

# Create the splits for 5-fold cross validation based on subj_id
data_split_file = split_file_base + 'data_splits.json'
# just load from file
print("Reading splits from file")
with open(data_split_file, 'r') as f:
    all_folds = json.load(f)

# If we're testing, overwrite the training data with the entire train/test data
if test_data:
    train_images = train_df.image_ids.tolist()
    test_images = test_df.image_ids.tolist()
    train_labels = train_df.view_label.tolist()
    test_labels = test_df.view_label.tolist()

    cur_fold = {'train_ids': train_images, 'test_ids': test_images, 'train_labels': train_labels, 'test_labels': test_labels}

    
    all_folds['test'] = cur_fold

Reading splits from file


In [5]:
print(len(all_folds['test']['train_ids']))
print(len(all_folds['0']['train_ids']))

print(len(all_folds['test']['test_ids']))
print(len(all_folds['0']['valid_ids']))

13958
11081
5070
2877


In [6]:
# MODEL DIRECTORIES 


vae50_dir = '/home/navidkorhani/Documents/HNProject/HNUltra/saved models/vae_model_h800_l50.pt'
vae100_dir = "/home/navidkorhani/Documents/HNProject/HNUltra/results/h800_l100_e30/vae_model.pt"
vae200_dir = "/home/navidkorhani/Documents/HNProject/HNUltra/results/h800_l200_e30/vae_model.pt"


# Models Initialization 

In [7]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [8]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 256

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224


    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == 'vae50':
        
        model_ft = VAE_50()
        input_size = 256
        
    elif model_name == 'vae100':
        
        model_ft = VAE_100()
        input_size = 256
        
    elif model_name == 'vae200':
        
        model_ft = VAE_200()
        input_size = 256
        
        
    elif model_name == 'viewnet':
        conv1_filters = 8
        conv2_filters = 16
        conv3_filters = 32
        linear1_size = 512
        dropout = 0.25
        model_ft = ViewNet(num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout)
        input_size = 256
        
    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size




# Models Architectures 

In [9]:
# custom view labeller 
class ViewNet(nn.Module):

    def __init__(self, num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout):
        super(ViewNet, self).__init__()
        self.conv1_filters = conv1_filters
        self.conv2_filters = conv2_filters
        self.conv3_filters = conv3_filters
        self.linear1_size = linear1_size
        self.drop_percent = dropout
        self.max_pool = 4
        self.conv_output = int(self.conv3_filters*(256/(self.max_pool**3))*(256/(self.max_pool**3)))
        print("conv_output: ", self.conv_output)

        
        self.conv1 = nn.Conv2d(1, self.conv1_filters, 4, padding=2)
        self.conv2 = nn.Conv2d(self.conv1_filters, self.conv2_filters, 4, padding=2)
        self.conv3 = nn.Conv2d(self.conv2_filters, self.conv3_filters, 4, padding=2)
        self.pool = nn.MaxPool2d(self.max_pool, self.max_pool)
        self.dropout = nn.Dropout(self.drop_percent)
        self.linear1 = nn.Linear(self.conv_output, self.linear1_size)
        self.linear2 = nn.Linear(self.linear1_size, num_classes)
    
    def forward(self, x):
        x = self.pool(self.dropout(F.relu(self.conv1(x))))
        x = self.pool(self.dropout(F.relu(self.conv2(x))))
        x = self.pool(self.dropout(F.relu(self.conv3(x))))
        x = x.view(-1, self.conv_output) 
        x = self.dropout(F.relu((self.linear1(x))))
        x = self.linear2(x)
        return x
    
    
class VAE_50(nn.Module):
    def __init__(self):
        super(VAE_50, self).__init__()
        
        hidden_dim = 800
        latent_dim = 50
        self.fc1 = nn.Linear(65536, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 65536)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #print("z.size() =", z.size())
        h3 = F.relu(self.fc3(z))
        #print("h3.size() =", h3.size())
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 65536))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    
class VAE_100(nn.Module):
    def __init__(self):
        super(VAE_100, self).__init__()
        
        hidden_dim = 800
        latent_dim = 100
        self.fc1 = nn.Linear(65536, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 65536)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #print("z.size() =", z.size())
        h3 = F.relu(self.fc3(z))
        #print("h3.size() =", h3.size())
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 65536))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
class VAE_200(nn.Module):
    def __init__(self):
        super(VAE_200, self).__init__()
        
        hidden_dim = 800
        latent_dim = 200
        self.fc1 = nn.Linear(65536, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 65536)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #print("z.size() =", z.size())
        h3 = F.relu(self.fc3(z))
        #print("h3.size() =", h3.size())
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 65536))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [10]:
# Code from: https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=True):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

# Model Training 

In [11]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, final_testing=True):
    es = EarlyStopping(patience = 15)
    stop_now = 0

    since = time.time()
    classnames = ['Other', 'Saggital_Right', 'Transverse_Right', 'Saggital_Left','Transverse_Left', 'Bladder']
    val_acc_history = []
    
    val_metrics_list = []
    train_metrics_list = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epoch_with_best_val_acc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 54)

        if stop_now:
            break
        # Each epoch has a training and validation phase
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            running_preds = []
            running_labels = []

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                labels = labels.type(torch.long)
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()
                
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        labels = torch.argmax(labels, 1)
                        running_preds += torch.argmax(outputs, 1).tolist()
                        running_labels += labels.tolist()
                        loss = criterion(outputs, labels)

                    preds = torch.argmax(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print('{} loss:\t{:.4f} | {} acc:\t{:.4f}\n'.format(phase, epoch_loss, phase, epoch_acc))

            if phase == 'train':
                wandb.log({'epoch': epoch, 'train_acc':epoch_acc, 'train_loss':epoch_loss})
                
                cur_train_metrics = {}
                                # compute and log f1, precision, recall for each class
                for c in range(6):
                    running_labels = np.asarray(running_labels)
                    running_preds = np.asarray(running_preds)

                    cur_c_labs_bin = np.asarray([0] *len(running_labels))
                    cur_c_preds_bin = np.asarray([0] *len(running_labels))

                    # Need to binarize
                    cur_c_preds_bin[running_preds == c] = 1
                    cur_c_labs_bin[running_labels == c] = 1
                    f1 = f1_score(cur_c_labs_bin, cur_c_preds_bin)
                    precision = precision_score(cur_c_labs_bin, cur_c_preds_bin)
                    recall = recall_score(cur_c_labs_bin, cur_c_preds_bin)
                    
                    cur_train_metrics['train_' + label_unmapping[c] + '_f1'] = f1
                    cur_train_metrics['train_' + label_unmapping[c] + '_precision'] = precision
                    cur_train_metrics['train_' + label_unmapping[c] + '_recall'] = recall
                    
                
                train_metrics_list.append(cur_train_metrics)
                
                average_types = ['macro', 'micro', 'weighted']
                average_metrics_to_log = ['precision', 'recall', 'f1score', 'support']
                average_dict = {'epoch': epoch}
                for av in average_types:
                    results_tuple = precision_recall_fscore_support(running_labels, running_preds, average=av)
                    for m in range(len(average_metrics_to_log)):      
                        average_dict[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                        cur_train_metrics[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                cur_train_metrics[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)                  
                average_dict[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)     
                wandb.log(cur_train_metrics)
                
            if phase == 'test':
                wandb.log({'test_loss':epoch_loss, 'test_acc':epoch_acc, 'epoch': epoch})
               
            
                cur_val_metrics = {}
                # compute and log f1, precision, recall for each class
                for c in range(6):
                    running_labels = np.asarray(running_labels)
                    running_preds = np.asarray(running_preds)

                    cur_c_labs_bin = np.asarray([0] *len(running_labels))
                    cur_c_preds_bin = np.asarray([0] *len(running_labels))

                    # Need to binarize
                    cur_c_preds_bin[running_preds == c] = 1
                    cur_c_labs_bin[running_labels == c] = 1
                    f1 = f1_score(cur_c_labs_bin, cur_c_preds_bin)
                    precision = precision_score(cur_c_labs_bin, cur_c_preds_bin)
                    recall = recall_score(cur_c_labs_bin, cur_c_preds_bin)
                    wandb.log({'valid_' + label_unmapping[c] + '_f1': f1})
                    wandb.log({'valid_' + label_unmapping[c] + '_precision': precision})
                    wandb.log({'valid_' + label_unmapping[c] + '_recall': recall})
                
                    cur_val_metrics['val_' + label_unmapping[c] + '_f1'] = f1
                    cur_val_metrics['val_' + label_unmapping[c] + '_precision'] = precision
                    cur_val_metrics['val_' + label_unmapping[c] + '_recall'] = recall
                
                average_types = ['macro', 'micro', 'weighted']
                average_metrics_to_log = ['precision', 'recall', 'f1score']
                average_dict = {'epoch': epoch}
                for av in average_types:
                    results_tuple = precision_recall_fscore_support(running_labels, running_preds, average=av)
                    for m in range(len(average_metrics_to_log)):      
                        average_dict[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                        cur_val_metrics[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                cur_val_metrics[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)                  
                average_dict[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)     
                print(cur_val_metrics)
                wandb.log(cur_val_metrics)
                
                
                val_metrics_list.append(cur_val_metrics)
                
            if phase == 'train':
                print(classification_report(running_labels, running_preds))
                train_acc = epoch_acc
            if phase == 'test' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_acc_train = train_acc
                epoch_with_best_val_acc = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), os.path.join(wandb.run.dir, "model.pt"))
                print(classification_report(running_labels, running_preds))

            if phase == 'test':
                val_acc_history.append(epoch_acc)
                if es.step(epoch_loss) and not final_testing:
                    stop_now = 1
                    print("EARLY STOPPING " + str(epoch))
                    break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val acc: {:4f}\n'.format(best_acc))
    
    # Directly save the best results in this fold
    wandb.config.best_acc = best_acc
    wandb.config.best_epoch = epoch_with_best_val_acc

    wandb.config.val_acc_history = val_acc_history
    wandb.config.best_epoch = epoch_with_best_val_acc
    
    wandb.config.update(val_metrics_list[epoch_with_best_val_acc])
    wandb.config.update(train_metrics_list[epoch_with_best_val_acc])
    
    metrics_from_best_epoch = {'best_epoch': epoch_with_best_val_acc, 'last_epoch': epoch}
    metrics_from_best_epoch.update( val_metrics_list[epoch_with_best_val_acc] )
    metrics_from_best_epoch.update( train_metrics_list[epoch_with_best_val_acc] )
    metrics_from_best_epoch.update( {'val_acc': best_acc.data.cpu(), 'train_acc': best_acc_train.data.cpu()} )    
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, metrics_from_best_epoch

In [12]:
def train5fold(network_configs,criterion_used,  model_ft, lr, wd, amsgrad, i):
    fold = ['test']

    random_str = str(uuid.uuid4()).split("-")[0]
    best_metrics_per_fold = []
    model_base = copy.deepcopy(model_ft)
    for fold in folds:
        now = datetime.now()
        date_time = now.strftime("%d-%m-%Y.%H:%M:%S")
        wandb.init(project='hnultra_test', entity=wandb_username, name=local_username + '_fold_' + fold, group=random_str)
        partition = all_folds[fold]

        model_ft = copy.deepcopy(model_base)
        model_ft = model_ft.to(device)
        wandb.watch(model_ft)

        # Gather the parameters to be optimized/updated in this run. If we are
        #  finetuning we will be updating all parameters. However, if we are
        #  doing feature extract method, we will only update the parameters
        #  that we have just initialized, i.e. the parameters with requires_grad
        #  is True.
        params_to_update = model_ft.parameters()
        #print("Params to learn:")
        if feature_extract:
            params_to_update = []
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
                    print("\t",name)
        else:
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    print("\t",name)

        # Observe that all parameters are being optimized
        optimizer_ft = optim.Adam(params_to_update, lr=lr, weight_decay=wd, amsgrad=amsgrad)

        # Setup the loss fxn
        criterion = criterion_used

        shuffle = True
        num_workers = 0
        params = {'batch_size': batch_size,
                  'shuffle': shuffle,
                  'num_workers': num_workers}

        config_dict = {'i': i, 'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': num_workers, 'fold': fold,
                       'lr': lr, 'wd': wd, 'amsgrad': amsgrad, 'model_name': model_name,'criterion': criterion, 'num_classes': num_classes, 
                       'num_epochs': num_epochs, 'feature_extract': feature_extract, "pretrain": pretrain }

        wandb.config.update(config_dict)
        wandb.config.update(network_configs)
        # Tranforms
        trans = transforms.Compose([transforms.RandomAffine(degrees=8, translate=(0.1, 0.1), scale=(0.95,1.25))])

        # Generators
        training_set = Dataset(partition['train_ids'], partition['train_labels'], transformations=trans)
        training_generator = data.DataLoader(training_set, **params)

        validation_set = Dataset(partition['test_ids'], partition['test_labels'])
        validation_generator = data.DataLoader(validation_set, **params)

        dataloaders_dict = {'train':training_generator, 'test':validation_generator}

        # Train & Evaluate
        model_ft, hist, metrics_from_best_epoch = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs, is_inception=(model_name=="inception"))
        best_metrics_per_fold.append(metrics_from_best_epoch)

    # Calculate the performance metrics on the best model in each fold
    wandb.init(project='hnultra', entity=wandb_username, name=local_username + '_ALL', group=random_str)
    config_dict['fold'] = -1
    wandb.config.update(config_dict)
    wandb.config.update(network_configs)


    metrics_all = {}
    for fold in best_metrics_per_fold:
        for key in fold:
            if key not in metrics_all:
                metrics_all[key] = [fold[key]]
            else:
                metrics_all[key].append(fold[key]) 

    metrics_to_log = {}
    for m in metrics_all:
        metric_list = np.asarray(metrics_all[m])

        metrics_to_log[m + '_mean'] = metric_list.mean()    
        metrics_to_log[m + '_stdev'] = metric_list.std()

    wandb.config.update(metrics_to_log)
  

In [13]:
class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, list_IDs, labels, transformations=None):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs
        self.transformations = transformations
        
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        img_path = data_dir + ID + '.jpg'
        image = Image.open(img_path).convert('L')
        
        if self.transformations:
            image = self.transformations(image)
        
        image = ToTensor()(image)
        
        y = torch.FloatTensor([0]*6)        
        y[int(self.labels[index])] = 1

        return image, y

In [14]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception, viewnet]
model_name = "viewnet"

# Number of classes in the dataset: right_sag, right_trav, left_sag, left_trav, bladder, other
num_classes = 6

# Batch size for training (change depending on how much memory you have)
batch_size = 100

# Number of epochs to train for
num_epochs = 46

# Flag for feature extracting. When False, we finetune the whole model; when True we only update the reshaped layer params
feature_extract = False

# Flag for whether or not to use pretrained model
pretrain = False

criterion_used = nn.CrossEntropyLoss()

In [15]:
repetitions = 5

conv1_filters = 8
conv2_filters = 16
conv3_filters = 32
linear1_size = 512

dropout = 0.25
lr = 0.0005
wd = 0.001
amsgrad = False

for i in range(repetitions):
    config_string = f"{conv1_filters}_{conv2_filters}_{conv3_filters}_{linear1_size}_{dropout}_{lr}_{wd}_{amsgrad}"
    model_ft = ViewNet(num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout)
    run_configs = {'lr': lr, 'wd': wd, 'amsgrad': amsgrad,'dropout': dropout, 
                  'conv1_filters': conv1_filters, 'conv2_filters': conv2_filters, 
                  'conv3_filters': conv3_filters, 'linear1_size': linear1_size }

    train5fold(run_configs,criterion_used, model_ft, lr, wd, amsgrad, i)

conv_output:  512


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


	 conv1.weight
	 conv1.bias
	 conv2.weight
	 conv2.bias
	 conv3.weight
	 conv3.bias
	 linear1.weight
	 linear1.bias
	 linear2.weight
	 linear2.bias
Epoch 1/46
------------------------------------------------------
train loss:	1.5708 | train acc:	0.3827

              precision    recall  f1-score   support

           0       0.38      0.90      0.53      4943
           1       0.34      0.07      0.12      2225
           2       0.00      0.00      0.00      1539
           3       0.31      0.15      0.20      2284
           4       0.19      0.00      0.00      1669
           5       0.58      0.30      0.39      1298

    accuracy                           0.38     13958
   macro avg       0.30      0.24      0.21     13958
weighted avg       0.32      0.38      0.28     13958

test loss:	1.4995 | test acc:	0.4617

{'val_Other_f1': 0.5904466287290913, 'val_Other_precision': 0.44946180099763716, 'val_Other_recall': 0.8603015075376884, 'val_Saggital_Right_f1': 0.2158119658119658,

test loss:	1.4031 | test acc:	0.5101

{'val_Other_f1': 0.6018581463856787, 'val_Other_precision': 0.5480808914568717, 'val_Other_recall': 0.6673366834170854, 'val_Saggital_Right_f1': 0.41059602649006627, 'val_Saggital_Right_precision': 0.5933014354066986, 'val_Saggital_Right_recall': 0.3139240506329114, 'val_Transverse_Right_f1': 0.3111668757841907, 'val_Transverse_Right_precision': 0.656084656084656, 'val_Transverse_Right_recall': 0.20394736842105263, 'val_Saggital_Left_f1': 0.4097610574478902, 'val_Saggital_Left_precision': 0.3141075604053001, 'val_Saggital_Left_recall': 0.5891812865497076, 'val_Transverse_Left_f1': 0.25771812080536916, 'val_Transverse_Left_precision': 0.44036697247706424, 'val_Transverse_Left_recall': 0.18216318785578747, 'val_Bladder_f1': 0.7663366336633664, 'val_Bladder_precision': 0.7179962894248608, 'val_Bladder_recall': 0.821656050955414, 'test_precision_average_macro': 0.5449896342092418, 'test_recall_average_macro': 0.46303477130532644, 'test_f1score_average_

train loss:	1.2019 | train acc:	0.5363

              precision    recall  f1-score   support

           0       0.53      0.74      0.62      4943
           1       0.57      0.42      0.48      2225
           2       0.50      0.24      0.33      1539
           3       0.47      0.47      0.47      2284
           4       0.47      0.29      0.36      1669
           5       0.75      0.73      0.74      1298

    accuracy                           0.54     13958
   macro avg       0.55      0.48      0.50     13958
weighted avg       0.53      0.54      0.52     13958

test loss:	1.3245 | test acc:	0.5464

{'val_Other_f1': 0.6297996661101837, 'val_Other_precision': 0.538543897216274, 'val_Other_recall': 0.7582914572864322, 'val_Saggital_Right_f1': 0.4751381215469613, 'val_Saggital_Right_precision': 0.6310272536687631, 'val_Saggital_Right_recall': 0.3810126582278481, 'val_Transverse_Right_f1': 0.282258064516129, 'val_Transverse_Right_precision': 0.7720588235294118, 'val_Transvers

train loss:	1.1265 | train acc:	0.5671

              precision    recall  f1-score   support

           0       0.55      0.74      0.63      4943
           1       0.61      0.46      0.52      2225
           2       0.57      0.32      0.41      1539
           3       0.51      0.50      0.51      2284
           4       0.51      0.37      0.43      1669
           5       0.77      0.76      0.76      1298

    accuracy                           0.57     13958
   macro avg       0.59      0.52      0.54     13958
weighted avg       0.57      0.57      0.56     13958

test loss:	1.2492 | test acc:	0.5657

{'val_Other_f1': 0.6224535989135356, 'val_Other_precision': 0.5663097199341022, 'val_Other_recall': 0.6909547738693468, 'val_Saggital_Right_f1': 0.5146757679180887, 'val_Saggital_Right_precision': 0.5585185185185185, 'val_Saggital_Right_recall': 0.4772151898734177, 'val_Transverse_Right_f1': 0.4917715392061956, 'val_Transverse_Right_precision': 0.5976470588235294, 'val_Transve

train loss:	1.0758 | train acc:	0.5836

              precision    recall  f1-score   support

           0       0.56      0.74      0.64      4943
           1       0.64      0.48      0.55      2225
           2       0.58      0.37      0.45      1539
           3       0.52      0.51      0.52      2284
           4       0.52      0.41      0.46      1669
           5       0.77      0.77      0.77      1298

    accuracy                           0.58     13958
   macro avg       0.60      0.55      0.57     13958
weighted avg       0.59      0.58      0.58     13958

test loss:	1.2039 | test acc:	0.5817

{'val_Other_f1': 0.6459652706843719, 'val_Other_precision': 0.5442340791738383, 'val_Other_recall': 0.7944723618090452, 'val_Saggital_Right_f1': 0.5027667984189722, 'val_Saggital_Right_precision': 0.6694736842105263, 'val_Saggital_Right_recall': 0.40253164556962023, 'val_Transverse_Right_f1': 0.4776444929116685, 'val_Transverse_Right_precision': 0.7087378640776699, 'val_Transv

train loss:	1.0316 | train acc:	0.6034

              precision    recall  f1-score   support

           0       0.59      0.74      0.66      4943
           1       0.65      0.50      0.56      2225
           2       0.61      0.41      0.49      1539
           3       0.54      0.54      0.54      2284
           4       0.56      0.45      0.50      1669
           5       0.78      0.79      0.79      1298

    accuracy                           0.60     13958
   macro avg       0.62      0.57      0.59     13958
weighted avg       0.61      0.60      0.60     13958

test loss:	1.1721 | test acc:	0.5876

{'val_Other_f1': 0.6363008971704623, 'val_Other_precision': 0.5867628341111583, 'val_Other_recall': 0.6949748743718593, 'val_Saggital_Right_f1': 0.5410764872521246, 'val_Saggital_Right_precision': 0.6141479099678456, 'val_Saggital_Right_recall': 0.4835443037974684, 'val_Transverse_Right_f1': 0.5528301886792453, 'val_Transverse_Right_precision': 0.6482300884955752, 'val_Transve

train loss:	0.9913 | train acc:	0.6201

              precision    recall  f1-score   support

           0       0.60      0.75      0.67      4943
           1       0.67      0.51      0.58      2225
           2       0.63      0.44      0.52      1539
           3       0.56      0.55      0.55      2284
           4       0.58      0.49      0.53      1669
           5       0.80      0.81      0.80      1298

    accuracy                           0.62     13958
   macro avg       0.64      0.59      0.61     13958
weighted avg       0.62      0.62      0.61     13958

test loss:	1.1419 | test acc:	0.5860

{'val_Other_f1': 0.6485484867201976, 'val_Other_precision': 0.5493547261946286, 'val_Other_recall': 0.7914572864321608, 'val_Saggital_Right_f1': 0.528328611898017, 'val_Saggital_Right_precision': 0.5996784565916399, 'val_Saggital_Right_recall': 0.47215189873417723, 'val_Transverse_Right_f1': 0.4646924829157176, 'val_Transverse_Right_precision': 0.7555555555555555, 'val_Transve

test loss:	1.1627 | test acc:	0.5929

{'val_Other_f1': 0.6456725755995829, 'val_Other_precision': 0.5518716577540107, 'val_Other_recall': 0.7778894472361809, 'val_Saggital_Right_f1': 0.5416974169741697, 'val_Saggital_Right_precision': 0.6495575221238938, 'val_Saggital_Right_recall': 0.46455696202531643, 'val_Transverse_Right_f1': 0.5380493033226154, 'val_Transverse_Right_precision': 0.7723076923076924, 'val_Transverse_Right_recall': 0.4128289473684211, 'val_Saggital_Left_f1': 0.4404567699836868, 'val_Saggital_Left_precision': 0.4981549815498155, 'val_Saggital_Left_recall': 0.39473684210526316, 'val_Transverse_Left_f1': 0.4514038876889849, 'val_Transverse_Left_precision': 0.5238095238095238, 'val_Transverse_Left_recall': 0.396584440227704, 'val_Bladder_f1': 0.7977900552486188, 'val_Bladder_precision': 0.8317972350230415, 'val_Bladder_recall': 0.7664543524416136, 'test_precision_average_macro': 0.6379164354279963, 'test_recall_average_macro': 0.5355084985674164, 'test_f1score_average_mac

test loss:	1.1392 | test acc:	0.5949

{'val_Other_f1': 0.6296470588235294, 'val_Other_precision': 0.5920353982300885, 'val_Other_recall': 0.6723618090452261, 'val_Saggital_Right_f1': 0.5612316305108467, 'val_Saggital_Right_precision': 0.6275430359937402, 'val_Saggital_Right_recall': 0.5075949367088608, 'val_Transverse_Right_f1': 0.595581171950048, 'val_Transverse_Right_precision': 0.7159353348729792, 'val_Transverse_Right_recall': 0.5098684210526315, 'val_Saggital_Left_f1': 0.4668508287292818, 'val_Saggital_Left_precision': 0.4424083769633508, 'val_Saggital_Left_recall': 0.49415204678362573, 'val_Transverse_Left_f1': 0.4888457807953443, 'val_Transverse_Left_precision': 0.5, 'val_Transverse_Left_recall': 0.4781783681214421, 'val_Bladder_f1': 0.8012752391073327, 'val_Bladder_precision': 0.8021276595744681, 'val_Bladder_recall': 0.8004246284501062, 'test_precision_average_macro': 0.6133416342724377, 'test_recall_average_macro': 0.5770967016936487, 'test_f1score_average_macro': 0.590571951

train loss:	0.8968 | train acc:	0.6560

              precision    recall  f1-score   support

           0       0.63      0.76      0.69      4943
           1       0.70      0.57      0.63      2225
           2       0.67      0.52      0.59      1539
           3       0.61      0.61      0.61      2284
           4       0.62      0.53      0.57      1669
           5       0.82      0.82      0.82      1298

    accuracy                           0.66     13958
   macro avg       0.67      0.64      0.65     13958
weighted avg       0.66      0.66      0.65     13958

test loss:	1.1278 | test acc:	0.5905

{'val_Other_f1': 0.6322429906542056, 'val_Other_precision': 0.5908296943231441, 'val_Other_recall': 0.6798994974874372, 'val_Saggital_Right_f1': 0.536869340232859, 'val_Saggital_Right_precision': 0.548941798941799, 'val_Saggital_Right_recall': 0.5253164556962026, 'val_Transverse_Right_f1': 0.5487674169346195, 'val_Transverse_Right_precision': 0.7876923076923077, 'val_Transvers

test loss:	1.1222 | test acc:	0.5974

{'val_Other_f1': 0.6236399604352127, 'val_Other_precision': 0.6139240506329114, 'val_Other_recall': 0.6336683417085427, 'val_Saggital_Right_f1': 0.5829652996845426, 'val_Saggital_Right_precision': 0.5811320754716981, 'val_Saggital_Right_recall': 0.5848101265822785, 'val_Transverse_Right_f1': 0.6082289803220036, 'val_Transverse_Right_precision': 0.6666666666666666, 'val_Transverse_Right_recall': 0.5592105263157895, 'val_Saggital_Left_f1': 0.46219081272084805, 'val_Saggital_Left_precision': 0.4473324213406293, 'val_Saggital_Left_recall': 0.4780701754385965, 'val_Transverse_Left_f1': 0.4855721393034826, 'val_Transverse_Left_precision': 0.5104602510460251, 'val_Transverse_Left_recall': 0.4629981024667932, 'val_Bladder_f1': 0.8119218910585817, 'val_Bladder_precision': 0.7868525896414342, 'val_Bladder_recall': 0.8386411889596603, 'test_precision_average_macro': 0.6010613424665608, 'test_recall_average_macro': 0.5928997435786101, 'test_f1score_average_mac

train loss:	0.8336 | train acc:	0.6792

              precision    recall  f1-score   support

           0       0.66      0.77      0.71      4943
           1       0.71      0.60      0.65      2225
           2       0.69      0.56      0.62      1539
           3       0.63      0.62      0.63      2284
           4       0.64      0.57      0.61      1669
           5       0.84      0.84      0.84      1298

    accuracy                           0.68     13958
   macro avg       0.70      0.66      0.68     13958
weighted avg       0.68      0.68      0.68     13958

test loss:	1.0883 | test acc:	0.6024

{'val_Other_f1': 0.6444035006909259, 'val_Other_precision': 0.594812925170068, 'val_Other_recall': 0.7030150753768845, 'val_Saggital_Right_f1': 0.5617647058823529, 'val_Saggital_Right_precision': 0.6701754385964912, 'val_Saggital_Right_recall': 0.4835443037974684, 'val_Transverse_Right_f1': 0.5466377440347072, 'val_Transverse_Right_precision': 0.802547770700637, 'val_Transvers

OSError: [Errno 12] Cannot allocate memory