## Model metrics

In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np

In [2]:
LOG_PATH = "lightning_logs"
METRICS_PATH = "../Metrics"

In [8]:
def check_folder(save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

def save_metrics_from_logger(model_id,version=0,save=True):
    metrics = pd.read_csv(f"{LOG_PATH}/{model_id}/version_{version}_train/metrics.csv")
    metrics = metrics.drop(['train_loss_step',	'train_acc_step',	'train_calibration_error_step'],axis=1)
    metrics = metrics.groupby(metrics['epoch']).first()
    if save:
        save_dir = f"{METRICS_PATH}/{model_id}/version_{version}"
        check_folder(save_dir)
        metrics.to_csv(f"{save_dir}/train_metrics.csv")
    return metrics

def get_metrics_from_csv(model_id,version=0,save=True):
    metrics = pd.read_csv(f"{METRICS_PATH}/{model_id}/version_{version}/train_metrics.csv")
    return metrics

In [11]:
def plot_metrics(model_id,version=0,show=False,save=True):
    metrics = get_metrics_from_csv(model_id,version)
    train_loss = metrics['train_loss_epoch']
    train_acc = metrics['train_acc_epoch']
    val_loss = metrics['val_loss']
    val_acc = metrics['val_acc']
    epoch = metrics.index

    fig = plt.figure(figsize=(9,4))
    ax1 = fig.add_subplot(221) 
    ax1.plot(epoch,train_loss,linestyle='-', c='orange')
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Train Loss')

    ax2 = fig.add_subplot(222) 
    ax2.plot(epoch,val_loss,linestyle='-', c='violet')
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Validation Loss')

    ax3 = fig.add_subplot(223) 
    ax3.plot(epoch,train_acc,linestyle='-', c='red')
    ax3.set_xlabel('Step')
    ax3.set_ylabel('Train Accuracy')

    ax4 = fig.add_subplot(224) 
    ax4.plot(epoch,val_acc,linestyle='-', c='green')
    ax4.set_xlabel('Step')
    ax4.set_ylabel('Validation Accuracy')
    #plt.tight_layout()
    #fig.suptitle(f"{run_name} (Loss: {test_loss:.2f}, Acc: {test_acc:.2%})")
    fig.suptitle(model_id)
    if save:
        save_dir = f"{METRICS_PATH}/{model_id}/version_{version}"
        check_folder(save_dir)
        plt.savefig(f'{save_dir}/metrics.png',dpi=200)
    if show:
        plt.show()
    else:
        plt.clf()
        plt.close()

## Generate cleaned up csv files for each run

In [10]:
model_ids = ['G_LeNet_cut_dataset_repeat','LeNet_cut_dataset_repeat','resnet18_cut_dataset_repeat','jiaresnet50_cut_dataset_repeat','jiaresnet50_cut_dataset_repeat','G_ResNet18_cut_dataset_repeat']

for model in model_ids:
    for run in range(5):
        #print(f"{model}, run {run}")
        try:
            save_metrics_from_logger(model,version=run)
        except:
            print(f"Error with {model}, run {run}")

Error with G_ResNet18_cut_dataset_repeat, run 3
Error with G_ResNet18_cut_dataset_repeat, run 4


## Create graphs for each run

In [12]:
for model in model_ids:
    for run in range(5):
        #print(f"{model}, run {run}")
        try:
            plot_metrics(model,version=run,show=False)
        except:
            print(f"Error with {model}, run {run}")

## Get best loss epoch for each run

In [14]:
for model in model_ids:
    best_losses = []
    best_accs = []
    best_eces = []
    best_chiralities = []
    for run in range(5):
        try:
            metrics = get_metrics_from_csv(model,version=run)
            best_loss_epoch = metrics['val_loss'].argmin()
            best_losses.append(metrics['val_loss'][best_loss_epoch])
            best_accs.append(metrics['val_acc'][best_loss_epoch])
            best_eces.append(metrics['val_calibration_error'][best_loss_epoch])
            best_chiralities.append((metrics['val_chirality_violation'][best_loss_epoch]))
        except:
            print(f"Error with {model}, run {run}")
    print(model)
    print(f"Loss: {np.average(best_losses):.4f} ± {np.std(best_losses):.4f}")
    print(f"Accuracy: {np.average(best_accs):.2%} ± {np.std(best_accs):.2%}")
    print(f"ECE: {np.average(best_eces):.4f} ± {np.std(best_eces):.4f}")
    print(f"C Viol: {np.average(best_chiralities):.4f} ± {np.std(best_chiralities):.4f}")
    print()

G_LeNet_cut_dataset_repeat
Loss: 0.7068 ± 0.0343
Accuracy: 83.39% ± 1.79%
ECE: 0.1835 ± 0.0106
C Viol: 0.5079 ± 0.4110

LeNet_cut_dataset_repeat
Loss: 0.8562 ± 0.0341
Accuracy: 70.61% ± 3.42%
ECE: 0.1421 ± 0.0146
C Viol: nan ± nan

resnet18_cut_dataset_repeat
Loss: 0.5290 ± 0.0003
Accuracy: 98.16% ± 0.12%
ECE: 0.2273 ± 0.0023
C Viol: 0.4963 ± 0.1710

jiaresnet50_cut_dataset_repeat
Loss: 0.5310 ± 0.0034
Accuracy: 97.65% ± 0.82%
ECE: 0.2232 ± 0.0066
C Viol: 0.2887 ± 0.0618

jiaresnet50_cut_dataset_repeat
Loss: 0.5310 ± 0.0034
Accuracy: 97.65% ± 0.82%
ECE: 0.2232 ± 0.0066
C Viol: 0.2887 ± 0.0618

Error with G_ResNet18_cut_dataset_repeat, run 3
Error with G_ResNet18_cut_dataset_repeat, run 4
G_ResNet18_cut_dataset_repeat
Loss: 0.6047 ± 0.1359
Accuracy: 92.19% ± 11.39%
ECE: 0.2100 ± 0.0293
C Viol: nan ± nan

