# Stage 1: Physics-Informed Hybrid Liquid Neural Network (PI-HLNN)
## LOS / NLOS Classification from Raw CIR

**Architecture**: Pure Liquid Neural Network with input+state driven time constants (tau).
**Physics-Informed**: RXPACC normalization, ODE dynamics, tau via `softplus` (no hardcoded targets).
**Hybrid**: Combines physics priors (preprocessing, loss, ODE structure) with data-driven learned dynamics.
**Input**: Raw 60-sample CIR window â€” no hand-crafted features. The LNN learns temporal dynamics directly from the signal.
**Training**: `combined_uwb_dataset.csv` (3600 samples, 6 UWB channels), 70/15/15 train/val/test split.
**Pipeline**: **Stage 1 (LNN â†’ LOS/NLOS)** â†’ Stage 2 (MLP â†’ single/multi bounce) â†’ Stage 3 (MLP â†’ ranging error)

In [None]:
CONFIG = {
    "pre_crop": 10,
    "post_crop": 50,
    "total_len": 60,
    "search_start": 740,
    "search_end": 890,
    "hidden_size": 32,        # per circuit; total embedding = 64 (2 × 32)
    "input_size": 1,
    "dropout": 0.2,
    "ode_unfolds": 6,
    "batch_size": 64,
    "max_epochs": 40,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "warmup_epochs": 3,
    "patience": 10,
    "grad_clip": 1.0,
    "val_ratio": 0.15,
    "test_ratio": 0.15,
    "seed": 42,
}


---
## Section 2: Data Loading, ROI Alignment & 70/15/15 Split

**ROI search range**: `[740, 890]` â€” derived from empirical CIR peak analysis:
- All CIR peaks across 36 CSV files (LOS + NLOS, 6 channels) fall within indices **743â€“807**.
- `search_start=740`: 3 indices before the earliest observed peak (743), provides margin.
- `search_end=890`: ~80 indices past the latest observed peak (807), captures any multipath tail.
- Noise floor estimated from indices 0â€“739 (740 samples of pure noise = robust estimate).

In [None]:
# ==========================================
# SMART ROI ALIGNMENT
# ==========================================
def get_roi_alignment(sig, search_start=CONFIG["search_start"],
                      search_end=CONFIG["search_end"]):
    """
    Find the leading edge of the pulse by backtracking from peak.
    Uses noise floor estimation (mean + 3*std) or 5% of peak.

    Search range [740, 890] derived from empirical CIR peak analysis:
      - All peaks across 36 files fall within indices 743-807
      - 740 start = margin before earliest peak
      - 890 end = margin past latest multipath outlier
    """
    region = sig[search_start:search_end]
    if len(region) == 0:
        return np.argmax(sig)

    peak_local = np.argmax(region)
    peak_idx = search_start + peak_local
    peak_val = sig[peak_idx]

    # Noise floor from samples before the search region
    noise_section = sig[:search_start]
    if len(noise_section) > 10:
        noise_mean = np.mean(noise_section)
        noise_std = np.std(noise_section)
        threshold = max(noise_mean + 3 * noise_std, 0.05 * peak_val)
    else:
        threshold = 0.05 * peak_val

    # Backtrack from peak to find leading edge
    leading_edge = peak_idx
    for i in range(peak_idx, max(search_start - 20, 0), -1):
        if sig[i] < threshold:
            leading_edge = i + 1
            break

    return leading_edge


