# GATv2-GCN NBA Player Performance Prediction

## Full Reproduction on Google Colab (A100)

> **Paper:** *Predicting NBA Player Performance via Graph Attention Networks with Temporal Convolutions* (Luo & Krishnamurthy 2023)



This notebook implements the complete pipeline end-to-end:

1. **Environment setup** — install all dependencies on Colab

2. **Mount Google Drive** — load helper scripts and save all outputs

3. **Data acquisition** — scrape NBA box-scores via `nba_api` (2022-23 through 2025-26)

4. **Preprocessing** — forward-fill, Z-score normalise, build graph sequence

5. **Model definition** — GATv2Conv + Temporal Conv2D architecture (17-D input → 6-stat output)

6. **Training** — Adam, MSE, 50/25/25 chronological split, 300 epochs

7. **Baseline comparison** — Naïve, TCN-only, ASTGCN

8. **Evaluation & visualisation** — RMSE, MAE, MAPE, CORR + rich figures

9. **Case study** — Underdog Fantasy prop-line pick'em replay


## 0 · Runtime Check
Make sure you have selected **A100** (or at least T4) under *Runtime → Change runtime type*.

In [None]:
import subprocess, sys
gpu = subprocess.run(['nvidia-smi','--query-gpu=name','--format=csv,noheader'],
                    capture_output=True, text=True).stdout.strip()
print('GPU:', gpu)
print('Python:', sys.version)

## 1 · Install Dependencies
Colab ships with PyTorch; we only need to add `torch-geometric` and `nba-api`.

In [None]:
# Install extra packages (takes ~60 s on first run)
!pip install -q torch-geometric nba-api
# Confirm versions
import torch, torch_geometric
print('PyTorch:', torch.__version__)
print('PyG:    ', torch_geometric.__version__)
print('CUDA:   ', torch.cuda.is_available())

## 2 · Mount Google Drive

Upload the `knowball/networks/reproduction/` and `knowball/networks/NBA-GNN-prediction/` folders to your Drive.

Expected layout inside Drive:

```

MyDrive/

  knowball/

    NBA-GNN-prediction/

      gatv2tcn.py

      tcn.py            ← stub included in the upload folder

      player_id2name.pkl

      player_id2team.pkl

      player_id2position.pkl

      data/

        X_seq.pkl

        G_seq.pkl

    reproduction/

      01_data_pipeline.py

      02_train.py

      03_baselines_and_analysis.py

```


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys, os
DRIVE_ROOT = '/content/drive/MyDrive/knowball'
REPO_ROOT  = f'{DRIVE_ROOT}/NBA-GNN-prediction'
REPRO_ROOT = f'{DRIVE_ROOT}/reproduction'
sys.path.insert(0, REPO_ROOT)
sys.path.insert(0, REPRO_ROOT)

# Output directories (written back to Drive so nothing is lost if runtime disconnects)
MODEL_DIR  = f'{DRIVE_ROOT}/outputs/model'
FIG_DIR    = f'{DRIVE_ROOT}/outputs/figures'
DATA_DIR   = f'{DRIVE_ROOT}/outputs/data'
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(FIG_DIR,   exist_ok=True)
os.makedirs(DATA_DIR,  exist_ok=True)
print("Drive mounted. Paths ready.")

## 3 · Shared Imports & Constants

In [None]:
import copy, itertools, json, logging, pickle, time, warnings
from pathlib import Path
warnings.filterwarnings('ignore')

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import networkx as nx
from scipy import stats
from sklearn import preprocessing
from sklearn.metrics import (mean_squared_error, mean_absolute_error,
                             mean_absolute_percentage_error)
from numpy.lib.stride_tricks import sliding_window_view

import torch, torch.nn as nn, torch.nn.functional as F
from torch.autograd import Variable

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

# ── Feature / prediction constants ───────────────────────────
FEATURE_COLS    = ['PTS','AST','REB','TO','STL','BLK','PLUS_MINUS',
                   'TCHS','PASS','DIST','PACE','USG_PCT','TS_PCT']
PREDICTION_COLS = ['PTS','AST','REB','TO','STL','BLK']
PRED_INDICES    = [FEATURE_COLS.index(c) for c in PREDICTION_COLS]
MIN_MINUTES     = 10.0
SEQ_LENGTH      = 10
OFFSET          = 1

## 4 · Data Acquisition

We call the `nba_api` for traditional, advanced, and player-tracking box-scores across four seasons.

**Time estimate on Colab:** ~2-4 hours (API rate limits, ~0.7 s/request × 3 calls/game × ~1 200 games).

If you already ran this and saved `raw_boxscores.parquet` to Drive, the cell will reload it instantly.


In [None]:
from nba_api.stats.endpoints import (
    leaguegamefinder, boxscoretraditionalv2,
    boxscoreadvancedv2, boxscoreplayertrackv3,
)
from datetime import date

