## 7) Neural‐network extension: Siamese DTW

In [None]:
# %%
import sys
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# 7.1) Imports
from src.nn_utils.data import load_siamese_data
from src.nn_utils.model import build_siamese_dtw_model
from src.nn_utils.training import train_siamese_model
from src.evaluation.evaluation import compute_metrics, plot_roc, plot_det

# 7.2) Paths and constants
PAIRS_PATH   = project_root / "data" / "pairs_meta.parquet"
CACHE_PATH   = project_root / "data" / "dtw_cache.parquet"
PROC_ROOT    = project_root / "data" / "processed"
FIG_DIR      = project_root / "figures"
FIG_DIR.mkdir(exist_ok=True, parents=True)

SEQ_LEN      = 100   # target sequence length
N_FEATURES   = 9     # localFunctions dimension

# 7.3) Load pairs metadata & prepare data
pairs_df = pd.read_parquet(PAIRS_PATH)
X1, X2, y = load_siamese_data(pairs_df, PROC_ROOT, SEQ_LEN)

# 7.4) Train/test split
from sklearn.model_selection import train_test_split
X1_tr, X1_te, X2_tr, X2_te, y_tr, y_te = train_test_split(
    X1, X2, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

# 7.5) Build & inspect model
model = build_siamese_dtw_model(
    sequence_length=SEQ_LEN,
    n_features=N_FEATURES,
    hidden_dims=(7,5),
    gamma=1.0
)
model.summary()

# 7.6) Train model
history = train_siamese_model(
    model, X1_tr, X2_tr, y_tr,
    batch_size=32,
    epochs=20,
    validation_split=0.2
)

# 7.7) Evaluate on test
y_pred = model.predict([X1_te, X2_te]).ravel()
df_test = pd.DataFrame({"label": y_te, "score": y_pred})

metrics = compute_metrics(df_test.rename(columns={"score": "d_raw"}))
print(f"AUC (Siamese-DTW): {metrics['auc']:.4f}")
print(f"EER: {metrics['eer']:.4f} @ thr={metrics['eer_threshold']:.5f}")

# 7.8) Plot ROC & DET
fig, ax = plt.subplots(figsize=(6,6))
plot_roc(metrics["fpr"], metrics["tpr"], metrics["auc"], ax=ax)
fig.savefig(FIG_DIR/"siamese_roc.png", dpi=300); plt.close(fig)

fig, ax = plt.subplots(figsize=(6,6))
plot_det(metrics["fpr"], 1-metrics["tpr"], ax=ax)
fig.savefig(FIG_DIR/"siamese_det.png", dpi=300); plt.close(fig)