# ==========================================
# DATASET LOADER (CIR-ONLY)
# ==========================================
def load_cir_dataset(filepath="../dataset/channels/combined_uwb_dataset.csv"):
    """
    Load and preprocess CIR data. Returns only the CIR sequence and labels.
    No hand-crafted features â€” the LNN learns directly from the signal.
    """
    PRE = CONFIG["pre_crop"]
    POST = CONFIG["post_crop"]
    TOTAL = CONFIG["total_len"]

    processed_seqs = []
    labels = []

    print(f"Loading: {filepath}")
    df = pd.read_csv(filepath)
    cir_cols = sorted(
        [c for c in df.columns if c.startswith('CIR')],
        key=lambda x: int(x.replace('CIR', ''))
    )
    print(f"  Samples: {len(df)}, CIR columns: {len(cir_cols)}")

    for _, row in df.iterrows():
        # 1. Parse raw signal
        sig = pd.to_numeric(row[cir_cols], errors='coerce').fillna(0).astype(float).values

        # 2. PHYSICS NORMALIZATION: divide by RXPACC
        rxpacc_col = 'RXPACC' if 'RXPACC' in row.index else 'RX_PACC'
        rxpacc = float(row.get(rxpacc_col, 128.0))
        if rxpacc > 0:
            sig = sig / rxpacc

        # 3. Smart ROI alignment
        leading_edge = get_roi_alignment(sig)

        # 4. Crop around leading edge
        start = max(0, leading_edge - PRE)
        end = start + TOTAL
        if end > len(sig):
            end = len(sig)
            start = max(0, end - TOTAL)

        crop = sig[start:end]
        if len(crop) < TOTAL:
            crop = np.pad(crop, (0, TOTAL - len(crop)), mode='constant')

        # 5. Instance normalization [0, 1]
        local_min = np.min(crop)
        local_max = np.max(crop)
        rng = local_max - local_min
        if rng > 0:
            crop = (crop - local_min) / rng
        else:
            crop = np.zeros(TOTAL)

        processed_seqs.append(crop)
        labels.append(float(row['Label']))

    X = np.array(processed_seqs).reshape(-1, TOTAL, 1).astype(np.float32)
    y = np.array(labels).astype(np.float32)

    print(f"  Output shape: X={X.shape}, y={y.shape}")
    print(f"  LOS: {int(np.sum(y == 0))}, NLOS: {int(np.sum(y == 1))}")
    return X, y


# Load and split 70/15/15
X_all, y_all = load_cir_dataset("../dataset/channels/combined_uwb_dataset.csv")

# 70/15/15 stratified split
val_ratio = CONFIG["val_ratio"]
test_ratio = CONFIG["test_ratio"]

X_train, X_temp, y_train, y_temp = train_test_split(
    X_all, y_all,
    test_size=val_ratio + test_ratio,
    stratify=y_all,
    random_state=CONFIG["seed"]
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=test_ratio / (val_ratio + test_ratio),
    stratify=y_temp,
    random_state=CONFIG["seed"]
)

print(f"\nSplit (70/15/15):")
print(f"  Train: {X_train.shape[0]} (LOS: {int(np.sum(y_train==0))}, NLOS: {int(np.sum(y_train==1))})")
print(f"  Val:   {X_val.shape[0]} (LOS: {int(np.sum(y_val==0))}, NLOS: {int(np.sum(y_val==1))})")
print(f"  Test:  {X_test.shape[0]} (LOS: {int(np.sum(y_test==0))}, NLOS: {int(np.sum(y_test==1))})")

---
## Section 3: Model Architecture — Dual-Circuit PI-HLNN

### Design: Two Parallel PILiquidCell Circuits with Cross-Circuit Communication

Two specialised LTC circuits run in parallel on the same CIR sequence:

- **cell_los** — develops dynamics suited to LOS-type CIR (sharp, single-peak)
- **cell_nlos** — develops dynamics suited to NLOS-type CIR (broader, multi-peak)

At every timestep the circuits **talk to each other** via gated projection matrices:

$$h_{\text{los,in}} = h_{\text{los}} + g_{\text{los}} \cdot P_{\text{nlos}\to\text{los}}(h_{\text{nlos}})$$
$$h_{\text{nlos,in}} = h_{\text{nlos}} + g_{\text{nlos}} \cdot P_{\text{los}\to\text{nlos}}(h_{\text{los}})$$

where $g = \sigma(\text{Linear}([h_{\text{own}},\, h_{\text{cross}}]))$ is a learned sigmoid gate.