SEASONS = [
    ('2022-23', '2022-10-18', '2023-04-09'),
    ('2023-24', '2023-10-24', '2024-04-14'),
    ('2024-25', '2024-10-22', '2025-04-13'),
    ('2025-26', '2025-10-28', str(date.today())),
]
API_DELAY   = 0.7
TRAD_COLS   = ['GAME_ID','PLAYER_ID','PLAYER_NAME','TEAM_ID','TEAM_ABBREVIATION',
               'MIN','PTS','AST','REB','TO','STL','BLK','PLUS_MINUS']
ADV_COLS    = ['GAME_ID','PLAYER_ID','PACE','USG_PCT','TS_PCT']
TRACK_COLS  = ['GAME_ID','PLAYER_ID','DIST','TCHS','PASS']

In [None]:
def parse_min(m):
    if pd.isna(m): return 0.0
    if isinstance(m, str) and ':' in m:
        p = m.split(':'); return float(p[0]) + float(p[1])/60
    return float(m)

def get_game_ids(season, d_from, d_to):
    gf = leaguegamefinder.LeagueGameFinder(
        season_nullable=season, season_type_nullable='Regular Season',
        date_from_nullable=d_from, date_to_nullable=d_to, league_id_nullable='00')
    df = gf.get_data_frames()[0]
    return sorted(df['GAME_ID'].drop_duplicates().tolist())

def fetch_game(game_id):
    try:
        time.sleep(API_DELAY)
        dt = boxscoretraditionalv2.BoxScoreTraditionalV2(game_id=game_id).get_data_frames()[0]
        time.sleep(API_DELAY)
        da = boxscoreadvancedv2.BoxScoreAdvancedV2(game_id=game_id).get_data_frames()[0]
        time.sleep(API_DELAY)
        dk = boxscoreplayertrackv3.BoxScorePlayerTrackV3(game_id=game_id).get_data_frames()[0]
        dk.columns = dk.columns.str.upper()
        for want, variants in [('DIST',['DIST','DISTANCE']),('TCHS',['TCHS','TOUCHES']),('PASS',['PASS','PASSES'])]:
            for v in variants:
                if v in dk.columns: dk = dk.rename(columns={v: want}); break
        avt = [c for c in TRAD_COLS  if c in dt.columns]
        ava = [c for c in ADV_COLS   if c in da.columns]
        avk = ['GAME_ID','PLAYER_ID'] + [c for c in ['DIST','TCHS','PASS'] if c in dk.columns]
        df  = dt[avt].merge(da[ava], on=['GAME_ID','PLAYER_ID'], how='left')
        df  = df.merge(dk[avk],     on=['GAME_ID','PLAYER_ID'], how='left')
        df['MIN'] = df['MIN'].apply(parse_min)
        return df[df['MIN'] >= MIN_MINUTES]
    except Exception as e:
        print(f'  skip {game_id}: {e}'); return None

In [None]:
RAW_PATH = f'{DATA_DIR}/raw_boxscores.parquet'

if os.path.exists(RAW_PATH):
    print('Loading cached raw data…')
    raw_df = pd.read_parquet(RAW_PATH)
else:
    print('Starting data acquisition (this will take a few hours)…')
    frames = []
    for season, d_from, d_to in SEASONS:
        print(f'\nSeason {season}')
        gids = get_game_ids(season, d_from, d_to)
        print(f'  {len(gids)} games')
        for k, gid in enumerate(gids):
            if k % 50 == 0: print(f'  [{k}/{len(gids)}]')
            df = fetch_game(gid)
            if df is not None and len(df):
                df['SEASON'] = season
                frames.append(df)
    raw_df = pd.concat(frames, ignore_index=True)
    raw_df.to_parquet(RAW_PATH, index=False)
    print(f'Saved {len(raw_df):,} rows → {RAW_PATH}')

print(f'Total player-game rows: {len(raw_df):,}')
raw_df.head(3)

## 5 · Preprocessing & Graph Construction

- Parse game dates from `GAME_ID` (format `002YYYYMMDD`)

- Forward-fill zeros (players who sat out carry their last stats)

- Z-score normalise each feature

- Build an undirected graph per game-day: edge ↔ two players competed in the same game


In [None]:
def fill_zeros_with_last(seq):
    seq_ff = np.zeros_like(seq)
    for i in range(seq.shape[1]):
        arr = seq[:, i]
        prev = np.arange(len(arr))
        prev[arr == 0] = 0
        prev = np.maximum.accumulate(prev)
        seq_ff[:, i] = arr[prev]
    return seq_ff

