In [1]:
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
import os
from dotenv import load_dotenv
import glob
from typing import List
from const import *

In [2]:
def smooth(
    scalars: List[float], weight: float
) -> List[float]:  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)  # Save it
        last = smoothed_val  # Anchor the last smoothed value

    return smoothed


def extract_data(filename, metric, smoothing=0.99):
    # Initialize the EventAccumulator
    ea = event_accumulator.EventAccumulator(filename)
    ea.Reload()
    # Get the list of all scalar tags
    scalar_tags = ea.Tags()["scalars"]
    # Extract scalar data
    data = {}
    for tag in scalar_tags:
        events = ea.Scalars(tag)
        steps = [event.step for event in events]
        values = [event.value for event in events]
        data[tag] = (steps, values)
    return smooth(data[metric][1], smoothing)

In [3]:
import matplotlib.pyplot as plt
import os


def plot_and_save(
    data,
    title,
    filename,
    save_dir="img",
    fig_size=(14, 8),
    title_size=22,
    label_size=20,
    legend_size=20,
    tick_size=18,
    smoothing=0.99
):
    """
    Plot and save the data as a PDF file.

    Parameters:
    data (dict): A dictionary containing the data to plot. Keys are the labels, values are the data lists.
    title (str): The title of the plot.
    filename (str): The name of the file to save the plot.
    save_dir (str): The directory to save the plot. Default is "img".
    fig_size (tuple): The size of the figure. Default is (14, 8).
    title_size (int): The font size of the title. Default is 22.
    label_size (int): The font size of the labels. Default is 20.
    legend_size (int): The font size of the legend. Default is 20.
    tick_size (int): The font size of the tick labels. Default is 18.
    """
    # Ensure the directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Create the plot
    plt.figure(figsize=fig_size)
    for label, data_list in data.items():
        plt.plot(data_list, label=label)
    plt.title(title, fontsize=title_size)
    plt.xlabel(f"Steps (Smoothing={smoothing})", fontsize=label_size)
    plt.ylabel("Value", fontsize=label_size)
    plt.xticks(fontsize=tick_size)
    plt.yticks(fontsize=tick_size)
    plt.legend(fontsize=legend_size)
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()

In [6]:
smoothing = 0.99

In [7]:
# Extract MSE data for small models validation
rnn_small_data_val_mse = extract_data(
    RNN_SMALL_LOGS, "val/mse_loss", smoothing=smoothing
)
mamba_small_data_val_mse = extract_data(
    MAMBA_SMALL_LOGS, "val/mse_loss", smoothing=smoothing
)
transformer_small_data_val_mse = extract_data(
    TRANSFORMER_SMALL_LOGS, "val/mse_loss", smoothing=smoothing
)

# Extract MSE data for large models validation
rnn_large_data_val_mse = extract_data(
    RNN_LARGE_LOGS, "val/mse_loss", smoothing=smoothing
)
mamba_large_data_val_mse = extract_data(
    MAMBA_LARGE_LOGS, "val/mse_loss", smoothing=smoothing
)
transformer_large_data_val_mse = extract_data(
    TRANSFORMER_LARGE_LOGS, "val/mse_loss", smoothing=smoothing
)

# Extract MSE data for small models training
rnn_small_data_train_mse = extract_data(
    RNN_SMALL_LOGS, "train/mse_loss", smoothing=smoothing
)
mamba_small_data_train_mse = extract_data(
    MAMBA_SMALL_LOGS, "train/mse_loss", smoothing=smoothing
)
transformer_small_data_train_mse = extract_data(
    TRANSFORMER_SMALL_LOGS, "train/mse_loss", smoothing=smoothing
)

# Extract MSE data for large models training
rnn_large_data_train_mse = extract_data(
    RNN_LARGE_LOGS, "train/mse_loss", smoothing=smoothing
)
mamba_large_data_train_mse = extract_data(
    MAMBA_LARGE_LOGS, "train/mse_loss", smoothing=smoothing
)
transformer_large_data_train_mse = extract_data(
    TRANSFORMER_LARGE_LOGS, "train/mse_loss", smoothing=smoothing
)

# Define the data to plot
data_small_train_mse = {
    "RNN": rnn_small_data_train_mse,
    "Mamba": mamba_small_data_train_mse,
    "Transformer": transformer_small_data_train_mse,
}

data_small_val_mse = {
    "RNN": rnn_small_data_val_mse,
    "Mamba": mamba_small_data_val_mse,
    "Transformer": transformer_small_data_val_mse,
}

data_large_train_mse = {
    "RNN": rnn_large_data_train_mse,
    "Mamba": mamba_large_data_train_mse,
    "Transformer": transformer_large_data_train_mse,
}

data_large_val_mse = {
    "RNN": rnn_large_data_val_mse,
    "Mamba": mamba_large_data_val_mse,
    "Transformer": transformer_large_data_val_mse,
}

# Call the function for each plot
plot_and_save(
    data_small_train_mse,
    "MSE Train Loss - Experiment 1 (200k Params)",
    "mse_train_loss_small_models.pdf",
)
plot_and_save(
    data_small_val_mse,
    "MSE Val Loss - Experiment 1 (200k Params)",
    "mse_val_loss_small_models.pdf",
)
plot_and_save(
    data_large_train_mse,
    "MSE Train Loss - Experiment 2 (600k Params)",
    "mse_train_loss_large_models.pdf",
)
plot_and_save(
    data_large_val_mse,
    "MSE Val Loss - Experiment 2 (600k Params)",
    "mse_val_loss_large_models.pdf",
)