Each circuit's hidden state is attention-pooled independently, then **concatenated** to form a 64-dim fused embedding:

```
cell_los  (hidden=32) → attn → h_los_pooled  (32)  ┐
                                                      ├─ cat → 64-dim → classifier
cell_nlos (hidden=32) → attn → h_nlos_pooled (32)  ┘
```

**Parameter count (32+32 dual)**: ~17 k — comparable to the original single-circuit (64-hidden, ~19 k).  
**Embedding dim = 64** — same as before, so Stages 2 & 3 are unchanged.

### Why τ is still adaptive

In each PILiquidCell:  $\tau = C_m / (g_{\text{leak}} + \sum w \cdot \text{gate}(v))$

τ adapts per timestep to the signal energy. The cross-circuit gate additionally allows the NLOS circuit to slow down the LOS circuit's dynamics (increase τ) when it detects a complex multipath structure — implementing the FP-to-LOS threshold insight from the EDA.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PILiquidCell(nn.Module):
    """
    Conductance-based LTC cell (Hasani et al. 2020).
    - Recurrent synapses: full conductance with reversal potentials
    - Sensory synapses: gated additive input
    - Softplus on conductances only
    - ODE solved via semi-implicit Euler
    """
    def __init__(self, input_size, hidden_size, ode_unfolds=6):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_size  = input_size
        self.ode_unfolds = ode_unfolds

        self.gleak = nn.Parameter(torch.empty(hidden_size).uniform_(0.001, 1.0))
        self.vleak = nn.Parameter(torch.empty(hidden_size).uniform_(-0.2, 0.2))
        self.cm    = nn.Parameter(torch.empty(hidden_size).uniform_(0.4, 0.6))

        self.w     = nn.Parameter(torch.empty(hidden_size, hidden_size).uniform_(0.001, 1.0))
        self.erev  = nn.Parameter(torch.empty(hidden_size, hidden_size).uniform_(-0.2, 0.2))
        self.mu    = nn.Parameter(torch.empty(hidden_size, hidden_size).uniform_(0.3, 0.8))
        self.sigma = nn.Parameter(torch.empty(hidden_size, hidden_size).uniform_(3, 8))

        self.sensory_w     = nn.Parameter(torch.empty(input_size, hidden_size).uniform_(0.001, 1.0))
        self.sensory_mu    = nn.Parameter(torch.empty(input_size, hidden_size).uniform_(0.3, 0.8))
        self.sensory_sigma = nn.Parameter(torch.empty(input_size, hidden_size).uniform_(3, 8))

    def forward(self, x_t, h_prev, dt=1.0):
        gleak     = F.softplus(self.gleak)
        cm        = F.softplus(self.cm)
        w         = F.softplus(self.w)
        sensory_w = F.softplus(self.sensory_w)

        sensory_gate    = torch.sigmoid(self.sensory_sigma * (x_t.unsqueeze(-1) - self.sensory_mu))
        sensory_current = (sensory_w * sensory_gate * x_t.unsqueeze(-1)).sum(dim=1)

        cm_t = cm / (dt / self.ode_unfolds)
        v    = h_prev

        for _ in range(self.ode_unfolds):
            recurrent_gate = torch.sigmoid(self.sigma.unsqueeze(0) * (v.unsqueeze(2) - self.mu.unsqueeze(0)))
            w_gate = w.unsqueeze(0) * recurrent_gate
            w_num  = (w_gate * self.erev.unsqueeze(0)).sum(dim=1)
            w_den  = w_gate.sum(dim=1)
            numerator   = cm_t * v + gleak * self.vleak + w_num + sensory_current
            denominator = cm_t + gleak + w_den + 1e-8
            v = numerator / denominator

        tau = cm / (gleak + w_den + 1e-8)
        return v, tau