def preprocess(df):
    for c in FEATURE_COLS:
        if c not in df.columns: df[c] = 0.0
    if 'GAME_DATE' not in df.columns:
        df['GAME_DATE'] = pd.to_datetime(
            df['GAME_ID'].astype(str).str[3:11], format='%Y%m%d', errors='coerce')
    df['GAME_DATE'] = pd.to_datetime(df['GAME_DATE'])
    df = df.sort_values('GAME_DATE')

    player_ids  = sorted(df['PLAYER_ID'].unique())
    player_index = {p: i for i, p in enumerate(player_ids)}
    N = len(player_ids)
    game_dates  = sorted(df['GAME_DATE'].dt.date.unique())
    D = len(game_dates)
    print(f'Players: {N}   Game-days: {D}')

    X_raw = np.zeros((D, N, len(FEATURE_COLS)), dtype=np.float32)
    G_raw = []
    for d, gdate in enumerate(game_dates):
        day = df[df['GAME_DATE'].dt.date == gdate]
        G = nx.Graph(); G.add_nodes_from(player_ids)
        for gid, grp in day.groupby('GAME_ID'):
            active = grp['PLAYER_ID'].tolist()
            for pid in active:
                if pid in player_index:
                    row = grp[grp['PLAYER_ID']==pid].iloc[0]
                    X_raw[d, player_index[pid]] = [float(row.get(c,0) or 0) for c in FEATURE_COLS]
            for pA, pB in itertools.combinations(active, 2):
                if pA in player_index and pB in player_index: G.add_edge(pA, pB)
        G_raw.append(G)

    X_ff = np.zeros_like(X_raw)
    for p in range(N):
        X_ff[:, p, :] = fill_zeros_with_last(X_raw[:, p, :])
    mu = X_ff.mean(axis=(0,1), keepdims=True)
    sd = X_ff.std(axis=(0,1),  keepdims=True) + 1e-8
    X_norm = (X_ff - mu) / sd
    return X_norm, G_raw, [str(d) for d in game_dates], player_ids

In [None]:
X_PKL = f'{DATA_DIR}/X_seq.pkl'
G_PKL = f'{DATA_DIR}/G_seq.pkl'
P_PKL = f'{DATA_DIR}/player_ids.pkl'

if os.path.exists(X_PKL):
    print('Loading cached tensors…')
    X_seq      = pickle.load(open(X_PKL,'rb'))
    G_seq      = pickle.load(open(G_PKL,'rb'))
    player_ids = pickle.load(open(P_PKL,'rb'))
else:
    print('Building tensors…')
    X_seq, G_seq, game_dates, player_ids = preprocess(raw_df)
    pickle.dump(X_seq,      open(X_PKL,'wb'))
    pickle.dump(G_seq,      open(G_PKL,'wb'))
    pickle.dump(player_ids, open(P_PKL,'wb'))
    print('Saved.')
print('X_seq shape:', X_seq.shape)

## 6 · Team & Position Embeddings

The model takes a 17-D input per player per time-step:

- 13 normalised statistics

- 2-D projected team one-hot

- 2-D projected position one-hot


In [None]:
# Try to load from Drive pkl files (original repo), else build from raw_df
def build_embeddings(player_ids, raw_df):
    N = len(player_ids)
    pid_set = set(player_ids)

    # Team: most common team per player, then label-encode → one-hot
    pid2team_str = {}
    for pid, grp in raw_df.groupby('PLAYER_ID'):
        if pid in pid_set:
            pid2team_str[pid] = grp['TEAM_ABBREVIATION'].mode()[0]
    teams = sorted(set(pid2team_str.values()))
    team2idx = {t:i for i,t in enumerate(teams)}
    n_teams  = len(teams)
    team_oh  = np.zeros((N, n_teams), dtype=np.float32)
    for idx, pid in enumerate(player_ids):
        team_oh[idx, team2idx.get(pid2team_str.get(pid,''), 0)] = 1.0

    # Position: query nba_api static roster
    try:
        from nba_api.stats.static import players as nba_pl
        static = {p['id']:p for p in nba_pl.get_players()}
        pos_map = {'G':[0,1,0],'F':[1,0,0],'C':[0,0,1],'F-G':[1,1,0],'F-C':[1,0,1]}
        def enc_pos(pid):
            info = static.get(pid,{})
            key  = (info.get('position','') or '').replace(' ','-')[:3]
            return pos_map.get(key, [0,0,0])
    except Exception:
        enc_pos = lambda pid: [0,0,0]
    pos_arr = np.array([enc_pos(pid) for pid in player_ids], dtype=np.float32)
    n_pos   = pos_arr.shape[1]
    return team_oh, pos_arr, n_teams, n_pos

team_oh, pos_arr, n_teams, n_pos = build_embeddings(player_ids, raw_df)
team_tensor = torch.FloatTensor(team_oh).to(DEVICE)
pos_tensor  = torch.FloatTensor(pos_arr).to(DEVICE)
print(f'Team one-hot: {team_oh.shape}   Position: {pos_arr.shape}')

## 7 · Convert Graphs to Edge-Index Tensors
PyTorch Geometric expects `edge_index` as a `(2, E)` LongTensor.

