In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json

from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import *
import time
from datetime import datetime
import os
from torch.utils import data
import random
import copy
import itertools
import io
import uuid
from sklearn.model_selection import KFold, train_test_split

import warnings
warnings.filterwarnings('ignore')

import wandb
wandb_username = 'denizjafari'
local_username = 'denizjafari'


In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:1') 
else:
    device = torch.device('cpu')
print(device)

cuda:1


In [3]:
# source: https://scikit-learn.org/0.16/auto_examples/model_selection/plot_confusion_matrix.html
#example-model-selection-plot-confusion-matrix-py
def plot_confusion_matrix(cm,fold, classnames, title='Confusion matrix', cmap=plt.cm.Blues):
    fig, ax = plt.subplots(1, figsize=(14, 10))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classnames))
    plt.xticks(tick_marks, classnames, rotation=45)
    plt.yticks(tick_marks, classnames)
    
    
    thresh = cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black", fontsize=19)
        
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('ViewNetConfMtrx' + fold + '.png')
    wandb.log({"confusion"+ fold: plt})

In [4]:
# root directory
root_dir = "/home/andreasabo/Documents/HNProject/"
split_file_base = "/home/andreasabo/Documents/HNUltra/"

# data directory on current machine: abhishekmoturu, andreasabo, denizjafari, navidkorhani
data_dir = "/home/" + local_username + "/Documents/HNProject/all_label_img/"

# read target df
csv_path = os.path.join(root_dir, "all_splits_1000000.csv")
data_df = pd.read_csv(csv_path, usecols=['subj_id', 'image_ids', 'view_label', 'view_train'])

# Are we doing the final test?
test_data = False

### **Reading Data Indicies and Labels**

In [5]:
label_mapping = {'Other':0, 'Saggital_Right':1, 'Transverse_Right':2, 
                 'Saggital_Left':3, 'Transverse_Left':4, 'Bladder':5}
label_unmapping = {0: 'Other', 1:'Saggital_Right', 2: 'Transverse_Right', 
                   3:'Saggital_Left', 4:'Transverse_Left', 5: 'Bladder'}

data_df['view_label'] = data_df['view_label'].map(label_mapping)

train_df = data_df[data_df.view_train == 1]
test_df = data_df[data_df.view_train == 0]

unique_subj = train_df.subj_id.unique()

# Create the splits for 5-fold cross validation based on subj_id
data_split_file = split_file_base + 'data_splits.json'
if not os.path.isfile(data_split_file):

    kf = KFold(n_splits=5, random_state=0, shuffle=True)
    fold = 0
    all_folds = {}
    for train_subj, val_subj in kf.split(unique_subj):
        train_ids  = unique_subj[train_subj]
        val_ids = unique_subj[val_subj]

        train_images = train_df[train_df.subj_id.isin(train_ids)].image_ids.tolist()
        val_images = train_df[train_df.subj_id.isin(val_ids)].image_ids.tolist()
        train_labels = train_df[train_df.subj_id.isin(train_ids)].view_label.tolist()
        val_labels = train_df[train_df.subj_id.isin(val_ids)].view_label.tolist()
        cur_fold = {'train_ids': train_images, 'valid_ids': val_images, 'train_labels': train_labels, 'valid_labels': val_labels}
        all_folds[fold] = cur_fold
        fold += 1

    print("Saving data splits")
    with open(data_split_file, 'w') as f:
        json.dump(all_folds, f)
        
else: # just load from file
    print("Reading splits from file")
    with open(data_split_file, 'r') as f:
        all_folds = json.load(f)

# If we're testing, overwrite the training data with the entire train/test data
#if test_data:
#    train_images = train_df.image_ids.tolist()
test_images = test_df.image_ids.tolist()
#    train_labels = train_df.view_label.tolist()
test_labels = test_df.view_label.tolist()

test_set = {'test_ids': test_images, 'test_labels': test_labels}
    
#all_folds['test'] = cur_fold

Reading splits from file


In [6]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [7]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 256

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    elif model_name == 'viewnet':
        conv1_filters = 8
        conv2_filters = 16
        conv3_filters = 32
        linear1_size = 512
        dropout = 0.25
        model_ft = ViewNet(num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout)
        input_size = 256
        
    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


class ViewNet(nn.Module):

    def __init__(self, num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout):
        super(ViewNet, self).__init__()
        self.conv1_filters = conv1_filters
        self.conv2_filters = conv2_filters
        self.conv3_filters = conv3_filters
        self.linear1_size = linear1_size
        self.drop_percent = dropout
        self.max_pool = 4
        self.conv_output = int(self.conv3_filters*(256/(self.max_pool**3))*(256/(self.max_pool**3)))
        print("conv_output: ", self.conv_output)

        
        self.conv1 = nn.Conv2d(1, self.conv1_filters, 4, padding=2)
        self.conv2 = nn.Conv2d(self.conv1_filters, self.conv2_filters, 4, padding=2)
        self.conv3 = nn.Conv2d(self.conv2_filters, self.conv3_filters, 4, padding=2)
        self.pool = nn.MaxPool2d(self.max_pool, self.max_pool)
        self.dropout = nn.Dropout(self.drop_percent)
        self.linear1 = nn.Linear(self.conv_output, self.linear1_size)
        self.linear2 = nn.Linear(self.linear1_size, num_classes)
    
    def forward(self, x):
        x = self.pool(self.dropout(F.relu(self.conv1(x))))
        x = self.pool(self.dropout(F.relu(self.conv2(x))))
        x = self.pool(self.dropout(F.relu(self.conv3(x))))
        x = x.view(-1, self.conv_output) 
        x = self.dropout(F.relu((self.linear1(x))))
        x = self.linear2(x)
        return x

