## 7) Neural‐network extension: Siamese DTW

In [None]:
# %%
import sys
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Ensure project root is on PYTHONPATH
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# %%
# 7.1) Imports
from src.nn_utils.data import load_siamese_data
from src.nn_utils.model import build_siamese_dtw_model, build_baseline_nn
from src.nn_utils.training import train_siamese_model
from src.evaluation.evaluation import compute_metrics, plot_roc, plot_det, dtw_distance
from sklearn.model_selection import train_test_split
from keras.utils import to_categorical           # ← switched import to satisfy Pylance

# %%
# 7.2) Paths and constants
PAIRS_PATH = project_root / "data" / "pairs_meta.parquet"
CACHE_PATH = project_root / "data" / "dtw_cache.parquet"
PROC_ROOT  = project_root / "data" / "processed"
FIG_DIR    = project_root / "figures"
FIG_DIR.mkdir(exist_ok=True, parents=True)

SEQ_LEN    = 100   # target sequence length
N_FEATURES = 9     # localFunctions dimension

# %%
# 7.3) Load pairs metadata & prepare data
pairs_df = pd.read_parquet(PAIRS_PATH)
X1, X2, y = load_siamese_data(pairs_df, PROC_ROOT, SEQ_LEN)

# %%
# 7.4) Train/test split
X1_tr, X1_te, X2_tr, X2_te, y_tr, y_te = train_test_split(
    X1, X2, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

# One‐hot encode labels for categorical crossentropy
y_tr_cat = to_categorical(y_tr, num_classes=2)
y_te_cat = to_categorical(y_te, num_classes=2)

# %%
# 7.5) Build & inspect Siamese‐DTW model
siamese_model = build_siamese_dtw_model(
    sequence_length=SEQ_LEN,
    n_features=N_FEATURES,
    hidden_dims=(7,5),
    dtw_gamma=1.0,
    post_hidden=(16,8),
)
siamese_model.summary()

# %%
# 7.6) Build & inspect baseline NN (no DTW)
# Create 18‐dim features by averaging each sequence over time
X1_avg = X1.mean(axis=1)                        # (N, 9)
X2_avg = X2.mean(axis=1)                        # (N, 9)
X_base = np.concatenate([X1_avg, X2_avg], axis=1)  # (N, 18)

# Split baseline features (same random_state for alignment)
X_base_tr, X_base_te, _, _ = train_test_split(
    X_base, y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

baseline_model = build_baseline_nn(
    input_dim=2 * N_FEATURES,
    hidden_dims=(32,16),
)
baseline_model.summary()

# %%
# 7.7) Train Siamese‐DTW model
history_siamese = train_siamese_model(
    siamese_model,
    X1_tr, X2_tr, y_tr_cat,
    batch_size=32,
    epochs=20,
    validation_split=0.2
)

# %%
# 7.8) Train baseline NN
history_baseline = baseline_model.fit(
    X_base_tr, y_tr_cat,
    batch_size=32,
    epochs=20,
    validation_split=0.2
)

# %%
# 7.9) Evaluate Siamese‐DTW on test set
y_pred_siam = siamese_model.predict([X1_te, X2_te])[:, 1]
df_siam     = pd.DataFrame({"label": y_te, "score": y_pred_siam})
metrics_siam = compute_metrics(df_siam.rename(columns={"score":"d_raw"}))
print(f"Siamese-DTW   → AUC={metrics_siam['auc']:.4f}, "
      f"EER={metrics_siam['eer']:.4f} @ thr={metrics_siam['eer_threshold']:.5f}")

# %%
# 7.10) Evaluate baseline NN on test set
y_pred_base  = baseline_model.predict(X_base_te)[:, 1]
df_base      = pd.DataFrame({"label": y_te, "score": y_pred_base})
metrics_base = compute_metrics(df_base.rename(columns={"score":"d_raw"}))
print(f"Baseline NN  → AUC={metrics_base['auc']:.4f}, "
      f"EER={metrics_base['eer']:.4f} @ thr={metrics_base['eer_threshold']:.5f}")

# %%
# 7.11) Evaluate classic (non-differentiable) DTW classifier
dists       = np.array([dtw_distance(a, b) for a, b in zip(X1_te, X2_te)])
scores_dtw  = -dists   # invert so “higher = more similar”
df_dtw      = pd.DataFrame({"label": y_te, "score": scores_dtw})
metrics_dtw = compute_metrics(df_dtw.rename(columns={"score":"d_raw"}))
print(f"Classic DTW → AUC={metrics_dtw['auc']:.4f}, "
      f"EER={metrics_dtw['eer']:.4f} @ thr={metrics_dtw['eer_threshold']:.5f}")

# %%
# 7.12) Plot ROC comparison
fig, ax = plt.subplots(figsize=(6,6))
ax.plot(metrics_siam["fpr"], metrics_siam["tpr"],
        label=f"Siamese-DTW (AUC={metrics_siam['auc']:.3f})")
ax.plot(metrics_base["fpr"], metrics_base["tpr"],
        label=f"Baseline NN (AUC={metrics_base['auc']:.3f})")
ax.plot(metrics_dtw["fpr"], metrics_dtw["tpr"],
        label=f"Classic DTW (AUC={metrics_dtw['auc']:.3f})")
ax.plot([0,1], [0,1], "k--")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
fig.savefig(FIG_DIR / "comparison_roc.png", dpi=300)
plt.close(fig)

# %%
# 7.13) Plot DET comparison
fig, ax = plt.subplots(figsize=(6,6))
ax.plot(metrics_siam["fpr"], 1 - metrics_siam["tpr"], label="Siamese-DTW")
ax.plot(metrics_base["fpr"], 1 - metrics_base["tpr"], label="Baseline NN")
ax.plot(metrics_dtw["fpr"], 1 - metrics_dtw["tpr"], label="Classic DTW")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("Miss Rate (1 − TPR)")
ax.legend()
fig.savefig(FIG_DIR / "comparison_det.png", dpi=300)
plt.close(fig)


Epoch 1/20
[1m768/768[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2292s[0m 3s/step - accuracy: 0.7749 - loss: 0.7487 - val_accuracy: 0.8553 - val_loss: 0.3496
Epoch 2/20
[1m 76/768[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m31:16[0m 3s/step - accuracy: 0.8555 - loss: 0.3703