In [None]:
def graphs_to_edges(G_seq, player_ids):
    nd = {pid:i for i,pid in enumerate(player_ids)}
    out = []
    for G in G_seq:
        edges = list(G.edges())
        if not edges:
            n = len(player_ids)
            out.append(torch.stack([torch.arange(n),torch.arange(n)]).long())
        else:
            s,d = zip(*edges)
            s = [nd.get(x,0) for x in s]; d = [nd.get(x,0) for x in d]
            out.append(torch.stack([
                torch.LongTensor(s+d), torch.LongTensor(d+s)]).to(DEVICE))
    return out

G_edges = graphs_to_edges(G_seq, player_ids)
print(f'Built {len(G_edges)} edge tensors.')

## 8 · Sliding Windows & Train/Val/Test Split

Following the paper: **50 % train / 25 % val / 25 % test** (strictly chronological).

A window of `SEQ_LENGTH=10` game-days predicts the next day's stats.


In [None]:
Xs = np.zeros_like(X_seq)
for i in range(X_seq.shape[1]):
    Xs[:, i, :] = fill_zeros_with_last(X_seq[:, i, :])

X_in  = sliding_window_view(Xs[:-OFFSET],      SEQ_LENGTH, axis=0)  # (T, N, F, 10)
X_out = Xs[SEQ_LENGTH + OFFSET - 1:]                                  # (T, N, F)
G_in  = [G_edges[t:t+SEQ_LENGTH] for t in range(len(G_edges)-SEQ_LENGTH-OFFSET+1)]
G_out = G_edges[SEQ_LENGTH+OFFSET-1:]

X_in  = torch.FloatTensor(np.array(X_in))
X_out = torch.FloatTensor(np.array(X_out))

T  = len(G_in)
t1 = int(T * 0.50)
t2 = int(T * 0.75)
print(f'Total windows: {T}  →  train:{t1}  val:{t2-t1}  test:{T-t2}')

X_tr, y_tr, G_tr = X_in[:t1], X_out[:t1], G_in[:t1]
X_va, y_va, G_va = X_in[t1:t2], X_out[t1:t2], G_in[t1:t2]
X_te, y_te, G_te = X_in[t2:],  X_out[t2:],  G_in[t2:]

## 9 · Model Definition — GATv2-TCN

We import `GATv2TCN` from the repo's `gatv2tcn.py`. The architecture:

| Layer | Details |

|---|---|

| Input | 17-D per player (13 stats + 2 team emb + 2 pos emb) |

| **GATv2Conv** | `in=17, out=32, heads=4` → 128-D spatial rep |

| **Temporal Conv2D** | `(128→64)`, kernel `(1,1)` |

| Residual Conv2D | `(17→64)`, adds skip connection |

| LayerNorm + ReLU | stabilise gradients |

| Final Conv2D | maps `seq_len` → 6-stat prediction |


In [None]:
from gatv2tcn import GATv2TCN

model_in = len(FEATURE_COLS) + 2 + 2  # 17

team_emb = nn.Linear(n_teams, 2).to(DEVICE)
pos_emb  = nn.Linear(n_pos, 2).to(DEVICE)
model    = GATv2TCN(
    in_channels=model_in, out_channels=6,
    len_input=SEQ_LENGTH, len_output=1,
    temporal_filter=64, out_gatv2conv=32,
    dropout_tcn=0.25, dropout_gatv2conv=0.5, head_gatv2conv=4,
).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable parameters: {total_params:,}')
print(model)

## 10 · Training

- **Optimiser:** Adam, `lr=0.001`, `weight_decay=0.001`

- **Loss:** MSE (only over players who actually played that day)

- **Epochs:** 300 | **Mini-batch:** 20 randomly sampled training days/epoch

- Best checkpoint saved to Drive on every validation improvement.


In [None]:
EPOCHS     = 300
BATCH_SIZE = 20
LR         = 1e-3
WD         = 1e-3

params = (list(model.parameters()) +
          list(team_emb.parameters()) +
          list(pos_emb.parameters()))
optimizer  = torch.optim.Adam(params, lr=LR, weight_decay=WD)
scheduler  = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

train_hist, val_hist = [], []
best_val   = float('inf')
best_epoch = -1