class DualCircuit_PI_HLNN(nn.Module):
    """
    Dual-circuit PI-HLNN with cross-circuit communication.
    Two parallel PILiquidCell circuits (hidden_size each):
      - cell_los:  specialises in LOS channel dynamics (sharp, single-peak CIR)
      - cell_nlos: specialises in NLOS channel dynamics (broader, multi-peak CIR)
    At each timestep, circuits exchange information via gated projection matrices.
    Output: 2*hidden_size fused embedding (32+32 = 64-dim) → same as single-circuit baseline.
    """
    def __init__(self, input_size=1, hidden_size=32, dropout=0.4, ode_unfolds=6):
        super().__init__()
        self.hidden_size = hidden_size  # per circuit

        # Two specialised circuits
        self.cell_los  = PILiquidCell(input_size, hidden_size, ode_unfolds)
        self.cell_nlos = PILiquidCell(input_size, hidden_size, ode_unfolds)

        # Cross-circuit projection matrices (no bias — pure linear mixing)
        self.P_nlos2los = nn.Linear(hidden_size, hidden_size, bias=False)
        self.P_los2nlos = nn.Linear(hidden_size, hidden_size, bias=False)

        # Gated cross-circuit mixing: [own_state | projected_cross] → gate
        self.gate_los  = nn.Linear(hidden_size * 2, hidden_size)
        self.gate_nlos = nn.Linear(hidden_size * 2, hidden_size)

        # Per-circuit attention pooling
        self.los_attn  = nn.Linear(hidden_size, 1)
        self.nlos_attn = nn.Linear(hidden_size, 1)

        # Classifier: 2*hidden → hidden → 1
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )

    def _run_circuits(self, x_seq):
        batch_size, seq_len, _ = x_seq.size()
        h_los  = torch.zeros(batch_size, self.hidden_size, device=x_seq.device)
        h_nlos = torch.zeros(batch_size, self.hidden_size, device=x_seq.device)

        los_states, nlos_states = [], []
        tau_los_sum  = torch.zeros_like(h_los)
        tau_nlos_sum = torch.zeros_like(h_nlos)
        tau_los_hist_list, tau_nlos_hist_list = [], []

        for t in range(seq_len):
            x_t = x_seq[:, t, :]

            proj_nlos_to_los = self.P_nlos2los(h_nlos)
            proj_los_to_nlos = self.P_los2nlos(h_los)

            g_los  = torch.sigmoid(self.gate_los( torch.cat([h_los,  proj_nlos_to_los], dim=1)))
            g_nlos = torch.sigmoid(self.gate_nlos(torch.cat([h_nlos, proj_los_to_nlos], dim=1)))

            h_los_in  = h_los  + g_los  * proj_nlos_to_los
            h_nlos_in = h_nlos + g_nlos * proj_los_to_nlos

            h_los,  tau_los  = self.cell_los( x_t, h_los_in)
            h_nlos, tau_nlos = self.cell_nlos(x_t, h_nlos_in)

            los_states.append(h_los.unsqueeze(1))
            nlos_states.append(h_nlos.unsqueeze(1))
            tau_los_sum  += tau_los
            tau_nlos_sum += tau_nlos
            tau_los_hist_list.append(tau_los.unsqueeze(1))
            tau_nlos_hist_list.append(tau_nlos.unsqueeze(1))

        los_all  = torch.cat(los_states,  dim=1)
        nlos_all = torch.cat(nlos_states, dim=1)
        tau_los_mean  = tau_los_sum  / seq_len
        tau_nlos_mean = tau_nlos_sum / seq_len
        tau_los_hist  = torch.cat(tau_los_hist_list,  dim=1)
        tau_nlos_hist = torch.cat(tau_nlos_hist_list, dim=1)
        return los_all, nlos_all, tau_los_hist, tau_nlos_hist, tau_los_mean, tau_nlos_mean

    def _pool_and_fuse(self, los_all, nlos_all):
        los_w  = F.softmax(self.los_attn(los_all).squeeze(-1),   dim=1).unsqueeze(-1)
        nlos_w = F.softmax(self.nlos_attn(nlos_all).squeeze(-1), dim=1).unsqueeze(-1)
        h_los_pooled  = (los_all  * los_w).sum(dim=1)
        h_nlos_pooled = (nlos_all * nlos_w).sum(dim=1)
        return torch.cat([h_los_pooled, h_nlos_pooled], dim=1)   # (batch, 2*hidden)

    def forward(self, x_seq, return_dynamics=False):
        los_all, nlos_all, tau_los_hist, tau_nlos_hist, tau_los_mean, tau_nlos_mean = \
            self._run_circuits(x_seq)
        h_fused = self._pool_and_fuse(los_all, nlos_all)
        pred = self.classifier(h_fused)
        if return_dynamics:
            return pred, los_all, nlos_all, tau_los_hist, tau_nlos_hist, tau_los_mean, tau_nlos_mean
        return pred, tau_los_mean, tau_nlos_mean

    def embed(self, x_seq):
        """Return 64-dim fused embedding for Stage 2/3."""
        los_all, nlos_all, _, _, _, _ = self._run_circuits(x_seq)
        return self._pool_and_fuse(los_all, nlos_all)


