## Model metrics

In [18]:
import pandas as pd
import numpy as np
from tabulate import tabulate
from metrics_utils import *

In [3]:
LOG_PATH = "lightning_logs"
METRICS_PATH = "../Metrics"

## Generate cleaned up csv files for each run

In [4]:
model_ids = ['G_LeNet_cut_dataset_repeat',
             'LeNet_cut_dataset_repeat',
             'resnet18_cut_dataset_repeat',
             'resnet50_cut_dataset_repeat',
             'jiaresnet50_cut_dataset_repeat',
             'G_ResNet18_cut_dataset_repeat']

def save_run(model_ids,max_runs):
    for model in model_ids:
        for run in range(max_runs):
            print(f"{model}, run {run}")
            try:
                save_metrics_from_logger(model,LOG_PATH,METRICS_PATH,version=run)
            except:
                print(f"Error with {model}, run {run}")

#save_run(model_ids,5) #No longer needed as newer framework does this automatically

## Create graphs for each run

In [5]:
def generate_graphs_run(model_ids,max_runs):
    for model in model_ids:
        for run in range(max_runs):
            #print(f"{model}, run {run}")
            try:
                plot_train_metrics(model,METRICS_PATH,version=run,show=False)
            except:
                print(f"Error with {model}, run {run}")
#generate_graphs_run(model_ids,5) 

## Get best loss epoch for each run

In [35]:
def get_train_results_runs(model_ids,max_runs):
    repeat_metrics = pd.DataFrame(columns=["Loss","Accuracy","ECE","C Viol"],index=model_ids)
    repeat_metrics.columns.name="Model"
    for model in model_ids:
        best_losses = []
        best_accs = []
        best_eces = []
        best_chiralities = []
        for run in range(max_runs):
            try:
                metrics = get_metrics_from_csv(model,METRICS_PATH,version=run)
                #best_loss_epoch = 59 
                best_loss_epoch = metrics['val_loss'].argmin()
                best_losses.append(metrics['val_loss'][best_loss_epoch])
                best_accs.append(metrics['val_acc'][best_loss_epoch])
                best_eces.append(metrics['val_calibration_error'][best_loss_epoch])
                best_chiralities.append((metrics['val_chirality_violation'][best_loss_epoch]))
            except:
                print(f"Error with {model}, run {run}")

        # if model == 'LeNet_cut_dataset_repeat':
        #     print(best_chiralities)
        repeat_metrics.loc[model] = {"Loss": f"{np.average(best_losses):.4f} ± {np.std(best_losses):.4f}",
                                        "Accuracy": f"{np.average(best_accs):.2%} ± {np.std(best_accs):.2%}",
                                        "ECE": f"{np.average(best_eces):.4f} ± {np.std(best_eces):.4f}",
                                        "C Viol": f"{np.average(best_chiralities):.4f} ± {np.std(best_chiralities):.4f}"}
    #print(tabulate(repeat_metrics,headers='keys',tablefmt='github'))
    return repeat_metrics

repeat_metrics = get_train_results_runs(model_ids,5)
repeat_metrics.head(6)

Model,Loss,Accuracy,ECE,C Viol
G_LeNet_cut_dataset_repeat,0.7068 ± 0.0343,83.39% ± 1.79%,0.1835 ± 0.0106,0.5079 ± 0.4110
LeNet_cut_dataset_repeat,0.8562 ± 0.0341,70.61% ± 3.42%,0.1421 ± 0.0146,nan ± nan
resnet18_cut_dataset_repeat,0.5290 ± 0.0003,98.16% ± 0.12%,0.2273 ± 0.0023,0.4963 ± 0.1710
resnet50_cut_dataset_repeat,0.5309 ± 0.0002,97.40% ± 0.13%,0.2195 ± 0.0031,0.7647 ± 0.3068
jiaresnet50_cut_dataset_repeat,0.5310 ± 0.0034,97.65% ± 0.82%,0.2232 ± 0.0066,0.2887 ± 0.0618
G_ResNet18_cut_dataset_repeat,0.5346 ± 0.0024,97.79% ± 0.25%,0.2235 ± 0.0041,0.5458 ± 0.1820


In [38]:
def get_test_results_runs(model_ids,max_runs):
    repeat_metrics = pd.DataFrame(columns=["Loss","Accuracy","ECE","C Viol"],index=model_ids)
    repeat_metrics.columns.name="Model"
    for model in model_ids:
        best_losses = []
        best_accs = []
        best_eces = []
        best_chiralities = []
        for run in range(max_runs):
            try:
                metrics = get_metrics_from_csv(model,METRICS_PATH,version=run,mode='test')
                best_losses.append(metrics['test_loss'])
                best_accs.append(metrics['test_acc'])
                best_eces.append(metrics['test_calibration_error'])
                best_chiralities.append((metrics['test_chirality_violation']))
            except:
                print(f"Error with {model}, run {run}")

        # if model == 'LeNet_cut_dataset_repeat':
        #     print(best_chiralities)
        repeat_metrics.loc[model] = {"Loss": f"{np.average(best_losses):.4f} ± {np.std(best_losses):.4f}",
                                        "Accuracy": f"{np.average(best_accs):.2%} ± {np.std(best_accs):.2%}",
                                        "ECE": f"{np.average(best_eces):.4f} ± {np.std(best_eces):.4f}",
                                        "C Viol": f"{np.average(best_chiralities):.4f} ± {np.std(best_chiralities):.4f}"}
    #print(tabulate(repeat_metrics,headers='keys',tablefmt='github'))
    return repeat_metrics

model_ids = ['resnet18_cut_dataset', #Replace with repeat once all repeats have been tested
             'resnet50_cut_dataset',
             'jiaresnet50_cut_dataset',
             'G_ResNet18_cut_dataset']
repeat_metrics = get_test_results_runs(model_ids,1)
repeat_metrics.head(6)

Model,Loss,Accuracy,ECE,C Viol
resnet18_cut_dataset,0.4603 ± 0.0000,91.51% ± 0.00%,0.1195 ± 0.0000,1.1345 ± 0.0000
resnet50_cut_dataset,0.4739 ± 0.0000,91.62% ± 0.00%,0.1371 ± 0.0000,2.1646 ± 0.0000
jiaresnet50_cut_dataset,0.4117 ± 0.0000,94.22% ± 0.00%,0.1208 ± 0.0000,-0.0001 ± 0.0000
G_ResNet18_cut_dataset,0.4150 ± 0.0000,92.84% ± 0.00%,0.1060 ± 0.0000,0.2052 ± 0.0000