for epoch in range(EPOCHS):
    # ── train ──
    model.train(); team_emb.train(); pos_emb.train()
    tv = team_emb(team_tensor); pv = pos_emb(pos_tensor)
    t_loss = torch.tensor(0.0, device=DEVICE)
    idx = np.random.choice(len(X_tr), size=min(BATCH_SIZE,len(X_tr)), replace=False)
    for i in idx:
        Xl = [torch.cat([X_tr[i,:,:,t].to(DEVICE), tv, pv], 1) for t in range(SEQ_LENGTH)]
        x  = torch.stack(Xl,-1)[None,...]
        p  = model(x, G_tr[i])[0]
        t_loss = t_loss + F.mse_loss(p, y_tr[i,:,PRED_INDICES].to(DEVICE))
    t_loss.backward(); optimizer.step(); optimizer.zero_grad(); scheduler.step()

    # ── validate ──
    model.eval(); team_emb.eval(); pos_emb.eval()
    with torch.no_grad():
        tv2=team_emb(team_tensor); pv2=pos_emb(pos_tensor)
        v_loss = 0.0
        for i in range(len(X_va)):
            Xl = [torch.cat([X_va[i,:,:,t].to(DEVICE), tv2, pv2], 1) for t in range(SEQ_LENGTH)]
            x  = torch.stack(Xl,-1)[None,...]
            p  = model(x, G_va[i])[0]
            v_loss += F.mse_loss(p, y_va[i,:,PRED_INDICES].to(DEVICE)).item()

    train_hist.append(t_loss.item()); val_hist.append(v_loss)
    if v_loss < best_val:
        best_val=v_loss; best_epoch=epoch
        torch.save(model.state_dict(),    f'{MODEL_DIR}/model.pth')
        torch.save(team_emb.state_dict(), f'{MODEL_DIR}/team_emb.pth')
        torch.save(pos_emb.state_dict(),  f'{MODEL_DIR}/pos_emb.pth')
        print(f'  ↓ val {v_loss:.4f} (epoch {epoch})')
    if epoch % 50 == 0:
        print(f'Epoch {epoch:03d} | train {t_loss.item():.4f} | val {v_loss:.4f}')

print(f'\nBest val loss: {best_val:.4f} at epoch {best_epoch}')
np.save(f'{MODEL_DIR}/train_hist.npy', np.array(train_hist))
np.save(f'{MODEL_DIR}/val_hist.npy',   np.array(val_hist))

## 11 · Test Set Evaluation
Load the best checkpoint and compute RMSE, MAE, MAPE, and Fisher-z CORR.

In [None]:
model.load_state_dict(torch.load(f'{MODEL_DIR}/model.pth', map_location=DEVICE))
team_emb.load_state_dict(torch.load(f'{MODEL_DIR}/team_emb.pth', map_location=DEVICE))
pos_emb.load_state_dict(torch.load(f'{MODEL_DIR}/pos_emb.pth', map_location=DEVICE))
model.eval(); team_emb.eval(); pos_emb.eval()

all_preds, all_trues = [], []
with torch.no_grad():
    tv=team_emb(team_tensor); pv=pos_emb(pos_tensor)
    for i in range(len(X_te)):
        Xl=[torch.cat([X_te[i,:,:,t].to(DEVICE),tv,pv],1) for t in range(SEQ_LENGTH)]
        x =torch.stack(Xl,-1)[None,...]
        p =model(x, G_te[i])[0].cpu().numpy()
        t2=y_te[i,:,PRED_INDICES].numpy()
        all_preds.append(p); all_trues.append(t2)

AP=np.concatenate(all_preds); AT=np.concatenate(all_trues)
rmse_v = mean_squared_error(AT, AP, squared=False)
mae_v  = mean_absolute_error(AT, AP)
mape_v = mean_absolute_percentage_error(AT, AP)
corr_z = []
for mi in range(6):
    r=np.corrcoef(AP[:,mi],AT[:,mi])[0,1]
    if not np.isnan(r) and abs(r)<1-1e-7: corr_z.append(np.arctanh(r))
corr_v = np.tanh(np.mean(corr_z)) if corr_z else float('nan')

repro_metrics = {'RMSE':float(rmse_v),'MAE':float(mae_v),
                 'MAPE':float(mape_v),'CORR':float(corr_v)}
print('── Test Metrics ──────────────────')
for k,v in repro_metrics.items(): print(f'  {k}: {v:.4f}')
with open(f'{MODEL_DIR}/test_metrics.json','w') as f: json.dump(repro_metrics,f,indent=2)

## 12 · Visualisations
All figures are saved to your Drive. Run sequentially.

### 12a · Learning Curves

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
ax.plot(train_hist, label='Train Loss', color='#5b8fc7', lw=2)
ax.plot(val_hist,   label='Val Loss',   color='#e07b54', lw=2)
ax.axvline(best_epoch, color='gray', ls='--', lw=1.5, label=f'Best ({best_epoch})')
ax.set_xlabel('Epoch'); ax.set_ylabel('MSE Loss')
ax.set_title('GATv2-TCN Learning Curves', fontweight='bold')
ax.legend(); ax.grid(alpha=0.25)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/loss_curves.png', dpi=150)
plt.show()

### 12b · Model Benchmark Comparison