# Parameter count
_m = DualCircuit_PI_HLNN(input_size=1, hidden_size=32)
_total = sum(p.numel() for p in _m.parameters())
print(f"DualCircuit_PI_HLNN parameter count: {_total:,}")
print(f"  cell_los + cell_nlos:  {sum(p.numel() for p in _m.cell_los.parameters()) * 2:,}")
print(f"  Projection matrices:   {sum(p.numel() for p in [*_m.P_nlos2los.parameters(), *_m.P_los2nlos.parameters()]):,}")
print(f"  Gates:                 {sum(p.numel() for p in [*_m.gate_los.parameters(), *_m.gate_nlos.parameters()]):,}")
print(f"  Attention + classifier:{sum(p.numel() for p in [*_m.los_attn.parameters(), *_m.nlos_attn.parameters(), *_m.classifier.parameters()]):,}")
print(f"  Embedding dim: {_m.hidden_size * 2} (2 × {_m.hidden_size})")
del _m


---
## Section 4: Loss Function

$$\mathcal{L} = \mathcal{L}_{BCE}$$

Pure binary cross-entropy. The time constant $\tau$ is **not** constrained by the loss â€” it emerges entirely from the ODE dynamics driven by `[x_t, h_t]`. This lets the LNN discover its own temporal behavior from the signal.

The tau values are still computed and available for post-training diagnostics (tau distribution, temporal evolution).

In [None]:
# ==========================================
# LOSS FUNCTION (Pure BCE)
# ==========================================
criterion_fn = nn.BCELoss()

print("Loss: Binary Cross-Entropy (pure BCE, no tau constraint)")

---
## Section 5: Training Pipeline (70/15/15 Split)

In [None]:

# ==========================================
# TRAINING PIPELINE (70/15/15)
# ==========================================
import math

