# Loss and Accuarcy over Epochs

In [None]:
import glob
import matplotlib.pyplot as plt
import json
import os
from itertools import cycle

# Use glob to find all json files in the specified directory
file_list = glob.glob('results/metrics_*.json')

# Create a color cycle iterator to assign unique colors
color_cycle = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

# Plot the data
plt.figure(figsize=(14, 7))

# Initialize a dictionary to store colors for each label
color_map = {}

# Subplot for Loss
plt.subplot(1, 2, 1)
for file in file_list:
    # Extract model_name and label_name from the file name
    filename = os.path.basename(file)
    model_label = filename.replace('metrics_', '').rsplit('.', 1)[0]  # Remove the 'metrics_' prefix and '.json' suffix
    
    # Check if we have already assigned a color, otherwise get the next color
    if model_label not in color_map:
        color_map[model_label] = next(color_cycle)
        
    # Read json content
    with open(file, 'r') as f:
        data = json.load(f)

    epochs = list(range(1, len(data['train_loss']) + 1))
    plt.plot(epochs, data['train_loss'], label=f'Train Loss ({model_label})', color=color_map[model_label])
    plt.plot(epochs, data['val_loss'], label=f'Val Loss ({model_label})', linestyle='--', color=color_map[model_label])

plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Subplot for Accuracy
plt.subplot(1, 2, 2)
for file in file_list:
    # Extract model_name and label_name from the file name
    filename = os.path.basename(file)
    model_label = filename.replace('metrics_', '').rsplit('.', 1)[0]  # Remove the 'metrics_' prefix and '.json' suffix
    
    # Read json content
    with open(file, 'r') as f:
        data = json.load(f)

    epochs = list(range(1, len(data['train_acc']) + 1))
    plt.plot(epochs, data['train_acc'], label=f'Train Acc ({model_label})', color=color_map[model_label])
    plt.plot(epochs, data['val_acc'], label=f'Val Acc ({model_label})', linestyle='--', color=color_map[model_label])

plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.suptitle('Loss and Accuracy over Epochs')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