In [None]:
PAPER = {
    'N-BEATS':     {'RMSE':5.112,'MAE':4.552,'MAPE':3.701,'CORR':0.366},
    'DeepVAR':     {'RMSE':2.896,'MAE':2.151,'MAPE':1.754,'CORR':0.396},
    'TCN':         {'RMSE':2.414,'MAE':1.780,'MAPE':0.551,'CORR':0.418},
    'ASTGCN':      {'RMSE':2.293,'MAE':1.699,'MAPE':0.455,'CORR':0.453},
    'GATv2 (Paper)':{'RMSE':2.222,'MAE':1.642,'MAPE':0.513,'CORR':0.508},
    'GATv2 (Repro)': repro_metrics,
}
COLORS = ['#e07b54','#5b8fc7','#70b472','#c07ec9','#f0c040','#e84393']
mets   = ['RMSE','MAE','CORR']
fig, axes = plt.subplots(1,3,figsize=(16,5))
for ax, met in zip(axes, mets):
    models = list(PAPER.keys())
    vals   = [PAPER[m][met] for m in models]
    bars   = ax.bar(range(len(models)), vals, color=COLORS, edgecolor='white', width=0.65)
    for b,v in zip(bars,vals):
        ax.text(b.get_x()+b.get_width()/2, b.get_height()+0.01*max(vals),
                f'{v:.3f}', ha='center', fontsize=8.5, fontweight='bold')
    ax.set_xticks(range(len(models)))
    ax.set_xticklabels(models, rotation=35, ha='right', fontsize=9)
    ax.set_title(met, fontweight='bold'); ax.grid(axis='y',alpha=0.2)
plt.suptitle('Model Benchmark Comparison', fontweight='bold')
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

### 12c · Predicted vs. Actual (per statistic)

In [None]:
colors=['#5b8fc7','#e07b54','#70b472','#c07ec9','#f0c040','#e84393']
fig, axes = plt.subplots(2,3,figsize=(15,9))
for i,(ax,stat) in enumerate(zip(axes.ravel(), PREDICTION_COLS)):
    p,t = AP[:,i], AT[:,i]
    ax.scatter(t, p, alpha=0.25, s=8, color=colors[i], rasterized=True)
    lo,hi = min(t.min(),p.min()), max(t.max(),p.max())
    ax.plot([lo,hi],[lo,hi],'k--',lw=1.5,alpha=0.6,label='Perfect')
    m,b,r,*_ = stats.linregress(t,p)
    xs=np.array([lo,hi]); ax.plot(xs,m*xs+b,'r-',lw=2,label=f'r={r:.3f}')
    ax.set_title(f'{stat} | MAE={mean_absolute_error(t,p):.2f}', fontweight='bold')
    ax.set_xlabel('Actual'); ax.set_ylabel('Predicted')
    ax.legend(fontsize=8); ax.grid(alpha=0.15)
plt.suptitle('GATv2-TCN: Predicted vs. Actual (Test Set)', fontweight='bold')
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/pred_vs_actual.png', dpi=150, bbox_inches='tight')
plt.show()

### 12d · Per-Statistic Error Bar Chart

In [None]:
mae_s  = np.abs(AP-AT).mean(0)
rmse_s = np.sqrt(((AP-AT)**2).mean(0))
corr_s = [np.corrcoef(AP[:,i],AT[:,i])[0,1] for i in range(6)]
x=np.arange(6); w=0.28
fig,ax=plt.subplots(figsize=(12,5))
for off, vals, lbl, col in [(-w,mae_s,'MAE','#5b8fc7'),(0,rmse_s,'RMSE','#e07b54'),(w,corr_s,'CORR','#70b472')]:
    bars=ax.bar(x+off, vals, w, label=lbl, color=col, edgecolor='white')
    for bar in bars:
        ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.01,
                f'{bar.get_height():.2f}', ha='center', fontsize=8)
ax.set_xticks(x); ax.set_xticklabels(PREDICTION_COLS, fontsize=12)
ax.set_title('Per-Statistic Error Metrics (Test Set)', fontweight='bold')
ax.legend(); ax.grid(axis='y',alpha=0.2)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/per_stat_errors.png', dpi=150)
plt.show()

### 12e · Residual Distributions

In [None]:
res = AP - AT
fig,axes=plt.subplots(2,3,figsize=(15,8))
for i,(ax,stat) in enumerate(zip(axes.ravel(),PREDICTION_COLS)):
    r=res[:,i]
    ax.hist(r, bins=60, color=colors[i], edgecolor='white', alpha=0.85, density=True)
    xs=np.linspace(r.min(),r.max(),300)
    ax.plot(xs, stats.norm.pdf(xs,r.mean(),r.std()), 'k-', lw=2)
    ax.axvline(0,          color='red',    ls='--', lw=1.5)
    ax.axvline(r.mean(),   color='orange', ls='-',  lw=1.5, label=f'μ={r.mean():.2f}')
    ax.set_title(stat, fontweight='bold')
    ax.legend(fontsize=8); ax.grid(axis='y',alpha=0.2)
plt.suptitle('Residual Error Distributions', fontweight='bold')
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/residuals.png', dpi=150, bbox_inches='tight')
plt.show()

### 12f · Correlation Heatmap (Features)

