# SASRec training curves â€“ TensorBoard logs & job output

Plot training loss and validation metrics (Hit@10, NDCG@10) from:
1. **TensorBoard event logs** (`log_tensorboard/`)
2. **Job text logs** (PBS `.o` file or `sasrec_live.log`) as fallback

In [None]:
import os
import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

# Paths: TRACT repo root (notebooks/ is one level below)
TRACT_ROOT = Path(os.getcwd()).resolve()
if (TRACT_ROOT / "log_tensorboard").exists():
    pass  # already in TRACT root
elif (TRACT_ROOT.parent / "log_tensorboard").exists():
    TRACT_ROOT = TRACT_ROOT.parent  # run from notebooks/
else:
    TRACT_ROOT = Path("..").resolve()  # fallback
LOG_TENSORBOARD = TRACT_ROOT / "log_tensorboard"
LOG_JOB = TRACT_ROOT / "sasrec_live.log"  # or e.g. sasrec_ml1m.o159742817

## 1. Load from TensorBoard event logs

In [None]:
try:
    from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
    HAS_TENSORBOARD = True
except ImportError:
    HAS_TENSORBOARD = False
    print("Install tensorboard: pip install tensorboard")

def load_tensorboard_scalars(logdir):
    """Load all scalar runs from a TensorBoard log directory (parent of run dirs)."""
    if not HAS_TENSORBOARD or not logdir.exists():
        return {}
    runs = {}
    for run_path in sorted(logdir.iterdir()):
        if not run_path.is_dir():
            continue
        ea = EventAccumulator(str(run_path))
        ea.Reload()
        tags = ea.Tags().get("scalars", [])
        if not tags:
            continue
        runs[run_path.name] = {}
        for tag in tags:
            events = ea.Scalars(tag)
            steps = [e.step for e in events]
            values = [e.value for e in events]
            runs[run_path.name][tag] = (np.array(steps), np.array(values))
    return runs

In [None]:
tb_runs = load_tensorboard_scalars(LOG_TENSORBOARD)
print("TensorBoard runs found:", list(tb_runs.keys()))
for run_name, scalars in tb_runs.items():
    print(f"  {run_name}: {list(scalars.keys())}")

In [None]:
def plot_tensorboard_runs(tb_runs, tags_to_plot=None):
    """Plot scalar curves for each run. tags_to_plot: list of tag names, or None = all."""
    if not tb_runs:
        print("No TensorBoard runs to plot.")
        return
    all_tags = set()
    for scalars in tb_runs.values():
        all_tags.update(scalars.keys())
    tags = tags_to_plot or sorted(all_tags)
    n = len(tags)
    fig, axes = plt.subplots(1, n, figsize=(5 * n, 4))
    if n == 1:
        axes = [axes]
    for ax, tag in zip(axes, tags):
        for run_name, scalars in tb_runs.items():
            if tag not in scalars:
                continue
            steps, values = scalars[tag]
            ax.plot(steps, values, label=run_name[:40], alpha=0.8)
        ax.set_title(tag)
        ax.set_xlabel("Epoch")
        ax.legend(loc="best", fontsize=7)
        ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

plot_tensorboard_runs(tb_runs)

## 2. Parse job text log (fallback / comparison)

Extract epoch, train loss, Hit@10, NDCG@10 from lines like:
- `epoch 43 training [time: 21.92s, train loss: 1119.2058]`
- `hit@10 : 0.2456    ndcg@10 : 0.1256`

In [None]:
def parse_job_log(log_path):
    """Parse PBS/job log; return dict with epoch, train_loss, hit10, ndcg10 arrays."""
    log_path = Path(log_path)
    if not log_path.exists():
        return None
    text = log_path.read_text()
    epochs, train_losses, hit10, ndcg10 = [], [], [], []
    # epoch N training [..., train loss: XXXX]
    for m in re.finditer(r"epoch\s+(\d+)\s+training\s+.*train loss:\s+([\d.]+)", text):
        epochs.append(int(m.group(1)))
        train_losses.append(float(m.group(2)))
    # hit@10 : X.XXXX    ndcg@10 : X.XXXX (after 'valid result:')
    for m in re.finditer(r"hit@10\s*:\s*([\d.]+)\s*ndcg@10\s*:\s*([\d.]+)", text, re.IGNORECASE):
        hit10.append(float(m.group(1)))
        ndcg10.append(float(m.group(2)))
    # Align length (valid every epoch)
    n = min(len(epochs), len(hit10), len(ndcg10))
    if n == 0:
        return None
    return {
        "epoch": np.array(epochs[:n]),
        "train_loss": np.array(train_losses[:n]),
        "hit@10": np.array(hit10[:n]),
        "ndcg@10": np.array(ndcg10[:n]),
    }

# Try default live log; or set path to a specific .o file
job_log_path = LOG_JOB
if not job_log_path.exists():
    # Try any sasrec_ml1m.o* in TRACT root
    for f in TRACT_ROOT.glob("sasrec_ml1m.o*"):
        job_log_path = f
        break
parsed = parse_job_log(job_log_path)
print("Job log path:", job_log_path)
print("Parsed:", parsed if parsed is None else {k: v.shape for k, v in parsed.items()})

In [None]:
def plot_job_log(parsed, title="Training curves (job log)"):
    if parsed is None:
        print("No parsed job log to plot.")
        return
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    ep = parsed["epoch"]
    axes[0].plot(ep, parsed["train_loss"], "b-", label="Train loss")
    axes[0].set_title("Train loss")
    axes[0].set_xlabel("Epoch")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(ep, parsed["hit@10"], "g-", label="Hit@10")
    axes[1].set_title("Valid Hit@10")
    axes[1].set_xlabel("Epoch")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    axes[2].plot(ep, parsed["ndcg@10"], "orange", label="NDCG@10")
    axes[2].set_title("Valid NDCG@10")
    axes[2].set_xlabel("Epoch")
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)

    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_job_log(parsed)

## 3. Optional: plot a specific PBS output file

Set `job_id` to your job ID (e.g. 159742817) to plot that run's log.

In [None]:
job_id = "159742817"  # change to your job ID
job_o_file = TRACT_ROOT / f"sasrec_ml1m.o{job_id}"
parsed_specific = parse_job_log(job_o_file)
if parsed_specific is not None:
    plot_job_log(parsed_specific, title=f"SASRec ml-1m (job {job_id})")
else:
    print(f"File not found or no matches: {job_o_file}")