In [8]:
# Code from: https://gist.github.com/stefanonardo/693d96ceb2f531fa05db530f3e21517d
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=True):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if np.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [9]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, final_testing=False):
    es = EarlyStopping(patience = 15)
    stop_now = 0

    since = time.time()
    classnames = ['Other', 'Saggital_Right', 'Transverse_Right', 'Saggital_Left','Transverse_Left', 'Bladder']
    val_acc_history = []
    
    val_metrics_list = []
    train_metrics_list = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epoch_with_best_val_acc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 54)

        if stop_now:
            break
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            running_preds = []
            running_labels = []

            # Iterate over data.
            for inputs, labels, _ in dataloaders[phase]:
                labels = labels.type(torch.long)
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()
                
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        labels = torch.argmax(labels, 1)
                        running_preds += torch.argmax(outputs, 1).tolist()
                        running_labels += labels.tolist()
                        loss = criterion(outputs, labels)

                    preds = torch.argmax(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            print('{} loss:\t{:.4f} | {} acc:\t{:.4f}\n'.format(phase, epoch_loss, phase, epoch_acc))

            if phase == 'train':
                wandb.log({'epoch': epoch, 'train_acc':epoch_acc, 'train_loss':epoch_loss})
                
                cur_train_metrics = {}
                                # compute and log f1, precision, recall for each class
                for c in range(6):
                    running_labels = np.asarray(running_labels)
                    running_preds = np.asarray(running_preds)

                    cur_c_labs_bin = np.asarray([0] *len(running_labels))
                    cur_c_preds_bin = np.asarray([0] *len(running_labels))

                    # Need to binarize
                    cur_c_preds_bin[running_preds == c] = 1
                    cur_c_labs_bin[running_labels == c] = 1
                    f1 = f1_score(cur_c_labs_bin, cur_c_preds_bin)
                    precision = precision_score(cur_c_labs_bin, cur_c_preds_bin)
                    recall = recall_score(cur_c_labs_bin, cur_c_preds_bin)
                    
                    cur_train_metrics['train_' + label_unmapping[c] + '_f1'] = f1
                    cur_train_metrics['train_' + label_unmapping[c] + '_precision'] = precision
                    cur_train_metrics['train_' + label_unmapping[c] + '_recall'] = recall
                    
                
                train_metrics_list.append(cur_train_metrics)
                
                average_types = ['macro', 'micro', 'weighted']
                average_metrics_to_log = ['precision', 'recall', 'f1score', 'support']
                average_dict = {'epoch': epoch}
                for av in average_types:
                    results_tuple = precision_recall_fscore_support(running_labels, running_preds, average=av)
                    for m in range(len(average_metrics_to_log)):      
                        average_dict[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                        cur_train_metrics[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                cur_train_metrics[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)                  
                average_dict[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)     
                wandb.log(cur_train_metrics)
                
            if phase == 'val':
                wandb.log({'valid_loss':epoch_loss, 'valid_acc':epoch_acc, 'epoch': epoch})
               
            
                cur_val_metrics = {}
                # compute and log f1, precision, recall for each class
                for c in range(6):
                    running_labels = np.asarray(running_labels)
                    running_preds = np.asarray(running_preds)

                    cur_c_labs_bin = np.asarray([0] *len(running_labels))
                    cur_c_preds_bin = np.asarray([0] *len(running_labels))

                    # Need to binarize
                    cur_c_preds_bin[running_preds == c] = 1
                    cur_c_labs_bin[running_labels == c] = 1
                    f1 = f1_score(cur_c_labs_bin, cur_c_preds_bin)
                    precision = precision_score(cur_c_labs_bin, cur_c_preds_bin)
                    recall = recall_score(cur_c_labs_bin, cur_c_preds_bin)
                    wandb.log({'valid_' + label_unmapping[c] + '_f1': f1})
                    wandb.log({'valid_' + label_unmapping[c] + '_precision': precision})
                    wandb.log({'valid_' + label_unmapping[c] + '_recall': recall})
                
                    cur_val_metrics['val_' + label_unmapping[c] + '_f1'] = f1
                    cur_val_metrics['val_' + label_unmapping[c] + '_precision'] = precision
                    cur_val_metrics['val_' + label_unmapping[c] + '_recall'] = recall
                
                average_types = ['macro', 'micro', 'weighted']
                average_metrics_to_log = ['precision', 'recall', 'f1score']
                average_dict = {'epoch': epoch}
                for av in average_types:
                    results_tuple = precision_recall_fscore_support(running_labels, running_preds, average=av)
                    for m in range(len(average_metrics_to_log)):      
                        average_dict[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                        cur_val_metrics[phase + '_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                cur_val_metrics[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)                  
                average_dict[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)     
                print(cur_val_metrics)
                wandb.log(cur_val_metrics)
                
                
                val_metrics_list.append(cur_val_metrics)
                
            if phase == 'train':
                print(classification_report(running_labels, running_preds))
                train_acc = epoch_acc
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_acc_train = train_acc
                epoch_with_best_val_acc = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), os.path.join(wandb.run.dir, "model.pt"))
                print(classification_report(running_labels, running_preds))

            if phase == 'val':
                val_acc_history.append(epoch_acc)
                if es.step(epoch_loss) and not final_testing:
                    stop_now = 1
                    print("EARLY STOPPING " + str(epoch))
                    break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val acc: {:4f}\n'.format(best_acc))
    
    # Directly save the best results in this fold
    wandb.config.best_acc = best_acc
    wandb.config.best_epoch = epoch_with_best_val_acc

    wandb.config.val_acc_history = val_acc_history
    wandb.config.best_epoch = epoch_with_best_val_acc
    
    wandb.config.update(val_metrics_list[epoch_with_best_val_acc])
    wandb.config.update(train_metrics_list[epoch_with_best_val_acc])
    
    metrics_from_best_epoch = {'best_epoch': epoch_with_best_val_acc, 'last_epoch': epoch}
    metrics_from_best_epoch.update( val_metrics_list[epoch_with_best_val_acc] )
    metrics_from_best_epoch.update( train_metrics_list[epoch_with_best_val_acc] )
    metrics_from_best_epoch.update( {'val_acc': best_acc.data.cpu(), 'train_acc': best_acc_train.data.cpu()} )    
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, metrics_from_best_epoch

In [10]:
class Dataset(data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, list_IDs, labels, transformations=None):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs
        self.transformations = transformations
        
  def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.list_IDs[index]

        # Load data and get label
        img_path = data_dir + ID + '.jpg'
        image = Image.open(img_path).convert('L')
        
        if self.transformations:
            image = self.transformations(image)
        
        image = ToTensor()(image)
        
        y = torch.FloatTensor([0]*6)        
        y[int(self.labels[index])] = 1

        return image, y, ID

In [11]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception, viewnet]
model_name = "viewnet"

# Number of classes in the dataset: right_sag, right_trav, left_sag, left_trav, bladder, other
num_classes = 6

# Batch size for training (change depending on how much memory you have)
batch_size = 100

# Number of epochs to train for
num_epochs = 100

# Flag for feature extracting. When False, we finetune the whole model; when True we only update the reshaped layer params
feature_extract = False

# Flag for whether or not to use pretrained model
pretrain = False

In [12]:
def train5fold(network_configs, model_ft, lr, wd, amsgrad, i):
    #folds = ['test']
    folds = ['0', '1', '2', '3', '4']
    classnames = ['Other', 'Saggital_Right', 'Transverse_Right', 'Saggital_Left','Transverse_Left', 'Bladder']
    project_name = 'hnultra_test_apr6_vae'
    random_str = str(uuid.uuid4()).split("-")[0]
    best_metrics_per_fold = []
    test_metrics_per_fold = []
    model_base = copy.deepcopy(model_ft)
    final_df = pd.DataFrame()
    for fold in folds:

        now = datetime.now()
        date_time = now.strftime("%d-%m-%Y.%H:%M:%S")
        wandb.init(project=project_name, entity=wandb_username, name=local_username + '_fold_' + fold, group=random_str)
        partition = all_folds[fold]

        model_ft = copy.deepcopy(model_base)
        model_ft = model_ft.to(device)
        wandb.watch(model_ft)

        # Gather the parameters to be optimized/updated in this run. If we are
        #  finetuning we will be updating all parameters. However, if we are
        #  doing feature extract method, we will only update the parameters
        #  that we have just initialized, i.e. the parameters with requires_grad
        #  is True.
        params_to_update = model_ft.parameters()
        #print("Params to learn:")
        if feature_extract:
            params_to_update = []
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
                    print("\t",name)
        else:
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    print("\t",name)

        # Observe that all parameters are being optimized
        optimizer_ft = optim.Adam(params_to_update, lr=lr, weight_decay=wd, amsgrad=amsgrad)

        # Setup the loss fxn
        criterion = nn.CrossEntropyLoss()

        shuffle = True
        num_workers = 0
        params = {'batch_size': batch_size,
                  'shuffle': shuffle,
                  'num_workers': num_workers}

        config_dict = {'i': i, 'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': num_workers, 'fold': fold,
                       'lr': lr, 'wd': wd, 'amsgrad': amsgrad, 'model_name': model_name, 'num_classes': num_classes, 
                       'num_epochs': num_epochs, 'feature_extract': feature_extract, "pretrain": pretrain }

        wandb.config.update(config_dict)
        wandb.config.update(network_configs)
        # Tranforms
        trans = transforms.Compose([transforms.RandomAffine(degrees=8, translate=(0.1, 0.1), scale=(0.95,1.25))])

        # Generators
        training_set = Dataset(partition['train_ids'], partition['train_labels'], transformations=trans)
        training_generator = data.DataLoader(training_set, **params)

        validation_set = Dataset(partition['valid_ids'], partition['valid_labels'])
        validation_generator = data.DataLoader(validation_set, **params)
        
        testing_set = Dataset(test_set['test_ids'], test_set['test_labels'])
        test_generator = data.DataLoader(testing_set, **params)

        dataloaders_dict = {'train':training_generator, 'val':validation_generator}

        # Train & Evaluate
        model_ft, hist, metrics_from_best_epoch = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs, is_inception=(model_name=="inception"))
        best_metrics_per_fold.append(metrics_from_best_epoch)
        
        # Test the model on test data
        # saving the results in df for further analysis on machine and year
        
        run_y = []
        run_pred = []
        run_ids = []
        running_loss =  0
        running_corrects = 0
        
        model_ft.eval()
        for inputs, labels, ids in test_generator:
            
            run_ids += np.array(ids).tolist()
            
            labels = labels.type(torch.long)
            inputs = inputs.to(device)
            labels = labels.to(device)
    
            # zero the parameter gradients
            optimizer_ft.zero_grad()
            outputs = model_ft(inputs)
            labels = torch.argmax(labels, 1)
            run_pred += torch.argmax(outputs, 1).tolist()
            run_y += labels.tolist()
            loss = criterion(outputs, labels)
            preds = torch.argmax(outputs, 1)
            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        test_loss = running_loss / len(test_generator)
        test_acc = running_corrects.double() / len(test_generator)
        
        
        print('Test loss:\t{:.4f} | Test acc:\t{:.4f}\n'.format( test_loss, test_acc))
        wandb.log({'test_loss':test_loss, 'test_acc':test_acc})
        
        # calculating the metrics
        test_metrics = {'fold': fold}
        # compute and log f1, precision, recall for each class
        for c in range(6):
            run_ids = np.asarray(run_ids)
            run_y = np.asarray(run_y)
            run_pred = np.asarray(run_pred)
            conf_mx = confusion_matrix(run_y, run_pred)
            cur_c_labs_bin = np.asarray([0] *len(run_y))
            cur_c_preds_bin = np.asarray([0] *len(run_y))

            # Need to binarize
            cur_c_preds_bin[run_pred == c] = 1
            cur_c_labs_bin[run_y == c] = 1
            f1 = f1_score(cur_c_labs_bin, cur_c_preds_bin)
            precision = precision_score(cur_c_labs_bin, cur_c_preds_bin)
            recall = recall_score(cur_c_labs_bin, cur_c_preds_bin)

            test_metrics['test_' + label_unmapping[c] + '_f1'] = f1
            test_metrics['test_' + label_unmapping[c] + '_precision'] = precision
            test_metrics['test_' + label_unmapping[c] + '_recall'] = recall
            
            wandb.log({'test_' + label_unmapping[c] + '_f1': f1})
            wandb.log({'test_' + label_unmapping[c] + '_precision': precision})
            wandb.log({'test_' + label_unmapping[c] + '_recall': recall})

        
        plot_confusion_matrix(conf_mx,fold,classnames, title='Test Data Confusion matrix - ViewNet', 
                                      cmap=plt.cm.Blues)
        
        average_types = ['macro', 'micro', 'weighted']
        average_metrics_to_log = ['precision', 'recall', 'f1score', 'support']
        average_dict = {'fold': fold}
        for av in average_types:
            results_tuple = precision_recall_fscore_support(run_y, run_pred, average=av)
            for m in range(len(average_metrics_to_log)):      
                #average_dict[ 'test_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
                test_metrics['test_'+ average_metrics_to_log[m] +'_average_' + av] = results_tuple[m]
        test_metrics['test_acc_average'] = accuracy_score(run_y, run_pred)                  
        #average_dict[phase + '_acc_average'] = accuracy_score(running_labels, running_preds)     
        #wandb.log(test_metrics)
        wandb.config.update(test_metrics)

        
        # saving results in df
        final_df['IDs_iter_' + str(fold)] = np.array(run_ids)
        final_df['labels_iter_' + str(fold)] = np.array(run_y)
        final_df['preds_iter_' + str(fold)] = np.array(run_pred)
        
        test_metrics_per_fold.append(test_metrics)
        
                

    # Calculate the performance metrics on the best model in each fold
    wandb.init(project=project_name, entity=wandb_username, name=local_username + '_ALL', group=random_str)
    config_dict['fold'] = -1
#     wandb.config.update(config_dict)
#     wandb.config.update(network_configs)
    final_df.to_csv (r'final_viewnet_data.csv', index = False, header=True)

    metrics_all = {}
    for fold in best_metrics_per_fold:
        for key in fold:
            if key not in metrics_all:
                metrics_all[key] = [fold[key]]
            else:
                metrics_all[key].append(fold[key]) 
                
    # add test data to metrics_all
    for fold in test_metrics_per_fold:
        for key in fold:
            if key not in metrics_all:
                metrics_all[key] = [fold[key]]
            else:
                metrics_all[key].append(fold[key]) 

    metrics_to_log = {}
    for m in metrics_all:
        try:
            
            metric_list = np.asarray(metrics_all[m])
    #         print(m)
    #         print(metric_list)
    #         print(type(metric_list))
            metrics_to_log[m + '_mean'] = metric_list.mean()    
            metrics_to_log[m + '_stdev'] = metric_list.std()
        except: 
            pass
    wandb.config.update(metrics_to_log)
    return final_df
  

In [None]:
repetitions = 1

conv1_filters = 8
conv2_filters = 16
conv3_filters = 32
linear1_size = 512

dropout = 0.25
lr = 0.0005
wd = 0.001
amsgrad = False

config_string = f"{conv1_filters}_{conv2_filters}_{conv3_filters}_{linear1_size}_{dropout}_{lr}_{wd}_{amsgrad}"
model_ft = ViewNet(num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout)
run_configs = {'lr': lr, 'wd': wd, 'amsgrad': amsgrad,'dropout': dropout, 
              'conv1_filters': conv1_filters, 'conv2_filters': conv2_filters, 
              'conv3_filters': conv3_filters, 'linear1_size': linear1_size }

final_df = train5fold(run_configs, model_ft, lr, wd, amsgrad, repetitions)

# model_ft = ViewNet(num_classes, conv1_filters, conv2_filters, conv3_filters, linear1_size, dropout)
# print(model_ft)
# num_parameters = sum(p.numel() for p in model_ft.parameters())
# print(num_parameters)

conv_output:  512


wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


	 conv1.weight
	 conv1.bias
	 conv2.weight
	 conv2.bias
	 conv3.weight
	 conv3.bias
	 linear1.weight
	 linear1.bias
	 linear2.weight
	 linear2.bias
Epoch 1/100
------------------------------------------------------
train loss:	1.6040 | train acc:	0.3600

              precision    recall  f1-score   support

           0       0.36      0.93      0.52      3724
           1       0.24      0.03      0.05      1813
           2       0.00      0.00      0.00      1218
           3       0.33      0.18      0.23      1895
           4       0.31      0.00      0.01      1359
           5       0.56      0.12      0.20      1072

    accuracy                           0.36     11081
   macro avg       0.30      0.21      0.17     11081
weighted avg       0.31      0.36      0.24     11081

val loss:	1.5856 | val acc:	0.3903

{'val_Other_f1': 0.5222300635144672, 'val_Other_precision': 0.4582043343653251, 'val_Other_recall': 0.6070549630844955, 'val_Saggital_Right_f1': 0.0, 'val_Saggital_Ri

val loss:	1.4828 | val acc:	0.4432

{'val_Other_f1': 0.537052456286428, 'val_Other_precision': 0.5452240067624683, 'val_Other_recall': 0.5291222313371616, 'val_Saggital_Right_f1': 0.3433333333333333, 'val_Saggital_Right_precision': 0.5478723404255319, 'val_Saggital_Right_recall': 0.25, 'val_Transverse_Right_f1': 0.0650887573964497, 'val_Transverse_Right_precision': 0.6470588235294118, 'val_Transverse_Right_recall': 0.03426791277258567, 'val_Saggital_Left_f1': 0.4362469927826785, 'val_Saggital_Left_precision': 0.317016317016317, 'val_Saggital_Left_recall': 0.699228791773779, 'val_Transverse_Left_f1': 0.2464646464646465, 'val_Transverse_Left_precision': 0.32972972972972975, 'val_Transverse_Left_recall': 0.1967741935483871, 'val_Bladder_f1': 0.5446428571428572, 'val_Bladder_precision': 0.4103139013452915, 'val_Bladder_recall': 0.8097345132743363, 'val_precision_average_macro': 0.4662025198014584, 'val_recall_average_macro': 0.4198546071177083, 'val_f1score_average_macro': 0.36213817390106

train loss:	1.2192 | train acc:	0.5329

              precision    recall  f1-score   support

           0       0.55      0.74      0.63      3724
           1       0.52      0.40      0.45      1813
           2       0.49      0.22      0.30      1218
           3       0.46      0.49      0.48      1895
           4       0.45      0.33      0.38      1359
           5       0.69      0.73      0.71      1072

    accuracy                           0.53     11081
   macro avg       0.53      0.48      0.49     11081
weighted avg       0.52      0.53      0.52     11081

val loss:	1.4363 | val acc:	0.4720

{'val_Other_f1': 0.49758919961427195, 'val_Other_precision': 0.6035087719298246, 'val_Other_recall': 0.42329778506972926, 'val_Saggital_Right_f1': 0.5320334261838441, 'val_Saggital_Right_precision': 0.6241830065359477, 'val_Saggital_Right_recall': 0.46359223300970875, 'val_Transverse_Right_f1': 0.375249500998004, 'val_Transverse_Right_precision': 0.5222222222222223, 'val_Transve

train loss:	1.1515 | train acc:	0.5579

              precision    recall  f1-score   support

           0       0.56      0.74      0.64      3724
           1       0.57      0.41      0.48      1813
           2       0.48      0.27      0.34      1218
           3       0.50      0.53      0.52      1895
           4       0.50      0.38      0.43      1359
           5       0.73      0.76      0.74      1072

    accuracy                           0.56     11081
   macro avg       0.56      0.52      0.53     11081
weighted avg       0.55      0.56      0.54     11081

val loss:	1.3434 | val acc:	0.5252

{'val_Other_f1': 0.5653679653679654, 'val_Other_precision': 0.5985334555453712, 'val_Other_recall': 0.5356849876948319, 'val_Saggital_Right_f1': 0.592686002522068, 'val_Saggital_Right_precision': 0.6167979002624672, 'val_Saggital_Right_recall': 0.5703883495145631, 'val_Transverse_Right_f1': 0.35827664399092973, 'val_Transverse_Right_precision': 0.6583333333333333, 'val_Transvers

train loss:	1.0827 | train acc:	0.5880

              precision    recall  f1-score   support

           0       0.59      0.75      0.66      3724
           1       0.58      0.46      0.51      1813
           2       0.55      0.33      0.41      1218
           3       0.53      0.55      0.54      1895
           4       0.53      0.45      0.49      1359
           5       0.77      0.79      0.78      1072

    accuracy                           0.59     11081
   macro avg       0.59      0.55      0.57     11081
weighted avg       0.58      0.59      0.58     11081

val loss:	1.3244 | val acc:	0.5137

{'val_Other_f1': 0.5335144927536232, 'val_Other_precision': 0.5955510616784631, 'val_Other_recall': 0.4831829368334701, 'val_Saggital_Right_f1': 0.576530612244898, 'val_Saggital_Right_precision': 0.6075268817204301, 'val_Saggital_Right_recall': 0.5485436893203883, 'val_Transverse_Right_f1': 0.444, 'val_Transverse_Right_precision': 0.6201117318435754, 'val_Transverse_Right_recall

train loss:	1.0443 | train acc:	0.6035

              precision    recall  f1-score   support

           0       0.60      0.75      0.67      3724
           1       0.62      0.47      0.54      1813
           2       0.58      0.38      0.46      1218
           3       0.54      0.56      0.55      1895
           4       0.56      0.47      0.51      1359
           5       0.78      0.81      0.79      1072

    accuracy                           0.60     11081
   macro avg       0.61      0.57      0.59     11081
weighted avg       0.60      0.60      0.60     11081

val loss:	1.2849 | val acc:	0.5308

{'val_Other_f1': 0.5501084598698481, 'val_Other_precision': 0.583793738489871, 'val_Other_recall': 0.5200984413453651, 'val_Saggital_Right_f1': 0.615979381443299, 'val_Saggital_Right_precision': 0.6565934065934066, 'val_Saggital_Right_recall': 0.5800970873786407, 'val_Transverse_Right_f1': 0.4765342960288809, 'val_Transverse_Right_precision': 0.5665236051502146, 'val_Transverse_

train loss:	0.9973 | train acc:	0.6197

              precision    recall  f1-score   support

           0       0.62      0.76      0.68      3724
           1       0.64      0.49      0.56      1813
           2       0.60      0.41      0.48      1218
           3       0.56      0.59      0.57      1895
           4       0.56      0.49      0.53      1359
           5       0.79      0.82      0.80      1072

    accuracy                           0.62     11081
   macro avg       0.63      0.59      0.60     11081
weighted avg       0.62      0.62      0.61     11081

val loss:	1.2933 | val acc:	0.5189

{'val_Other_f1': 0.5150501672240804, 'val_Other_precision': 0.6167048054919908, 'val_Other_recall': 0.4421657095980312, 'val_Saggital_Right_f1': 0.5659863945578231, 'val_Saggital_Right_precision': 0.6439628482972136, 'val_Saggital_Right_recall': 0.5048543689320388, 'val_Transverse_Right_f1': 0.49683544303797467, 'val_Transverse_Right_precision': 0.5048231511254019, 'val_Transver

train loss:	0.9673 | train acc:	0.6304

              precision    recall  f1-score   support

           0       0.63      0.76      0.69      3724
           1       0.65      0.51      0.57      1813
           2       0.61      0.44      0.51      1218
           3       0.58      0.60      0.59      1895
           4       0.56      0.50      0.53      1359
           5       0.79      0.82      0.80      1072

    accuracy                           0.63     11081
   macro avg       0.64      0.60      0.61     11081
weighted avg       0.63      0.63      0.62     11081

val loss:	1.2321 | val acc:	0.5540

{'val_Other_f1': 0.5765993265993264, 'val_Other_precision': 0.592048401037165, 'val_Other_recall': 0.5619360131255127, 'val_Saggital_Right_f1': 0.6297760210803689, 'val_Saggital_Right_precision': 0.6887608069164265, 'val_Saggital_Right_recall': 0.5800970873786407, 'val_Transverse_Right_f1': 0.5280289330922243, 'val_Transverse_Right_precision': 0.6293103448275862, 'val_Transverse

val loss:	1.2209 | val acc:	0.5447

{'val_Other_f1': 0.5506849315068494, 'val_Other_precision': 0.6210092687950567, 'val_Other_recall': 0.4946677604593929, 'val_Saggital_Right_f1': 0.6002372479240806, 'val_Saggital_Right_precision': 0.5870069605568445, 'val_Saggital_Right_recall': 0.6140776699029126, 'val_Transverse_Right_f1': 0.5101351351351352, 'val_Transverse_Right_precision': 0.5571955719557196, 'val_Transverse_Right_recall': 0.470404984423676, 'val_Saggital_Left_f1': 0.49568965517241376, 'val_Saggital_Left_precision': 0.4267161410018553, 'val_Saggital_Left_recall': 0.5912596401028277, 'val_Transverse_Left_f1': 0.4228571428571429, 'val_Transverse_Left_precision': 0.37948717948717947, 'val_Transverse_Left_recall': 0.4774193548387097, 'val_Bladder_f1': 0.7265469061876249, 'val_Bladder_precision': 0.6618181818181819, 'val_Bladder_recall': 0.8053097345132744, 'val_precision_average_macro': 0.5388722172691396, 'val_recall_average_macro': 0.5755231907067989, 'val_f1score_average_macro': 

train loss:	0.8777 | train acc:	0.6748

              precision    recall  f1-score   support

           0       0.67      0.78      0.72      3724
           1       0.71      0.59      0.64      1813
           2       0.66      0.51      0.58      1218
           3       0.62      0.63      0.62      1895
           4       0.62      0.57      0.59      1359
           5       0.83      0.85      0.84      1072

    accuracy                           0.67     11081
   macro avg       0.68      0.66      0.67     11081
weighted avg       0.67      0.67      0.67     11081

val loss:	1.2326 | val acc:	0.5401

{'val_Other_f1': 0.5676741130091985, 'val_Other_precision': 0.6090225563909775, 'val_Other_recall': 0.5315832649712879, 'val_Saggital_Right_f1': 0.5345622119815668, 'val_Saggital_Right_precision': 0.7280334728033473, 'val_Saggital_Right_recall': 0.4223300970873786, 'val_Transverse_Right_f1': 0.5211505922165821, 'val_Transverse_Right_precision': 0.5703703703703704, 'val_Transvers

val loss:	1.1967 | val acc:	0.5551

{'val_Other_f1': 0.5865344898802148, 'val_Other_precision': 0.5906821963394343, 'val_Other_recall': 0.5824446267432322, 'val_Saggital_Right_f1': 0.5681818181818182, 'val_Saggital_Right_precision': 0.684931506849315, 'val_Saggital_Right_recall': 0.4854368932038835, 'val_Transverse_Right_f1': 0.488245931283906, 'val_Transverse_Right_precision': 0.5818965517241379, 'val_Transverse_Right_recall': 0.4205607476635514, 'val_Saggital_Left_f1': 0.5075268817204301, 'val_Saggital_Left_precision': 0.43622920517560076, 'val_Saggital_Left_recall': 0.6066838046272494, 'val_Transverse_Left_f1': 0.39999999999999997, 'val_Transverse_Left_precision': 0.39076923076923076, 'val_Transverse_Left_recall': 0.4096774193548387, 'val_Bladder_f1': 0.7397260273972601, 'val_Bladder_precision': 0.6631578947368421, 'val_Bladder_recall': 0.8362831858407079, 'val_precision_average_macro': 0.5579444309324267, 'val_recall_average_macro': 0.5568477795722439, 'val_f1score_average_macro': 

train loss:	0.8048 | train acc:	0.7000

              precision    recall  f1-score   support

           0       0.70      0.79      0.74      3724
           1       0.72      0.62      0.67      1813
           2       0.70      0.58      0.63      1218
           3       0.65      0.66      0.66      1895
           4       0.64      0.59      0.62      1359
           5       0.84      0.86      0.85      1072

    accuracy                           0.70     11081
   macro avg       0.71      0.69      0.69     11081
weighted avg       0.70      0.70      0.70     11081

val loss:	1.1752 | val acc:	0.5568

{'val_Other_f1': 0.5660036166365282, 'val_Other_precision': 0.6304128902316214, 'val_Other_recall': 0.5135356849876949, 'val_Saggital_Right_f1': 0.600715137067938, 'val_Saggital_Right_precision': 0.5901639344262295, 'val_Saggital_Right_recall': 0.6116504854368932, 'val_Transverse_Right_f1': 0.5401709401709403, 'val_Transverse_Right_precision': 0.5984848484848485, 'val_Transverse

val loss:	1.1906 | val acc:	0.5440

{'val_Other_f1': 0.5539906103286386, 'val_Other_precision': 0.6476399560922064, 'val_Other_recall': 0.48400328137817883, 'val_Saggital_Right_f1': 0.5883870967741937, 'val_Saggital_Right_precision': 0.628099173553719, 'val_Saggital_Right_recall': 0.5533980582524272, 'val_Transverse_Right_f1': 0.5186385737439222, 'val_Transverse_Right_precision': 0.5405405405405406, 'val_Transverse_Right_recall': 0.4984423676012461, 'val_Saggital_Left_f1': 0.4952741020793951, 'val_Saggital_Left_precision': 0.39162929745889385, 'val_Saggital_Left_recall': 0.6735218508997429, 'val_Transverse_Left_f1': 0.43514644351464427, 'val_Transverse_Left_precision': 0.3832923832923833, 'val_Transverse_Left_recall': 0.5032258064516129, 'val_Bladder_f1': 0.7396061269146609, 'val_Bladder_precision': 0.7316017316017316, 'val_Bladder_recall': 0.7477876106194691, 'val_precision_average_macro': 0.553800513756579, 'val_recall_average_macro': 0.5767298292004461, 'val_f1score_average_macro': 

train loss:	0.7264 | train acc:	0.7282

              precision    recall  f1-score   support

           0       0.73      0.79      0.76      3724
           1       0.74      0.65      0.69      1813
           2       0.71      0.62      0.66      1218
           3       0.68      0.70      0.69      1895
           4       0.69      0.65      0.67      1359
           5       0.86      0.88      0.87      1072

    accuracy                           0.73     11081
   macro avg       0.74      0.72      0.73     11081
weighted avg       0.73      0.73      0.73     11081

val loss:	1.1408 | val acc:	0.5582

{'val_Other_f1': 0.5886730053741216, 'val_Other_precision': 0.5933333333333334, 'val_Other_recall': 0.5840853158326497, 'val_Saggital_Right_f1': 0.5756756756756757, 'val_Saggital_Right_precision': 0.649390243902439, 'val_Saggital_Right_recall': 0.5169902912621359, 'val_Transverse_Right_f1': 0.5251396648044693, 'val_Transverse_Right_precision': 0.6527777777777778, 'val_Transverse

train loss:	0.7083 | train acc:	0.7335

              precision    recall  f1-score   support

           0       0.73      0.80      0.76      3724
           1       0.76      0.66      0.71      1813
           2       0.72      0.63      0.67      1218
           3       0.69      0.71      0.70      1895
           4       0.69      0.66      0.68      1359
           5       0.85      0.88      0.87      1072

    accuracy                           0.73     11081
   macro avg       0.74      0.72      0.73     11081
weighted avg       0.73      0.73      0.73     11081

val loss:	1.1749 | val acc:	0.5534

{'val_Other_f1': 0.5769403319874383, 'val_Other_precision': 0.6366336633663366, 'val_Other_recall': 0.527481542247744, 'val_Saggital_Right_f1': 0.5622119815668204, 'val_Saggital_Right_precision': 0.5350877192982456, 'val_Saggital_Right_recall': 0.5922330097087378, 'val_Transverse_Right_f1': 0.503448275862069, 'val_Transverse_Right_precision': 0.5637065637065637, 'val_Transverse_

val loss:	1.1614 | val acc:	0.5537

{'val_Other_f1': 0.5835229858898497, 'val_Other_precision': 0.6554192229038854, 'val_Other_recall': 0.5258408531583265, 'val_Saggital_Right_f1': 0.5592515592515593, 'val_Saggital_Right_precision': 0.4890909090909091, 'val_Saggital_Right_recall': 0.6529126213592233, 'val_Transverse_Right_f1': 0.527736131934033, 'val_Transverse_Right_precision': 0.5086705202312138, 'val_Transverse_Right_recall': 0.5482866043613707, 'val_Saggital_Left_f1': 0.465, 'val_Saggital_Left_precision': 0.45255474452554745, 'val_Saggital_Left_recall': 0.4781491002570694, 'val_Transverse_Left_f1': 0.4457652303120357, 'val_Transverse_Left_precision': 0.4132231404958678, 'val_Transverse_Left_recall': 0.4838709677419355, 'val_Bladder_f1': 0.7516483516483516, 'val_Bladder_precision': 0.7467248908296943, 'val_Bladder_recall': 0.7566371681415929, 'val_precision_average_macro': 0.5442805713461863, 'val_recall_average_macro': 0.5742828858365864, 'val_f1score_average_macro': 0.555487376505

train loss:	0.6529 | train acc:	0.7531

              precision    recall  f1-score   support

           0       0.75      0.81      0.77      3724
           1       0.76      0.69      0.72      1813
           2       0.75      0.67      0.71      1218
           3       0.72      0.72      0.72      1895
           4       0.71      0.70      0.71      1359
           5       0.88      0.89      0.89      1072

    accuracy                           0.75     11081
   macro avg       0.76      0.75      0.75     11081
weighted avg       0.75      0.75      0.75     11081

val loss:	1.1834 | val acc:	0.5391

{'val_Other_f1': 0.5273972602739726, 'val_Other_precision': 0.6533333333333333, 'val_Other_recall': 0.4421657095980312, 'val_Saggital_Right_f1': 0.5798045602605864, 'val_Saggital_Right_precision': 0.5245579567779961, 'val_Saggital_Right_recall': 0.6480582524271845, 'val_Transverse_Right_f1': 0.5273010920436817, 'val_Transverse_Right_precision': 0.528125, 'val_Transverse_Right_re

val loss:	1.1675 | val acc:	0.5471

{'val_Other_f1': 0.5471609572970437, 'val_Other_precision': 0.6392543859649122, 'val_Other_recall': 0.4782608695652174, 'val_Saggital_Right_f1': 0.5854241338112306, 'val_Saggital_Right_precision': 0.5764705882352941, 'val_Saggital_Right_recall': 0.5946601941747572, 'val_Transverse_Right_f1': 0.550354609929078, 'val_Transverse_Right_precision': 0.5052083333333334, 'val_Transverse_Right_recall': 0.6043613707165109, 'val_Saggital_Left_f1': 0.48556149732620324, 'val_Saggital_Left_precision': 0.4157509157509158, 'val_Saggital_Left_recall': 0.583547557840617, 'val_Transverse_Left_f1': 0.43797856049004597, 'val_Transverse_Left_precision': 0.41690962099125367, 'val_Transverse_Left_recall': 0.4612903225806452, 'val_Bladder_f1': 0.7383367139959431, 'val_Bladder_precision': 0.6816479400749064, 'val_Bladder_recall': 0.8053097345132744, 'val_precision_average_macro': 0.5392069640584358, 'val_recall_average_macro': 0.587905008231837, 'val_f1score_average_macro': 0

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


	 conv1.weight
	 conv1.bias
	 conv2.weight
	 conv2.bias
	 conv3.weight
	 conv3.bias
	 linear1.weight
	 linear1.bias
	 linear2.weight
	 linear2.bias
Epoch 1/100
------------------------------------------------------
train loss:	1.6251 | train acc:	0.3537

              precision    recall  f1-score   support

           0       0.35      0.97      0.52      3756
           1       0.27      0.02      0.04      1813
           2       0.30      0.00      0.00      1284
           3       0.29      0.06      0.10      1717
           4       0.00      0.00      0.00      1286
           5       0.61      0.08      0.13      1009

    accuracy                           0.35     10865
   macro avg       0.30      0.19      0.13     10865
weighted avg       0.30      0.35      0.21     10865

val loss:	1.5367 | val acc:	0.4536

{'val_Other_f1': 0.5918023582257159, 'val_Other_precision': 0.4437894736842105, 'val_Other_recall': 0.8879528222409435, 'val_Saggital_Right_f1': 0.0, 'val_Saggital_Ri

val loss:	1.4088 | val acc:	0.5587

{'val_Other_f1': 0.6516690856313498, 'val_Other_precision': 0.5723390694710007, 'val_Other_recall': 0.7565290648694187, 'val_Saggital_Right_f1': 0.3735144312393888, 'val_Saggital_Right_precision': 0.6214689265536724, 'val_Saggital_Right_recall': 0.2669902912621359, 'val_Transverse_Right_f1': 0.33248081841432225, 'val_Transverse_Right_precision': 0.47794117647058826, 'val_Transverse_Right_recall': 0.2549019607843137, 'val_Saggital_Left_f1': 0.5341196293176074, 'val_Saggital_Left_precision': 0.5112903225806451, 'val_Saggital_Left_recall': 0.5590828924162258, 'val_Transverse_Left_f1': 0.3268892794376098, 'val_Transverse_Left_precision': 0.5, 'val_Transverse_Left_recall': 0.24281984334203655, 'val_Bladder_f1': 0.7060518731988473, 'val_Bladder_precision': 0.6049382716049383, 'val_Bladder_recall': 0.8477508650519031, 'val_precision_average_macro': 0.5479962944468074, 'val_recall_average_macro': 0.4880124862876723, 'val_f1score_average_macro': 0.48745418620

train loss:	1.2648 | train acc:	0.5047

              precision    recall  f1-score   support

           0       0.50      0.74      0.60      3756
           1       0.51      0.41      0.46      1813
           2       0.46      0.21      0.29      1284
           3       0.43      0.40      0.41      1717
           4       0.43      0.23      0.30      1286
           5       0.71      0.68      0.69      1009

    accuracy                           0.50     10865
   macro avg       0.51      0.45      0.46     10865
weighted avg       0.50      0.50      0.48     10865

val loss:	1.3191 | val acc:	0.5930

{'val_Other_f1': 0.6787193973634651, 'val_Other_precision': 0.6137602179836512, 'val_Other_recall': 0.7590564448188711, 'val_Saggital_Right_f1': 0.5452513966480447, 'val_Saggital_Right_precision': 0.505175983436853, 'val_Saggital_Right_recall': 0.5922330097087378, 'val_Transverse_Right_f1': 0.44776119402985076, 'val_Transverse_Right_precision': 0.49065420560747663, 'val_Transver

val loss:	1.2634 | val acc:	0.6104

{'val_Other_f1': 0.6762534006995725, 'val_Other_precision': 0.6277056277056277, 'val_Other_recall': 0.7329401853411963, 'val_Saggital_Right_f1': 0.5720620842572063, 'val_Saggital_Right_precision': 0.5265306122448979, 'val_Saggital_Right_recall': 0.6262135922330098, 'val_Transverse_Right_f1': 0.4196185286103542, 'val_Transverse_Right_precision': 0.6875, 'val_Transverse_Right_recall': 0.30196078431372547, 'val_Saggital_Left_f1': 0.5567765567765567, 'val_Saggital_Left_precision': 0.579047619047619, 'val_Saggital_Left_recall': 0.5361552028218695, 'val_Transverse_Left_f1': 0.4530744336569579, 'val_Transverse_Left_precision': 0.5957446808510638, 'val_Transverse_Left_recall': 0.36553524804177545, 'val_Bladder_f1': 0.7539432176656151, 'val_Bladder_precision': 0.6927536231884058, 'val_Bladder_recall': 0.8269896193771626, 'val_precision_average_macro': 0.6182136938396025, 'val_recall_average_macro': 0.5649657720214565, 'val_f1score_average_macro': 0.5719547036

train loss:	1.1545 | train acc:	0.5504

              precision    recall  f1-score   support

           0       0.54      0.74      0.62      3756
           1       0.61      0.44      0.51      1813
           2       0.53      0.32      0.40      1284
           3       0.48      0.49      0.49      1717
           4       0.46      0.32      0.38      1286
           5       0.76      0.75      0.75      1009

    accuracy                           0.55     10865
   macro avg       0.56      0.51      0.52     10865
weighted avg       0.55      0.55      0.54     10865

val loss:	1.2648 | val acc:	0.5878

{'val_Other_f1': 0.6606012658227848, 'val_Other_precision': 0.6226696495152871, 'val_Other_recall': 0.7034540859309183, 'val_Saggital_Right_f1': 0.5556650246305418, 'val_Saggital_Right_precision': 0.46766169154228854, 'val_Saggital_Right_recall': 0.6844660194174758, 'val_Transverse_Right_f1': 0.49569707401032703, 'val_Transverse_Right_precision': 0.44171779141104295, 'val_Transv

val loss:	1.2145 | val acc:	0.6198

{'val_Other_f1': 0.6707086017430846, 'val_Other_precision': 0.609504132231405, 'val_Other_recall': 0.7455770850884583, 'val_Saggital_Right_f1': 0.5823095823095823, 'val_Saggital_Right_precision': 0.5895522388059702, 'val_Saggital_Right_recall': 0.5752427184466019, 'val_Transverse_Right_f1': 0.5521472392638036, 'val_Transverse_Right_precision': 0.5769230769230769, 'val_Transverse_Right_recall': 0.5294117647058824, 'val_Saggital_Left_f1': 0.5310621242484971, 'val_Saggital_Left_precision': 0.6148491879350348, 'val_Saggital_Left_recall': 0.4673721340388007, 'val_Transverse_Left_f1': 0.5372714486638538, 'val_Transverse_Left_precision': 0.5823170731707317, 'val_Transverse_Left_recall': 0.49869451697127937, 'val_Bladder_f1': 0.7626168224299065, 'val_Bladder_precision': 0.8292682926829268, 'val_Bladder_recall': 0.7058823529411765, 'val_precision_average_macro': 0.633735666958191, 'val_recall_average_macro': 0.5870300953653665, 'val_f1score_average_macro': 0.

train loss:	1.0541 | train acc:	0.5974

              precision    recall  f1-score   support

           0       0.58      0.74      0.65      3756
           1       0.65      0.50      0.56      1813
           2       0.59      0.40      0.48      1284
           3       0.53      0.54      0.53      1717
           4       0.52      0.43      0.47      1286
           5       0.79      0.81      0.80      1009

    accuracy                           0.60     10865
   macro avg       0.61      0.57      0.58     10865
weighted avg       0.60      0.60      0.59     10865

val loss:	1.1825 | val acc:	0.6094

{'val_Other_f1': 0.6709863210943124, 'val_Other_precision': 0.5857950974230044, 'val_Other_recall': 0.785172704296546, 'val_Saggital_Right_f1': 0.5627705627705628, 'val_Saggital_Right_precision': 0.693950177935943, 'val_Saggital_Right_recall': 0.4733009708737864, 'val_Transverse_Right_f1': 0.513888888888889, 'val_Transverse_Right_precision': 0.6271186440677966, 'val_Transverse_R

train loss:	1.0229 | train acc:	0.6060

              precision    recall  f1-score   support

           0       0.59      0.74      0.66      3756
           1       0.66      0.50      0.57      1813
           2       0.60      0.43      0.50      1284
           3       0.54      0.54      0.54      1717
           4       0.53      0.45      0.49      1286
           5       0.80      0.82      0.81      1009

    accuracy                           0.61     10865
   macro avg       0.62      0.58      0.59     10865
weighted avg       0.61      0.61      0.60     10865

val loss:	1.1729 | val acc:	0.6256

{'val_Other_f1': 0.6752000000000001, 'val_Other_precision': 0.6428027418126429, 'val_Other_recall': 0.7110362257792755, 'val_Saggital_Right_f1': 0.5766423357664233, 'val_Saggital_Right_precision': 0.5780487804878048, 'val_Saggital_Right_recall': 0.5752427184466019, 'val_Transverse_Right_f1': 0.5596868884540117, 'val_Transverse_Right_precision': 0.55859375, 'val_Transverse_Right_

train loss:	0.9821 | train acc:	0.6223

              precision    recall  f1-score   support

           0       0.60      0.75      0.67      3756
           1       0.67      0.51      0.58      1813
           2       0.63      0.49      0.55      1284
           3       0.56      0.56      0.56      1717
           4       0.56      0.45      0.50      1286
           5       0.80      0.82      0.81      1009

    accuracy                           0.62     10865
   macro avg       0.64      0.60      0.61     10865
weighted avg       0.62      0.62      0.62     10865

val loss:	1.1260 | val acc:	0.6182

{'val_Other_f1': 0.6684698608964452, 'val_Other_precision': 0.6174161313347609, 'val_Other_recall': 0.7287278854254423, 'val_Saggital_Right_f1': 0.5937161430119177, 'val_Saggital_Right_precision': 0.5362035225048923, 'val_Saggital_Right_recall': 0.6650485436893204, 'val_Transverse_Right_f1': 0.5845511482254697, 'val_Transverse_Right_precision': 0.625, 'val_Transverse_Right_recal

train loss:	0.9459 | train acc:	0.6341

              precision    recall  f1-score   support

           0       0.62      0.75      0.68      3756
           1       0.68      0.54      0.60      1813
           2       0.64      0.50      0.56      1284
           3       0.58      0.58      0.58      1717
           4       0.57      0.49      0.52      1286
           5       0.81      0.81      0.81      1009

    accuracy                           0.63     10865
   macro avg       0.65      0.61      0.63     10865
weighted avg       0.64      0.63      0.63     10865

val loss:	1.1296 | val acc:	0.6275

{'val_Other_f1': 0.6821583300512013, 'val_Other_precision': 0.6405325443786982, 'val_Other_recall': 0.7295703454085931, 'val_Saggital_Right_f1': 0.5891647855530475, 'val_Saggital_Right_precision': 0.5506329113924051, 'val_Saggital_Right_recall': 0.633495145631068, 'val_Transverse_Right_f1': 0.605072463768116, 'val_Transverse_Right_precision': 0.5622895622895623, 'val_Transverse_

train loss:	0.9117 | train acc:	0.6499

              precision    recall  f1-score   support

           0       0.63      0.75      0.69      3756
           1       0.70      0.57      0.62      1813
           2       0.65      0.53      0.58      1284
           3       0.59      0.59      0.59      1717
           4       0.60      0.53      0.56      1286
           5       0.82      0.83      0.82      1009

    accuracy                           0.65     10865
   macro avg       0.66      0.63      0.64     10865
weighted avg       0.65      0.65      0.65     10865

val loss:	1.1025 | val acc:	0.6405

{'val_Other_f1': 0.684997978164173, 'val_Other_precision': 0.6586314152410575, 'val_Other_recall': 0.7135636057287279, 'val_Saggital_Right_f1': 0.6074766355140188, 'val_Saggital_Right_precision': 0.5855855855855856, 'val_Saggital_Right_recall': 0.6310679611650486, 'val_Transverse_Right_f1': 0.6088631984585742, 'val_Transverse_Right_precision': 0.5984848484848485, 'val_Transverse

train loss:	0.8760 | train acc:	0.6627

              precision    recall  f1-score   support

           0       0.65      0.75      0.70      3756
           1       0.70      0.59      0.64      1813
           2       0.67      0.54      0.60      1284
           3       0.61      0.62      0.61      1717
           4       0.60      0.55      0.58      1286
           5       0.82      0.84      0.83      1009

    accuracy                           0.66     10865
   macro avg       0.68      0.65      0.66     10865
weighted avg       0.66      0.66      0.66     10865

val loss:	1.0909 | val acc:	0.6350

{'val_Other_f1': 0.6813532651455547, 'val_Other_precision': 0.6391143911439114, 'val_Other_recall': 0.7295703454085931, 'val_Saggital_Right_f1': 0.5988700564971753, 'val_Saggital_Right_precision': 0.5602536997885835, 'val_Saggital_Right_recall': 0.6432038834951457, 'val_Transverse_Right_f1': 0.6098081023454157, 'val_Transverse_Right_precision': 0.6682242990654206, 'val_Transvers

val loss:	1.0550 | val acc:	0.6431

{'val_Other_f1': 0.6945975744211687, 'val_Other_precision': 0.6160365058670143, 'val_Other_recall': 0.7961246840775064, 'val_Saggital_Right_f1': 0.5908496732026143, 'val_Saggital_Right_precision': 0.6402266288951841, 'val_Saggital_Right_recall': 0.5485436893203883, 'val_Transverse_Right_f1': 0.5700934579439252, 'val_Transverse_Right_precision': 0.7052023121387283, 'val_Transverse_Right_recall': 0.47843137254901963, 'val_Saggital_Left_f1': 0.564, 'val_Saggital_Left_precision': 0.651270207852194, 'val_Saggital_Left_recall': 0.4973544973544973, 'val_Transverse_Left_f1': 0.5140324963072377, 'val_Transverse_Left_precision': 0.5918367346938775, 'val_Transverse_Left_recall': 0.45430809399477806, 'val_Bladder_f1': 0.8067226890756302, 'val_Bladder_precision': 0.7843137254901961, 'val_Bladder_recall': 0.8304498269896193, 'val_precision_average_macro': 0.6648143524895324, 'val_recall_average_macro': 0.6008686940476348, 'val_f1score_average_macro': 0.62338264849

val loss:	1.0851 | val acc:	0.6243

{'val_Other_f1': 0.6616915422885573, 'val_Other_precision': 0.6514285714285715, 'val_Other_recall': 0.6722830665543387, 'val_Saggital_Right_f1': 0.5952380952380952, 'val_Saggital_Right_precision': 0.5841121495327103, 'val_Saggital_Right_recall': 0.6067961165048543, 'val_Transverse_Right_f1': 0.5755102040816328, 'val_Transverse_Right_precision': 0.6, 'val_Transverse_Right_recall': 0.5529411764705883, 'val_Saggital_Left_f1': 0.549808429118774, 'val_Saggital_Left_precision': 0.6016771488469602, 'val_Saggital_Left_recall': 0.5061728395061729, 'val_Transverse_Left_f1': 0.5383678440925701, 'val_Transverse_Left_precision': 0.5045662100456622, 'val_Transverse_Left_recall': 0.577023498694517, 'val_Bladder_f1': 0.8082901554404145, 'val_Bladder_precision': 0.8068965517241379, 'val_Bladder_recall': 0.8096885813148789, 'val_precision_average_macro': 0.624780105263007, 'val_recall_average_macro': 0.6208175465075584, 'val_f1score_average_macro': 0.621484378376674, 

train loss:	0.7794 | train acc:	0.6971

              precision    recall  f1-score   support

           0       0.68      0.77      0.72      3756
           1       0.74      0.64      0.68      1813
           2       0.71      0.60      0.65      1284
           3       0.65      0.64      0.65      1717
           4       0.65      0.60      0.63      1286
           5       0.84      0.85      0.85      1009

    accuracy                           0.70     10865
   macro avg       0.71      0.69      0.70     10865
weighted avg       0.70      0.70      0.70     10865

val loss:	1.0649 | val acc:	0.6376

{'val_Other_f1': 0.6802070888092393, 'val_Other_precision': 0.6450151057401813, 'val_Other_recall': 0.7194608256107835, 'val_Saggital_Right_f1': 0.594413407821229, 'val_Saggital_Right_precision': 0.5507246376811594, 'val_Saggital_Right_recall': 0.6456310679611651, 'val_Transverse_Right_f1': 0.6119733924611972, 'val_Transverse_Right_precision': 0.7040816326530612, 'val_Transverse

val loss:	1.0759 | val acc:	0.6275

{'val_Other_f1': 0.6666666666666667, 'val_Other_precision': 0.6493610223642172, 'val_Other_recall': 0.6849199663016007, 'val_Saggital_Right_f1': 0.5891829689298044, 'val_Saggital_Right_precision': 0.5601750547045952, 'val_Saggital_Right_recall': 0.6213592233009708, 'val_Transverse_Right_f1': 0.6115384615384615, 'val_Transverse_Right_precision': 0.6, 'val_Transverse_Right_recall': 0.6235294117647059, 'val_Saggital_Left_f1': 0.560303893637227, 'val_Saggital_Left_precision': 0.6069958847736625, 'val_Saggital_Left_recall': 0.5202821869488536, 'val_Transverse_Left_f1': 0.5115606936416185, 'val_Transverse_Left_precision': 0.5728155339805825, 'val_Transverse_Left_recall': 0.4621409921671018, 'val_Bladder_f1': 0.7862969004893964, 'val_Bladder_precision': 0.7438271604938271, 'val_Bladder_recall': 0.8339100346020761, 'val_precision_average_macro': 0.6221957760528142, 'val_recall_average_macro': 0.6243569691808849, 'val_f1score_average_macro': 0.620924930817195

train loss:	0.7361 | train acc:	0.7170

              precision    recall  f1-score   support

           0       0.71      0.79      0.75      3756
           1       0.73      0.66      0.69      1813
           2       0.70      0.63      0.66      1284
           3       0.68      0.66      0.67      1717
           4       0.68      0.62      0.65      1286
           5       0.86      0.89      0.87      1009

    accuracy                           0.72     10865
   macro avg       0.73      0.71      0.72     10865
weighted avg       0.72      0.72      0.72     10865

val loss:	1.0436 | val acc:	0.6382

{'val_Other_f1': 0.6675148430873621, 'val_Other_precision': 0.6720751494449189, 'val_Other_recall': 0.6630160067396799, 'val_Saggital_Right_f1': 0.5857642940490081, 'val_Saggital_Right_precision': 0.5640449438202247, 'val_Saggital_Right_recall': 0.6092233009708737, 'val_Transverse_Right_f1': 0.639344262295082, 'val_Transverse_Right_precision': 0.6695278969957081, 'val_Transverse

val loss:	1.0066 | val acc:	0.6424

{'val_Other_f1': 0.684354272337105, 'val_Other_precision': 0.6373546511627907, 'val_Other_recall': 0.7388374052232519, 'val_Saggital_Right_f1': 0.5841836734693877, 'val_Saggital_Right_precision': 0.6155913978494624, 'val_Saggital_Right_recall': 0.5558252427184466, 'val_Transverse_Right_f1': 0.617336152219873, 'val_Transverse_Right_precision': 0.6697247706422018, 'val_Transverse_Right_recall': 0.5725490196078431, 'val_Saggital_Left_f1': 0.552434456928839, 'val_Saggital_Left_precision': 0.5888223552894212, 'val_Saggital_Left_recall': 0.5202821869488536, 'val_Transverse_Left_f1': 0.5612104539202201, 'val_Transverse_Left_precision': 0.5930232558139535, 'val_Transverse_Left_recall': 0.5326370757180157, 'val_Bladder_f1': 0.8266199649737302, 'val_Bladder_precision': 0.8368794326241135, 'val_Bladder_recall': 0.8166089965397924, 'val_precision_average_macro': 0.6568993105636572, 'val_recall_average_macro': 0.6227899877927006, 'val_f1score_average_macro': 0.63

train loss:	0.6880 | train acc:	0.7414

              precision    recall  f1-score   support

           0       0.73      0.79      0.76      3756
           1       0.75      0.69      0.72      1813
           2       0.73      0.68      0.71      1284
           3       0.70      0.70      0.70      1717
           4       0.71      0.68      0.69      1286
           5       0.88      0.90      0.89      1009

    accuracy                           0.74     10865
   macro avg       0.75      0.74      0.74     10865
weighted avg       0.74      0.74      0.74     10865

val loss:	1.0498 | val acc:	0.6308

{'val_Other_f1': 0.675902602854744, 'val_Other_precision': 0.6736401673640168, 'val_Other_recall': 0.6781802864363943, 'val_Saggital_Right_f1': 0.5940828402366864, 'val_Saggital_Right_precision': 0.5796766743648961, 'val_Saggital_Right_recall': 0.6092233009708737, 'val_Transverse_Right_f1': 0.5927209705372617, 'val_Transverse_Right_precision': 0.531055900621118, 'val_Transverse_

val loss:	1.0243 | val acc:	0.6321

{'val_Other_f1': 0.6895463510848127, 'val_Other_precision': 0.6483679525222552, 'val_Other_recall': 0.7363100252737995, 'val_Saggital_Right_f1': 0.5807127882599581, 'val_Saggital_Right_precision': 0.511070110701107, 'val_Saggital_Right_recall': 0.6723300970873787, 'val_Transverse_Right_f1': 0.6025641025641025, 'val_Transverse_Right_precision': 0.6619718309859155, 'val_Transverse_Right_recall': 0.5529411764705883, 'val_Saggital_Left_f1': 0.500517063081696, 'val_Saggital_Left_precision': 0.605, 'val_Saggital_Left_recall': 0.42680776014109345, 'val_Transverse_Left_f1': 0.5470085470085471, 'val_Transverse_Left_precision': 0.6018808777429467, 'val_Transverse_Left_recall': 0.5013054830287206, 'val_Bladder_f1': 0.817857142857143, 'val_Bladder_precision': 0.8450184501845018, 'val_Bladder_recall': 0.7923875432525952, 'val_precision_average_macro': 0.6455515370227877, 'val_recall_average_macro': 0.6136803475423626, 'val_f1score_average_macro': 0.62303433247604

train loss:	0.6579 | train acc:	0.7542

              precision    recall  f1-score   support

           0       0.74      0.81      0.77      3756
           1       0.78      0.70      0.73      1813
           2       0.76      0.69      0.72      1284
           3       0.71      0.71      0.71      1717
           4       0.71      0.71      0.71      1286
           5       0.88      0.88      0.88      1009

    accuracy                           0.75     10865
   macro avg       0.76      0.75      0.76     10865
weighted avg       0.75      0.75      0.75     10865

val loss:	1.0112 | val acc:	0.6450

{'val_Other_f1': 0.6856425282544998, 'val_Other_precision': 0.6813643926788685, 'val_Other_recall': 0.6899747262005055, 'val_Saggital_Right_f1': 0.6149732620320856, 'val_Saggital_Right_precision': 0.6845238095238095, 'val_Saggital_Right_recall': 0.558252427184466, 'val_Transverse_Right_f1': 0.6086956521739131, 'val_Transverse_Right_precision': 0.5875912408759124, 'val_Transverse

wandb: Wandb version 0.8.31 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


	 conv1.weight
	 conv1.bias
	 conv2.weight
	 conv2.bias
	 conv3.weight
	 conv3.bias
	 linear1.weight
	 linear1.bias
	 linear2.weight
	 linear2.bias
Epoch 1/100
------------------------------------------------------
train loss:	1.5983 | train acc:	0.3868

              precision    recall  f1-score   support

           0       0.39      0.99      0.56      4432
           1       0.20      0.00      0.00      1751
           2       0.00      0.00      0.00      1195
           3       0.39      0.03      0.05      1871
           4       0.00      0.00      0.00      1296
           5       0.65      0.04      0.08      1057

    accuracy                           0.39     11602
   macro avg       0.27      0.18      0.11     11602
weighted avg       0.30      0.39      0.23     11602

val loss:	1.6617 | val acc:	0.3294

{'val_Other_f1': 0.43816254416961126, 'val_Other_precision': 0.29523809523809524, 'val_Other_recall': 0.8493150684931506, 'val_Saggital_Right_f1': 0.0, 'val_Saggital_

val loss:	1.5405 | val acc:	0.3514

{'val_Other_f1': 0.4359093012494216, 'val_Other_precision': 0.28545454545454546, 'val_Other_recall': 0.9217221135029354, 'val_Saggital_Right_f1': 0.09542743538767395, 'val_Saggital_Right_precision': 0.8275862068965517, 'val_Saggital_Right_recall': 0.05063291139240506, 'val_Transverse_Right_f1': 0.017291066282420747, 'val_Transverse_Right_precision': 1.0, 'val_Transverse_Right_recall': 0.00872093023255814, 'val_Saggital_Left_f1': 0.39662447257383965, 'val_Saggital_Left_precision': 0.47315436241610737, 'val_Saggital_Left_recall': 0.3414043583535109, 'val_Transverse_Left_f1': 0.0053475935828877, 'val_Transverse_Left_precision': 1.0, 'val_Transverse_Left_recall': 0.002680965147453083, 'val_Bladder_f1': 0.6103896103896104, 'val_Bladder_precision': 0.5013333333333333, 'val_Bladder_recall': 0.7800829875518672, 'val_precision_average_macro': 0.6812547413500897, 'val_recall_average_macro': 0.350874044363455, 'val_f1score_average_macro': 0.26016491324430896, '

train loss:	1.2300 | train acc:	0.5283

              precision    recall  f1-score   support

           0       0.53      0.78      0.63      4432
           1       0.53      0.40      0.46      1751
           2       0.49      0.14      0.21      1195
           3       0.46      0.44      0.45      1871
           4       0.45      0.20      0.28      1296
           5       0.70      0.67      0.68      1057

    accuracy                           0.53     11602
   macro avg       0.52      0.44      0.45     11602
weighted avg       0.52      0.53      0.50     11602

val loss:	1.4406 | val acc:	0.4724

{'val_Other_f1': 0.5074812967581047, 'val_Other_precision': 0.37236962488563585, 'val_Other_recall': 0.7964774951076321, 'val_Saggital_Right_f1': 0.42524005486968447, 'val_Saggital_Right_precision': 0.6078431372549019, 'val_Saggital_Right_recall': 0.3270042194092827, 'val_Transverse_Right_f1': 0.327433628318584, 'val_Transverse_Right_precision': 0.6851851851851852, 'val_Transver

train loss:	1.1667 | train acc:	0.5486

              precision    recall  f1-score   support

           0       0.55      0.77      0.64      4432
           1       0.56      0.44      0.49      1751
           2       0.51      0.22      0.31      1195
           3       0.47      0.44      0.46      1871
           4       0.48      0.28      0.36      1296
           5       0.74      0.69      0.72      1057

    accuracy                           0.55     11602
   macro avg       0.55      0.47      0.50     11602
weighted avg       0.54      0.55      0.53     11602

val loss:	1.3872 | val acc:	0.4728

{'val_Other_f1': 0.49691358024691357, 'val_Other_precision': 0.4101910828025478, 'val_Other_recall': 0.6301369863013698, 'val_Saggital_Right_f1': 0.4522502744237102, 'val_Saggital_Right_precision': 0.47139588100686497, 'val_Saggital_Right_recall': 0.4345991561181435, 'val_Transverse_Right_f1': 0.26229508196721313, 'val_Transverse_Right_precision': 0.6746987951807228, 'val_Transv

train loss:	1.1174 | train acc:	0.5735

              precision    recall  f1-score   support

           0       0.57      0.77      0.66      4432
           1       0.60      0.46      0.52      1751
           2       0.55      0.28      0.37      1195
           3       0.51      0.48      0.50      1871
           4       0.49      0.31      0.38      1296
           5       0.75      0.72      0.74      1057

    accuracy                           0.57     11602
   macro avg       0.58      0.51      0.53     11602
weighted avg       0.57      0.57      0.56     11602

val loss:	1.3512 | val acc:	0.5030

{'val_Other_f1': 0.5120101137800253, 'val_Other_precision': 0.37815126050420167, 'val_Other_recall': 0.7925636007827789, 'val_Saggital_Right_f1': 0.3777089783281734, 'val_Saggital_Right_precision': 0.7093023255813954, 'val_Saggital_Right_recall': 0.25738396624472576, 'val_Transverse_Right_f1': 0.36324786324786323, 'val_Transverse_Right_precision': 0.6854838709677419, 'val_Transv

val loss:	1.2927 | val acc:	0.5000

{'val_Other_f1': 0.495575221238938, 'val_Other_precision': 0.3454124903623747, 'val_Other_recall': 0.8767123287671232, 'val_Saggital_Right_f1': 0.5071151358344114, 'val_Saggital_Right_precision': 0.6555183946488294, 'val_Saggital_Right_recall': 0.41350210970464135, 'val_Transverse_Right_f1': 0.31775700934579443, 'val_Transverse_Right_precision': 0.8095238095238095, 'val_Transverse_Right_recall': 0.19767441860465115, 'val_Saggital_Left_f1': 0.5110782865583456, 'val_Saggital_Left_precision': 0.6553030303030303, 'val_Saggital_Left_recall': 0.4188861985472155, 'val_Transverse_Left_f1': 0.3784786641929499, 'val_Transverse_Left_precision': 0.6144578313253012, 'val_Transverse_Left_recall': 0.2734584450402145, 'val_Bladder_f1': 0.7843942505133471, 'val_Bladder_precision': 0.7764227642276422, 'val_Bladder_recall': 0.7925311203319502, 'val_precision_average_macro': 0.6427730533984979, 'val_recall_average_macro': 0.49546077016596596, 'val_f1score_average_macro'

train loss:	1.0304 | train acc:	0.6041

              precision    recall  f1-score   support

           0       0.60      0.77      0.67      4432
           1       0.65      0.49      0.56      1751
           2       0.60      0.36      0.45      1195
           3       0.52      0.52      0.52      1871
           4       0.56      0.41      0.47      1296
           5       0.78      0.77      0.77      1057

    accuracy                           0.60     11602
   macro avg       0.62      0.55      0.57     11602
weighted avg       0.61      0.60      0.59     11602

val loss:	1.2692 | val acc:	0.5378

{'val_Other_f1': 0.5287994448299791, 'val_Other_precision': 0.4096774193548387, 'val_Other_recall': 0.7455968688845401, 'val_Saggital_Right_f1': 0.5057766367137355, 'val_Saggital_Right_precision': 0.6459016393442623, 'val_Saggital_Right_recall': 0.41561181434599154, 'val_Transverse_Right_f1': 0.428, 'val_Transverse_Right_precision': 0.6858974358974359, 'val_Transverse_Right_reca

In [None]:
final_df.head()