In [None]:
STAT_LBL=['PTS','AST','REB','TO','STL','BLK','±','TCHS','PASS','DIST','PACE','USG%','TS%']
n=X_seq.shape[-1]
flat=X_seq.reshape(-1,n)
flat=flat[(flat!=0).any(1)]
corr_mat=np.corrcoef(flat.T)
fig,ax=plt.subplots(figsize=(9,7))
im=ax.imshow(corr_mat, cmap='RdBu_r', vmin=-1, vmax=1)
ax.set_xticks(range(n)); ax.set_xticklabels(STAT_LBL[:n],rotation=45,ha='right')
ax.set_yticks(range(n)); ax.set_yticklabels(STAT_LBL[:n])
for i in range(n):
    for j in range(n):
        ax.text(j,i,f'{corr_mat[i,j]:.2f}',ha='center',va='center',fontsize=7,
               color='white' if abs(corr_mat[i,j])>0.5 else 'black')
plt.colorbar(im,ax=ax,shrink=0.8,label='Pearson r')
ax.set_title('Feature Correlation Matrix', fontweight='bold')
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/correlation_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

### 12g · Graph Topology Sample

In [None]:
# Visualise one game-day graph (mid-dataset)
id2name = {k:v for k,v in zip(player_ids, [str(p) for p in player_ids])}
try:
    id2name_pkl = f'{REPO_ROOT}/player_id2name.pkl'
    if os.path.exists(id2name_pkl):
        id2name = pickle.load(open(id2name_pkl,'rb'))
except Exception: pass