def train_model(X_train, y_train, X_val, y_val, config=CONFIG):
    print(f"Training on {len(X_train)} samples, validating on {len(X_val)}")

    X_tr = torch.tensor(X_train).to(device)
    y_tr = torch.tensor(y_train).unsqueeze(1).to(device)
    X_va = torch.tensor(X_val).to(device)
    y_va = torch.tensor(y_val).unsqueeze(1).to(device)

    train_ds     = TensorDataset(X_tr, y_tr)
    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)

    model = DualCircuit_PI_HLNN(
        input_size=config["input_size"],
        hidden_size=config["hidden_size"],
        dropout=config["dropout"],
        ode_unfolds=config.get("ode_unfolds", 6)
    ).to(device)

    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"],
                            weight_decay=config["weight_decay"])

    warmup_epochs = config["warmup_epochs"]
    total_epochs  = config["max_epochs"]

    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
        return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "lr": []}
    best_val_acc      = 0
    best_model_state  = None
    patience_counter  = 0

    for epoch in range(config["max_epochs"]):
        model.train()
        train_loss_sum = 0
        train_correct, train_total = 0, 0

        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            pred, tau_los, tau_nlos = model(batch_x)
            loss = criterion(pred, batch_y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
            optimizer.step()
            train_loss_sum += loss.item() * len(batch_x)
            train_correct  += ((pred > 0.5).float() == batch_y).sum().item()
            train_total    += len(batch_x)

        train_loss = train_loss_sum / train_total
        train_acc  = train_correct / train_total

        model.eval()
        with torch.no_grad():
            val_pred, val_tau_los, val_tau_nlos = model(X_va)
            val_loss = criterion(val_pred, y_va)
            val_acc  = ((val_pred > 0.5).float() == y_va).float().mean().item()

        lr_now = optimizer.param_groups[0]['lr']
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss.item())
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["lr"].append(lr_now)

        scheduler.step()

        if val_acc > best_val_acc:
            best_val_acc     = val_acc
            best_model_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1

        if epoch % 5 == 0 or epoch == config["max_epochs"] - 1:
            print(f"  Ep {epoch:>3} | Loss: {train_loss:.4f} | Val Acc: {100*val_acc:.2f}% | Best: {100*best_val_acc:.2f}% | LR: {lr_now:.1e}")

        if patience_counter >= config["patience"]:
            print(f"  Early stopping at epoch {epoch}")
            break

    model.load_state_dict(best_model_state)
    print(f"\nBest Validation Accuracy: {100*best_val_acc:.2f}%")
    return model, (X_va, y_va), history


best_model, best_data, best_history = train_model(X_train, y_train, X_val, y_val)


---
## Section 6: Diagnostics

In [None]:

