In [None]:
"""
analyse_log.ipynb

Parses a training log file, then visualises the training/validation loss
and HellaSwag accuracy over time, comparing against GPT-2/GPT-3 baselines.
"""

import numpy as np
import matplotlib.pyplot as plt

# For inline plots (Jupyter):
%matplotlib inline

# ---------------------------------------------------------
# 1) Configure model size and known baseline metrics
# ---------------------------------------------------------
sz = "124M"  # e.g. 124M, 350M, 774M, 1558M

loss_baseline = {
    "124M": 3.2924,  # Validation loss for GPT-2 124M checkpoint
}[sz]

hella2_baseline = {  # HellaSwag baseline (GPT-2)
    "124M": 0.294463,
    "350M": 0.375224,
    "774M": 0.431986,
    "1558M": 0.488946,
}[sz]

hella3_baseline = {  # HellaSwag baseline (GPT-3)
    "124M": 0.337,
    "350M": 0.436,
    "774M": 0.510,
    "1558M": 0.547,
}[sz]

# ---------------------------------------------------------
# 2) Read and parse training log
# ---------------------------------------------------------
# Expecting lines of the form:
#   step  stream  value
# e.g. "100 train 3.0123"
# where 'stream' ∈ {train, val, hella}

log_filename = "logs/log.txt"

with open(log_filename, "r") as f:
    lines = f.readlines()

# Group logs by stream in a dictionary, e.g. streams['train'][step] = loss
streams = {}
for line in lines:
    step_str, stream_name, val_str = line.strip().split()
    step = int(step_str)
    val = float(val_str)
    if stream_name not in streams:
        streams[stream_name] = {}
    streams[stream_name][step] = val

# ---------------------------------------------------------
# 3) Convert dictionaries to sorted (step, value) tuples
# ---------------------------------------------------------
streams_xy = {}
for stream, data_dict in streams.items():
    # Sort by step
    sorted_items = sorted(data_dict.items())
    # Unzip keys and values into separate lists
    steps_list, values_list = zip(*sorted_items)
    streams_xy[stream] = (np.array(steps_list), np.array(values_list))

# ---------------------------------------------------------
# 4) Create figure with two subplots: Loss and HellaSwag
# ---------------------------------------------------------
plt.figure(figsize=(16, 6))

# ============= Panel 1: Training & Validation Loss =============
plt.subplot(1, 2, 1)

# Plot training loss
train_steps, train_vals = streams_xy["train"]
plt.plot(train_steps, train_vals, label=f'ArcanaGPT ({sz}) train loss')
print("Min Train Loss:", np.min(train_vals))

# Plot validation loss
val_steps, val_vals = streams_xy["val"]
plt.plot(val_steps, val_vals, label=f'ArcanaGPT ({sz}) val loss')
print("Min Validation Loss:", np.min(val_vals))

# Optionally compare to GPT-2 baseline loss
if loss_baseline is not None:
    plt.axhline(y=loss_baseline, color='r', linestyle='--',
                label=f"OpenAI GPT-2 ({sz}) checkpoint val loss")

plt.xlabel("Steps")
plt.ylabel("Loss")
plt.yscale('log')  # Log scale for better visual separation
plt.ylim(top=12.0)  # Adjust upper limit of y-axis
plt.title("Training & Validation Loss")
plt.legend()

# ============= Panel 2: HellaSwag Accuracy =============
plt.subplot(1, 2, 2)

hella_steps, hella_vals = streams_xy["hella"]
plt.plot(hella_steps, hella_vals, label=f"ArcanaGPT ({sz})")
print("Max Hellaswag Accuracy:", np.max(hella_vals))

# Plot GPT-2 baseline accuracy
if hella2_baseline is not None:
    plt.axhline(y=hella2_baseline, color='r', linestyle='--',
                label=f"OpenAI GPT-2 ({sz}) checkpoint")

# Plot GPT-3 baseline accuracy
if hella3_baseline is not None:
    plt.axhline(y=hella3_baseline, color='g', linestyle='--',
                label=f"OpenAI GPT-3 ({sz}) checkpoint")

plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.title("HellaSwag Accuracy")
plt.legend()

plt.tight_layout()
plt.show()