G_sample = G_seq[len(G_seq)//2]
active   = [n for n in G_sample.nodes() if G_sample.degree(n)>0][:60]
subG     = G_sample.subgraph(active)
comps    = list(nx.connected_components(subG))
cmap     = plt.cm.get_cmap('tab20', len(comps))
nc = {}
for k,comp in enumerate(comps):
    for n in comp: nc[n]=cmap(k%20)

fig,ax=plt.subplots(figsize=(13,10))
pos=nx.kamada_kawai_layout(subG)
labels={n:str(id2name.get(n,n))[:10] for n in subG.nodes()}
nx.draw_networkx_nodes(subG,pos,ax=ax,node_color=[nc.get(n,'#ccc') for n in subG.nodes()],
                       node_size=250,alpha=0.93,linewidths=1.3,edgecolors='white')
nx.draw_networkx_edges(subG,pos,ax=ax,alpha=0.3,edge_color='#888',width=1.2)
nx.draw_networkx_labels(subG,pos,labels=labels,ax=ax,font_size=6.5)
ax.set_title('Game-Day Graph Topology (colours = game clusters)', fontweight='bold')
ax.axis('off'); plt.tight_layout()
plt.savefig(f'{FIG_DIR}/graph_topology.png', dpi=150, bbox_inches='tight')
plt.show()

### 12h · Performance Radar Chart

In [None]:
all_r = {**PAPER, 'GATv2 (Repro)': repro_metrics}
def to_radar(r): return {'1/RMSE':1/max(r['RMSE'],1e-6),'1/MAE':1/max(r['MAE'],1e-6),
                          '1/MAPE':1/max(r['MAPE'],1e-6),'CORR':r['CORR']}
normed={m:to_radar(v) for m,v in all_r.items()}
cats=list(next(iter(normed.values())).keys())
angles=np.linspace(0,2*np.pi,len(cats),endpoint=False).tolist()+[0]
RCOLS=['#e07b54','#5b8fc7','#70b472','#c07ec9','#f0c040','#e84393']
fig,ax=plt.subplots(figsize=(8,8),subplot_kw={'polar':True})
for (model,m),col in zip(normed.items(),RCOLS):
    vals=list(m.values())+[list(m.values())[0]]
    ax.plot(angles,vals,lw=2.5,color=col,label=model)
    ax.fill(angles,vals,alpha=0.07,color=col)
ax.set_xticks(angles[:-1]); ax.set_xticklabels(cats,fontsize=12)
ax.set_yticklabels([])
ax.set_title('Performance Radar (higher=better on all axes)',fontweight='bold',pad=25)
ax.legend(loc='upper right',bbox_to_anchor=(1.45,1.1),fontsize=9)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/radar.png', dpi=150, bbox_inches='tight')
plt.show()

### 12i · Population-Average PTS Forecast Trend

In [None]:
T_te=len(all_preds)
mean_p=[all_preds[i][:,0].mean() for i in range(T_te)]
mean_t=[all_trues[i][:,0].mean() for i in range(T_te)]
std_p =[all_preds[i][:,0].std()  for i in range(T_te)]
x=np.arange(T_te)
fig,ax=plt.subplots(figsize=(14,4.5))
ax.plot(x,mean_t,color='#e07b54',lw=2.5,ls='--',label='Actual PTS')
ax.plot(x,mean_p,color='#5b8fc7',lw=2.5,label='Predicted PTS')
ax.fill_between(x,np.array(mean_p)-np.array(std_p),
                  np.array(mean_p)+np.array(std_p),alpha=0.18,color='#5b8fc7',label='±1σ')
ax.set_xlabel('Test-Set Game-Day Index'); ax.set_ylabel('PTS (Z-Normalised)')
ax.set_title('Population-Average PTS Forecast vs. Actuals',fontweight='bold')
ax.legend(); ax.grid(axis='y',alpha=0.2)
plt.tight_layout()
plt.savefig(f'{FIG_DIR}/pts_trend.png', dpi=150)
plt.show()

## 13 · Case Study — Underdog Fantasy Prop Pick'em Replay

The paper evaluated the trained model on the January 20 2023 slate.

A pick is **correct** if the model's prediction and the true outcome are on the **same side** of the prop line.

The paper achieved **35/59 = 59.3 %**, well above the 52.4 % breakeven for -110 juice.


In [None]:
# Load original X_seq / G_seq from the repo (covers the 2022-23 season only)
ORIG_X = f'{REPO_ROOT}/data/X_seq.pkl'
ORIG_G = f'{REPO_ROOT}/data/G_seq.pkl'
ORIG_N = f'{REPO_ROOT}/player_id2name.pkl'

if all(os.path.exists(p) for p in [ORIG_X,ORIG_G,ORIG_N]):
    ox = pickle.load(open(ORIG_X,'rb'))
    og = pickle.load(open(ORIG_G,'rb'))
    on = pickle.load(open(ORIG_N,'rb'))

    og_edges = graphs_to_edges(og, list(on.keys()))
    Xs2=np.zeros_like(ox)
    for i in range(ox.shape[1]): Xs2[:,i,:]=fill_zeros_with_last(ox[:,i,:])

    # Use last 10 days as input window → predict day 91
    with torch.no_grad():
        tv=team_emb(team_tensor); pv=pos_emb(pos_tensor)
        Xl=[torch.cat([torch.FloatTensor(Xs2[-11+t]).to(DEVICE), tv, pv],1)
            for t in range(SEQ_LENGTH)]
        x=torch.stack(Xl,-1)[None,...]
        preds_jan20 = model(x, og_edges[-11:-1])[0].cpu().numpy()

    TODAY_PROPS = {
        'CJ McCollum':         {'PTS':25.5},
        'Jonas Valanciunas':   {'PTS':19.0,'REB':11.5},
        'Paolo Banchero':      {'PTS':9.5,'REB':6.0},
        'Luka Doncic':         {'PTS':34.5,'REB':10.5,'AST':9.0},
        'Jimmy Butler':        {'PTS':20.5,'REB':6.0,'STL':1.5},
        'Bam Adebayo':         {'PTS':21.5,'REB':10.5},
        'Darius Garland':      {'PTS':25.5,'AST':8.5},
        'Evan Mobley':         {'PTS':16.5,'REB':9.0},
    }
    pid_list   = list(on.keys())
    name_list  = list(on.values())
    correct=0; total=0; rows=[]
    for player,props in TODAY_PROPS.items():
        if player not in name_list: continue
        pidx = name_list.index(player)
        true_row = ox[-1][pidx] if pidx < ox.shape[1] else None
        for stat, line in props.items():
            if stat not in PREDICTION_COLS: continue
            si = PREDICTION_COLS.index(stat)
            pred_val = float(preds_jan20[pidx, si])
            true_val = float(ox[-1][pidx, si]) if true_row is not None else float('nan')
            over_pred = pred_val > line
            over_true = true_val > line
            ok = over_pred == over_true
            if not np.isnan(true_val):
                correct += int(ok); total += 1
            rows.append({'Player':player,'Stat':stat,'Line':line,
                         'Predicted':round(pred_val,2),'Actual':round(true_val,2),
                         'Correct':ok})
    df_case = pd.DataFrame(rows)
    print(f'Pick accuracy: {correct}/{total} = {100*correct/max(total,1):.1f}%')
    display(df_case)
else:
    print('Original repo data not found at expected Drive path – skipping case study.')

## 14 · Final Summary Table

In [None]:
summary = pd.DataFrame(PAPER).T.rename_axis('Model')
summary.loc['GATv2 (Repro)'] = pd.Series(repro_metrics)
summary = summary.round(3)
display(summary.style.background_gradient(cmap='RdYlGn_r', subset=['RMSE','MAE','MAPE'])
                       .background_gradient(cmap='RdYlGn',  subset=['CORR']))
summary.to_csv(f'{MODEL_DIR}/summary_table.csv')
print('Saved summary table to Drive.')

---

## ✅ All Done!

Your Drive `knowball/outputs/` folder now contains:

- `model/model.pth` — best GATv2-TCN checkpoint

- `model/test_metrics.json` — RMSE / MAE / MAPE / CORR

- `model/train_hist.npy` + `val_hist.npy` — loss curves

- `model/summary_table.csv` — full comparison table

- `figures/*.png` — all publication-quality figures
