In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def compare_csv_column(folder_path: str, column_name: str, output_path: str):
    """
    Compare the same column across multiple CSV files in a folder and plot their values.
    
    Parameters
    ----------
    folder_path : str
        Path to the folder containing CSV files.
    column_name : str
        The name of the column to compare.
    output_path : str
        File path to save the resulting plot (e.g., 'output/comparison.png').
    """
    # Collect all CSV files in the folder
    csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]
    
    if not csv_files:
        raise ValueError(f"No CSV files found in {folder_path}")
    
    plt.figure(figsize=(10, 6))
    
    for file in csv_files:
        file_path = os.path.join(folder_path, file)
        try:
            df = pd.read_csv(file_path)
            
            if column_name not in df.columns:
                print(f"⚠️ Skipping {file} — column '{column_name}' not found.")
                continue
            
            plt.plot(df[column_name].values, label=file)
        
        except Exception as e:
            print(f"❌ Error reading {file}: {e}")
    # for plotting the number of visible tokens
    ax = plt.gca()
    ax2 = ax.twinx()
    ax2.plot(df['n_vis'].values, 'k--', label='n_vis')
    ax2.set_ylabel('n_vis', color='k')
    ax2.tick_params(axis='y', colors='k')
    # plt.title(f"Comparison of '{column_name}' Across Files")
    # combine legends from both axes
    handles1, labels1 = ax.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(handles1 + handles2, labels1 + labels2, loc='best')

    # plt.title(f"Comparison of '{column_name}' Across Files")
    ax.set_xlabel("Epoch")
    ax.set_ylabel(column_name)
    ax.grid(True)
    plt.tight_layout()
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    plt.savefig(output_path, dpi=300)
    plt.close()
    print(f"✅ Plot saved to {output_path}")


In [2]:
compare_csv_column(
    folder_path="results/all_training",
    column_name="train_acc",
    output_path="figs/train_accuracy.png"
)
compare_csv_column(
    folder_path="results/all_training",
    column_name="train_loss",
    output_path="figs/train_loss.png"
)


✅ Plot saved to figs/train_accuracy.png
✅ Plot saved to figs/train_loss.png


In [3]:
compare_csv_column(
    folder_path="results/all_training",
    column_name="val_acc",
    output_path="figs/val_accuracy.png"
)
compare_csv_column(
    folder_path="results/all_training",
    column_name="val_loss",
    output_path="figs/val_loss.png"
)


✅ Plot saved to figs/val_accuracy.png
✅ Plot saved to figs/val_loss.png
