# Run Example (Local)

In [None]:
# 최초 실행 시 해당 셀의 주석을 모두 해제하고 실행합니다.

%%bash
# python3 -m venv venv # Create venv
source venv/bin/activate # Activate virtual environment
# pip install -r requirements.txt

In [None]:
%cd code

In [None]:
%pwd # Check if the path is right

## CLI Command

In [None]:
%%bash
OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 \
python code/main.py \
  --dataset gowalla \
  --model lgn \
  --recdim 32 \
  --layer 1 \
  --epochs 30 \
  --bpr_batch 4096 \
  --testbatch 512 \
  --optimizer cluster \
  --alpha 0.25 \
  --num_clusters 64 \
  --recluster_interval 128 \
  --cluster_warmup 1024 \
  --tensorboard 1

In [21]:
!tensorboard --logdir=runs


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.11.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


## Visualization

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def plot_metrics(files, labels=None, title="Training Metrics Comparison"):
    """
    files: file path list
    labels: file label list
    """
    
    if labels is None:
        labels = [os.path.basename(f) for f in files]

    dfs = []
    valid_labels = []
    for f, label in zip(files, labels):
        if not os.path.exists(f):
            print(f"[WARN] File not found: {f}")
            continue
        try:
            df = pd.read_csv(f)
            dfs.append(df)
            valid_labels.append(label)
            print(f"[OK] Loaded: {f}")
        except Exception as e:
            print(f"[ERROR] Failed to read csv file: {f} → {e}")

    if not dfs:
        print("[ERROR] No csv loaded.")
        return
    
    def find_columns(df):
        cols = df.columns
        metrics = {}
        metrics["epoch"] = "epoch" if "epoch" in cols else None

        for m in ["loss", "precision@20", "recall@20", "ndcg@20", "hr@20", "convergence_speed"]:
            found = [c for c in cols if m in c.lower()]
            metrics[m] = found[0] if found else None

        return metrics

    base_metrics = find_columns(dfs[0])
    metric_list = [m for m in base_metrics if base_metrics[m] and m != "epoch"]

    num_metrics = len(metric_list)
    rows = (num_metrics + 1) // 2
    cols = 2 if num_metrics > 1 else 1

    fig, axes = plt.subplots(rows, cols, figsize=(12, rows * 3))
    axes = axes.flatten() if num_metrics > 1 else [axes]

    for ax, metric_name in zip(axes, metric_list):
        metric_col = base_metrics[metric_name]

        for df, label in zip(dfs, valid_labels):
            if metric_col not in df.columns:
                print(f"[WARN] {label} file doesn't have {metric_name} column")
                continue

            ax.plot(df["epoch"], df[metric_col], marker="o", label=label)

        ax.set_title(metric_name)
        ax.set_xlabel("Epoch")
        ax.set_ylabel(metric_name)
        ax.grid(True)

    handles, leg_labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, leg_labels, loc="upper right", ncol=3)

    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

In [None]:
path_adam = "logs_gowalla_adam_lgn-12-02-03h38m56s.csv"
path_ccadam = "logs_gowalla_cluster_lgn-12-02-03h21m22s.csv"

plot_metrics(
    files=[path_adam, path_ccadam],
    labels=["Adam", "CCA"]
)