In [37]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [38]:
# Load metrics files
metrics_path = "metrics"
training_metrics = [json.load(open(os.path.join(metrics_path, f))) 
                   for f in sorted(os.listdir(metrics_path)) 
                   if f.startswith("train_") and f.endswith(".json")]
validation_metrics = [json.load(open(os.path.join(metrics_path, f))) 
                     for f in sorted(os.listdir(metrics_path))
                     if f.startswith("val_") and f.endswith(".json")]

# Process training metrics
training_df = pd.DataFrame([
    {
        'step': m['step'],
        'epoch': m['epoch'],
        'accuracy': m['true_positives'] / m['total'] if m['total'] > 0 else 0,
        'type': 'training',
        **{f'class_{k}_acc': m[str(k)]['true_positives']/m[str(k)]['total'] 
           if m[str(k)]['total'] > 0 else 0 
           for k in range(2)}  # Assuming 2 classes, adjust if different
    }
    for m in training_metrics
])

# Process validation metrics
validation_df = pd.DataFrame([
    {
        'step': m['step'],
        'epoch': m['epoch'],
        'accuracy': m['true_positives'] / m['total'] if m['total'] > 0 else 0,
        'type': 'validation',
        **{f'class_{k}_acc': m[str(k)]['true_positives']/m[str(k)]['total'] 
           if m[str(k)]['total'] > 0 else 0 
           for k in range(2)}  # Assuming 2 classes, adjust if different
    }
    for m in validation_metrics
])

In [None]:
training_df.head()

In [40]:
class_cols = [col for col in training_df.columns if col.startswith('class_')]

In [None]:
# Plot overall accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))

# Training accuracy barplot
sns.barplot(data=training_df, x='step', y='accuracy', ax=ax1, color='steelblue')
ax1.set_title('Training Overall Accuracy', fontsize=26)
ax1.set_xlabel('Training Steps', fontsize=20)
ax1.set_ylabel('Accuracy', fontsize=20)
ax1.grid(True)

# Validation accuracy barplot
sns.barplot(data=validation_df, x='epoch', y='accuracy', ax=ax2, color='darkred')
ax2.set_title('Validation Overall Accuracy', fontsize=26)
ax2.set_xlabel('Epoch', fontsize=20)
ax2.set_ylabel('Accuracy', fontsize=20)
ax2.grid(True)

plt.tight_layout()
plt.show()

# Plot per-class accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Reshape data for better plotting
train_melted = pd.melt(training_df, 
                      id_vars=['step'], 
                      value_vars=class_cols, 
                      var_name='class',
                      value_name='class_accuracy')

val_melted = pd.melt(validation_df, 
                     id_vars=['epoch'], 
                     value_vars=class_cols, 
                     var_name='class',
                     value_name='class_accuracy')

color_palette = sns.color_palette("deep", len(class_cols))

# Training per-class plot
sns.lineplot(data=train_melted, x='step', y='class_accuracy', 
            hue='class', lw=3,
            markers=True, dashes=False, ax=ax1, palette=color_palette)
ax1.set_title('Training Per-class Accuracy', fontsize=26)
ax1.set_xlabel('Training Steps', fontsize=20)
ax1.set_ylabel('Accuracy', fontsize=20)
ax1.grid(True)
ax1.legend(title='Class', bbox_to_anchor=(1.05, 1))

# Validation per-class plot
sns.lineplot(data=val_melted, x='epoch', y='class_accuracy', 
                hue='class', lw=3,
                markers=True, ax=ax2, palette=color_palette)
ax2.set_title('Validation Per-class Accuracy', fontsize=26)
ax2.set_xlabel('Epoch', fontsize=20)
ax2.set_ylabel('Accuracy', fontsize=20)
ax2.grid(True)
ax2.legend(title='Class', bbox_to_anchor=(1.05, 1))

plt.tight_layout()
plt.show()