# ==========================================
# DIAGNOSTIC GRID (3x2) — DualCircuit_PI_HLNN
# ==========================================
def plot_diagnostics(model, val_data, history):
    X_va, y_va = val_data
    model.eval()

    with torch.no_grad():
        preds, los_hist, nlos_hist, tau_los_hist, tau_nlos_hist, tau_los_mean, tau_nlos_mean = \
            model(X_va, return_dynamics=True)

    y_true = y_va.cpu().numpy().flatten()
    y_prob = preds.cpu().numpy().flatten()
    y_pred = (y_prob > 0.5).astype(float)

    # Average tau across neurons for each sample
    tau_los_np  = tau_los_mean.cpu().numpy().mean(axis=1)   # (batch,)
    tau_nlos_np = tau_nlos_mean.cpu().numpy().mean(axis=1)

    los_hist_np  = los_hist.cpu().numpy()   # (batch, seq_len, 32)
    nlos_hist_np = nlos_hist.cpu().numpy()

    fig, axs = plt.subplots(2, 3, figsize=(24, 14))
    plt.subplots_adjust(hspace=0.35, wspace=0.3)

    # --- 1. LEARNING CURVES ---
    ax = axs[0, 0]
    ax.plot(history["train_loss"], label='Train Loss', color='#3498db', lw=2)
    ax.plot(history["val_loss"],   label='Val Loss',   color='#e74c3c', lw=2, ls='--')
    ax.set_title("Learning Curves")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Loss (BCE)")
    ax.legend(fontsize=9); ax.grid(True, alpha=0.3)

    # --- 2. ACCURACY CURVES ---
    ax = axs[0, 1]
    ax.plot(history["train_acc"], label='Train Acc', color='#3498db', lw=2)
    ax.plot(history["val_acc"],   label='Val Acc',   color='#e74c3c', lw=2, ls='--')
    ax.set_title("Accuracy Curves")
    ax.set_xlabel("Epoch"); ax.set_ylabel("Accuracy")
    ax.set_ylim([0.4, 1.05]); ax.legend(); ax.grid(True, alpha=0.3)

    # --- 3. CONFUSION MATRIX ---
    ax = axs[0, 2]
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=['LOS', 'NLOS'])
    disp.plot(ax=ax, cmap='Blues', colorbar=False)
    acc = (y_true == y_pred).mean()
    ax.set_title(f"Confusion Matrix (Acc: {100*acc:.1f}%)")

    # --- 4. ROC CURVE ---
    ax = axs[1, 0]
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    ax.plot(fpr, tpr, color='#e74c3c', lw=2, label=f'AUC = {roc_auc:.4f}')
    ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5)
    ax.set_title("ROC Curve")
    ax.set_xlabel("False Positive Rate"); ax.set_ylabel("True Positive Rate")
    ax.legend(loc='lower right'); ax.grid(True, alpha=0.3)

    # --- 5. PCA PHASE SPACE (fused 64-dim embedding) ---
    ax = axs[1, 1]
    batch_size, seq_len, h = los_hist_np.shape
    los_idx  = np.where(y_true == 0)[0]
    nlos_idx = np.where(y_true == 1)[0]
    n_show   = min(len(los_idx), len(nlos_idx), 25)
    show_idx = np.concatenate([los_idx[:n_show], nlos_idx[:n_show]])

    # Concatenate both circuit histories for PCA
    fused_hist = np.concatenate([los_hist_np, nlos_hist_np], axis=2)  # (batch, seq, 64)
    h_flat = fused_hist.reshape(-1, h * 2)
    pca    = PCA(n_components=2)
    h_pca  = pca.fit_transform(h_flat).reshape(batch_size, seq_len, 2)

    for i in show_idx:
        color = '#2ecc71' if y_true[i] == 0 else '#e74c3c'
        ax.plot(h_pca[i, :, 0], h_pca[i, :, 1], color=color, alpha=0.3, lw=0.8)
        ax.scatter(h_pca[i, -1, 0], h_pca[i, -1, 1], color=color, s=15, zorder=5)

    ax.set_title(f"Fused Phase Space (PCA on 64-dim, n={n_show*2})")
    ax.set_xlabel(f"PC1 ({100*pca.explained_variance_ratio_[0]:.1f}%)")
    ax.set_ylabel(f"PC2 ({100*pca.explained_variance_ratio_[1]:.1f}%)")
    ax.grid(True, alpha=0.2)

    # --- 6. TAU DISTRIBUTION — LOS vs NLOS circuits ---
    ax = axs[1, 2]
    sns.kdeplot(tau_los_np[y_true == 0],  ax=ax, fill=True, color='#27ae60',
                label=f'LOS circuit / LOS sample (mean={tau_los_np[y_true==0].mean():.2f})',   alpha=0.5)
    sns.kdeplot(tau_los_np[y_true == 1],  ax=ax, fill=True, color='#82e0aa',
                label=f'LOS circuit / NLOS sample (mean={tau_los_np[y_true==1].mean():.2f})', alpha=0.5)
    sns.kdeplot(tau_nlos_np[y_true == 0], ax=ax, fill=True, color='#e74c3c',
                label=f'NLOS circuit / LOS sample (mean={tau_nlos_np[y_true==0].mean():.2f})', alpha=0.5, ls='--')
    sns.kdeplot(tau_nlos_np[y_true == 1], ax=ax, fill=True, color='#c0392b',
                label=f'NLOS circuit / NLOS sample (mean={tau_nlos_np[y_true==1].mean():.2f})', alpha=0.5, ls='--')
    ax.set_title("Emergent Tau: LOS vs NLOS circuits")
    ax.set_xlabel("Mean Tau"); ax.legend(fontsize=7); ax.grid(True, alpha=0.3)

    plt.suptitle("DualCircuit_PI_HLNN — Stage 1 Diagnostics",
                 fontsize=16, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.show()


plot_diagnostics(best_model, best_data, best_history)


In [None]:

# ==========================================
# TAU TEMPORAL EVOLUTION — both circuits
# ==========================================
def plot_tau_temporal(model, val_data, n_samples=5):
    X_va, y_va = val_data
    model.eval()
    with torch.no_grad():
        _, _, _, tau_los_hist, tau_nlos_hist, _, _ = model(X_va, return_dynamics=True)

    y_true = y_va.cpu().numpy().flatten()
    tau_los_t  = tau_los_hist.cpu().numpy().mean(axis=2)   # (batch, seq_len)
    tau_nlos_t = tau_nlos_hist.cpu().numpy().mean(axis=2)
    x_input    = X_va.cpu().numpy().squeeze(-1)

    los_idx  = np.where(y_true == 0)[0][:n_samples]
    nlos_idx = np.where(y_true == 1)[0][:n_samples]

    fig, axs = plt.subplots(2, 3, figsize=(22, 10))
    plt.subplots_adjust(hspace=0.4, wspace=0.3)

    titles = [
        ('LOS signal',          los_idx,  '#2ecc71', x_input),
        ('LOS τ — LOS circuit', los_idx,  '#27ae60', tau_los_t),
        ('LOS τ — NLOS circuit',los_idx,  '#82e0aa', tau_nlos_t),
        ('NLOS signal',         nlos_idx, '#e74c3c', x_input),
        ('NLOS τ — LOS circuit',nlos_idx, '#c0392b', tau_los_t),
        ('NLOS τ — NLOS circuit',nlos_idx,'#8e44ad', tau_nlos_t),
    ]

    for ax, (title, idx, color, data) in zip(axs.flat, titles):
        for i in idx:
            ax.plot(data[i], alpha=0.55, color=color, lw=1.3)
        ax.set_title(title, fontsize=10, fontweight='bold')
        ax.set_xlabel("Timestep"); ax.grid(True, alpha=0.3)
        if 'signal' in title:
            ax.set_ylabel("Normalised CIR")
        else:
            ax.set_ylabel("Mean Tau"); ax.set_ylim([0.3, 7.0])

    plt.suptitle("DualCircuit Tau Temporal Evolution — How Each Circuit Adapts",
                 fontsize=13, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.show()


plot_tau_temporal(best_model, best_data, n_samples=5)


---
## Section 7: Test Set Evaluation & Save Artifacts

In [None]:

# ==========================================
# TEST SET EVALUATION
# ==========================================
best_model.eval()
X_te = torch.tensor(X_test).to(device)
y_te = torch.tensor(y_test).unsqueeze(1).to(device)

with torch.no_grad():
    test_pred, _, _ = best_model(X_te)
    test_acc = ((test_pred > 0.5).float() == y_te).float().mean().item()
    test_pred_np = (test_pred.cpu().numpy().flatten() > 0.5).astype(float)
    test_true_np = y_test.flatten()

print(f"Test Accuracy: {100*test_acc:.2f}%")
print(f"\nClassification Report:")
print(classification_report(test_true_np, test_pred_np, target_names=['LOS', 'NLOS']))

cm = confusion_matrix(test_true_np, test_pred_np)
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
disp = ConfusionMatrixDisplay(cm, display_labels=['LOS', 'NLOS'])
disp.plot(ax=ax, cmap='Blues', colorbar=False)
ax.set_title(f"Test Set Confusion Matrix — DualCircuit_PI_HLNN (Acc: {100*test_acc:.1f}%)")
plt.tight_layout()
plt.show()

print(f"\nModel summary:")
print(f"  Architecture: DualCircuit_PI_HLNN (hidden={best_model.hidden_size} per circuit)")
print(f"  Embedding dim: {best_model.hidden_size * 2} (2 × {best_model.hidden_size})")
print(f"  Total params: {sum(p.numel() for p in best_model.parameters()):,}")


In [None]:
# ==========================================
# SAVE ARTIFACTS
# ==========================================
# 1. Model weights
torch.save(best_model.state_dict(), "stage1_pi_hlnn_best.pt")
print("Saved: stage1_pi_hlnn_best.pt")

# 2. Configuration (for reproducibility)
torch.save({"config": CONFIG}, "stage1_config.pt")
print("Saved: stage1_config.pt")

print(f"\nArtifacts ready for Stage 2.")

In [None]:
print("Stage 1 complete.")
print("Model artifact: stage1_pi_hlnn_best.pt")
print("Config artifact: stage1_config.pt")
print("\nReady for Stage 2.")