In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader, TensorDataset

# ==========================================
# 1. ROBUST DATA LOADING (RX_PACC + INSTANCE NORM + HARDWARE RISE TIME)
# ==========================================
PRE_CROP = 10
POST_CROP = 50
TOTAL_LEN = 60

def load_optimized_dataset(filepath="combined_uwb_dataset.csv"):
    processed_seqs, context_features, labels = [], [], []
    print(f"Reading dataset: {filepath}...")
    
    try:
        df = pd.read_csv(filepath)
    except FileNotFoundError:
        print("Error: File not found. Please check the filepath.")
        return None, None, None

    cir_cols = sorted([c for c in df.columns if c.startswith('CIR')], key=lambda x: int(x.replace('CIR', '')))
    
    for _, row in df.iterrows():
        # 1. Parse Raw Signal
        sig = pd.to_numeric(row[cir_cols], errors='coerce').fillna(0).astype(float).values
        
        # 2. PHYSICS NORMALIZATION: Divide by RX_PACC
        if 'RX_PACC' in row and row['RX_PACC'] > 0:
            sig = sig / float(row['RX_PACC'])
        
        # 3. Dynamic Crop (Center on First Path)
        fp_idx = int(row['FP_INDEX']) if 'FP_INDEX' in row and not pd.isna(row['FP_INDEX']) else np.argmax(sig)
        center = fp_idx if fp_idx > PRE_CROP else np.argmax(sig)
        start, end = int(center - PRE_CROP), int(center + POST_CROP)
        
        crop = sig[start:end] if start >= 0 and end <= len(sig) else np.zeros(TOTAL_LEN)
        if len(crop) != TOTAL_LEN: crop = np.resize(crop, TOTAL_LEN)

        # 4. INSTANCE NORMALIZATION (Crucial for LNN)
        local_min = np.min(crop)
        local_max = np.max(crop)
        range_val = local_max - local_min
        
        if range_val > 0:
            crop = (crop - local_min) / range_val
        else:
            crop = np.zeros_like(crop)
            
        processed_seqs.append(crop)
        
        # 5. Extract Physical Context Features
        c_max = np.max(sig) 
        total_energy = np.sum(sig**2) + 1e-9
        peak_energy = np.sum(sig[max(0, fp_idx-2):min(len(sig), fp_idx+6)]**2)
        c_pow_ratio = peak_energy / total_energy
        
        # Feature D: Hardware-Style Rise Time (Threshold based)
        # Calculates distance from 10% threshold crossing to the Peak
        try:
            peak_idx = np.argmax(sig)
            peak_val = sig[peak_idx]
            threshold = 0.1 * peak_val
            
            # Look at signal BEFORE the peak
            pre_peak = sig[:peak_idx]
            
            # Find where it first crosses the threshold
            start_indices = np.where(pre_peak >= threshold)[0]
            
            if len(start_indices) > 0:
                rise_time = peak_idx - start_indices[0]
            else:
                rise_time = 0
        except:
            rise_time = 0 
        
        context_features.append([c_max, total_energy, c_pow_ratio, rise_time])
        labels.append(float(row['Label']))
    
    return np.array(processed_seqs).reshape(-1, TOTAL_LEN, 1), np.array(context_features), np.array(labels)

# ==========================================
# 2. PURE LIQUID NEURAL NETWORK
# ==========================================
class LiquidCell(nn.Module):
    def __init__(self, input_size, hidden_size, context_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.synapse = nn.Linear(input_size + hidden_size, hidden_size)
        self.A = nn.Parameter(torch.ones(hidden_size) * -0.5)
        
        self.tau_controller = nn.Sequential(
            nn.Linear(context_size, 16),
            nn.Tanh(),
            nn.Linear(16, hidden_size),
            nn.Sigmoid() 
        )
        
    def forward(self, x_t, h_prev, context, dt=1.0):
        tau_gate = self.tau_controller(context)
        tau = 1.0 + (4.0 * tau_gate) 
        
        combined = torch.cat((x_t, h_prev), dim=1)
        S_t = torch.tanh(self.synapse(combined))
        
        numerator = h_prev + (dt * S_t * self.A)
        denominator = 1.0 + (dt / tau)
        
        h_new = numerator / denominator
        return h_new, tau

class PureLNN(nn.Module):
    def __init__(self, context_size=3):
        super().__init__()
        self.hidden_size = 64
        self.cell = LiquidCell(1, self.hidden_size, context_size)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, 32),
            nn.SiLU(), 
            nn.Dropout(0.3),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x_seq, x_ctx, return_dynamics=False):
        batch_size, seq_len, _ = x_seq.size()
        h_t = torch.zeros(batch_size, self.hidden_size).to(x_seq.device)
        
        h_hist = [] 
        h_sum = torch.zeros_like(h_t) 
        tau_mean = 0
        
        for t in range(seq_len):
            h_t, tau = self.cell(x_seq[:, t, :], h_t, x_ctx)
            h_sum += h_t
            
            if return_dynamics: 
                h_hist.append(h_t.unsqueeze(1))
                if isinstance(tau, torch.Tensor): tau_mean = tau 
        
        h_pooled = h_sum / seq_len
        out = self.classifier(h_pooled)
        
        if return_dynamics:
            return out, torch.cat(h_hist, dim=1), tau_mean
        return out

# ==========================================
# 3. ENHANCED DIAGNOSTICS (2x2 Grid)
# ==========================================
def plot_full_diagnostics(model, X_seq, X_ctx, y_true, train_hist, val_hist):
    model.eval()
    with torch.no_grad():
        y_probs, h_hist, taus = model(X_seq, X_ctx, return_dynamics=True)
    
    # 1. Prepare Data
    h_hist_np = h_hist.cpu().numpy()
    if taus.dim() > 1: taus_np = taus.cpu().numpy().mean(axis=1)
    else: taus_np = taus.cpu().numpy()
    
    y_true_np = y_true.cpu().numpy().flatten()
    y_pred_np = (y_probs.cpu().numpy().flatten() > 0.5).astype(float)
    batch_size, seq_len, hidden_dim = h_hist_np.shape

    # 2. Setup 2x2 Plot
    fig, axs = plt.subplots(2, 2, figsize=(18, 12))
    plt.subplots_adjust(hspace=0.3)

    # --- PLOT 1: Learning Curves ---
    axs[0, 0].plot(train_hist, label='Train Loss', color='#3498db', linewidth=2)
    axs[0, 0].plot(val_hist, label='Val Loss', color='#e74c3c', linewidth=2, linestyle='--')
    axs[0, 0].set_title("Learning Dynamics (Loss Curve)")
    axs[0, 0].set_xlabel("Epochs")
    axs[0, 0].set_ylabel("Binary Cross Entropy")
    axs[0, 0].grid(True, alpha=0.3)
    axs[0, 0].legend()

    # --- PLOT 2: Confusion Matrix ---
    cm = confusion_matrix(y_true_np, y_pred_np)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['LOS', 'NLOS'])
    disp.plot(ax=axs[0, 1], cmap='Blues', colorbar=False)
    axs[0, 1].set_title(f"Confusion Matrix (Acc: {100*(y_true_np==y_pred_np).mean():.1f}%)")

    # --- PLOT 3: Phase Space (PCA) ---
    # Select subset for clarity
    los_indices = np.where(y_true_np == 0)[0]
    nlos_indices = np.where(y_true_np == 1)[0]
    num_samples = min(len(los_indices), len(nlos_indices), 25)
    selected_indices = np.concatenate([los_indices[:num_samples], nlos_indices[:num_samples]])

    h_flat = h_hist_np.reshape(-1, hidden_dim)
    pca = PCA(n_components=2)
    h_pca = pca.fit_transform(h_flat)
    h_pca_traj = h_pca.reshape(batch_size, seq_len, 2)

    for i in selected_indices:
        color = '#2ecc71' if y_true_np[i] == 0 else '#e74c3c'
        axs[1, 0].plot(h_pca_traj[i, :, 0], h_pca_traj[i, :, 1], color=color, alpha=0.3)
    
    axs[1, 0].set_title(f"Liquid State Phase Space (n={num_samples*2})")
    axs[1, 0].set_xlabel("PC1")
    axs[1, 0].set_ylabel("PC2")
    axs[1, 0].grid(True, alpha=0.2)

    # --- PLOT 4: Tau Distribution ---
    sns.kdeplot(taus_np[y_true_np==0], ax=axs[1, 1], fill=True, color="#2ecc71", label="LOS Tau")
    sns.kdeplot(taus_np[y_true_np==1], ax=axs[1, 1], fill=True, color="#e74c3c", label="NLOS Tau")
    axs[1, 1].set_title("Learned Time Constants (Tau)")
    axs[1, 1].set_xlabel("Tau (Time Steps)")
    axs[1, 1].legend()

    plt.show()

# ==========================================
# 4. MAIN EXECUTION
# ==========================================
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1. Load Data
    X_seq_all, X_ctx_all, y_all = load_optimized_dataset("combined_uwb_dataset.csv")
    if X_seq_all is None: return 

    # Scale Context Features
    scaler = StandardScaler()
    X_ctx_all = scaler.fit_transform(X_ctx_all)
    num_context_features = X_ctx_all.shape[1]
    print(f"Context Features: {num_context_features} (Ampl, Energy, Ratio, HW_RiseTime)")

    # 2. Cross Validation
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_accuracies = []
    
    # Storage for the last fold's history for plotting
    final_train_hist = []
    final_val_hist = []
    final_model = None
    final_val_data = None

    for fold, (train_idx, val_idx) in enumerate(skf.split(X_seq_all, y_all)):
        print(f"\n--- FOLD {fold+1} ---")
        
        # Prepare Tensors
        X_seq_tr = torch.tensor(X_seq_all[train_idx].astype(np.float32)).to(device)
        X_ctx_tr = torch.tensor(X_ctx_all[train_idx].astype(np.float32)).to(device)
        y_tr = torch.tensor(y_all[train_idx].astype(np.float32)).unsqueeze(1).to(device)
        
        X_seq_va = torch.tensor(X_seq_all[val_idx].astype(np.float32)).to(device)
        X_ctx_va = torch.tensor(X_ctx_all[val_idx].astype(np.float32)).to(device)
        y_va = torch.tensor(y_all[val_idx].astype(np.float32)).unsqueeze(1).to(device)

        train_ds = TensorDataset(X_seq_tr, X_ctx_tr, y_tr)
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

        model = PureLNN(context_size=num_context_features).to(device)
        optimizer = optim.AdamW(model.parameters(), lr=0.005, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
        
        # Lists to store loss per epoch
        epoch_train_losses = []
        epoch_val_losses = []

        best_val_acc = 0
        
        for epoch in range(40):
            model.train()
            train_loss_accum = 0
            
            for batch_x, batch_ctx, batch_y in train_loader:
                optimizer.zero_grad()
                preds = model(batch_x, batch_ctx)
                loss = nn.BCELoss()(preds, batch_y)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                train_loss_accum += loss.item()

            # Record Train Loss
            avg_train_loss = train_loss_accum / len(train_loader)
            epoch_train_losses.append(avg_train_loss)

            # Validation Step
            model.eval()
            val_loss_accum = 0
            with torch.no_grad():
                val_preds = model(X_seq_va, X_ctx_va)
                val_loss = nn.BCELoss()(val_preds, y_va).item()
                val_loss_accum = val_loss # Since full batch val here
                val_acc = ((val_preds > 0.5).float() == y_va).float().mean().item()
            
            # Record Val Loss
            epoch_val_losses.append(val_loss_accum)
            
            scheduler.step(val_acc)
            
            if epoch % 10 == 0:
                print(f"Ep {epoch:<3} | Loss: {avg_train_loss:.4f} | Val Acc: {100*val_acc:.2f}%")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc

        fold_accuracies.append(best_val_acc)
        print(f"Fold {fold+1} Best Accuracy: {100*best_val_acc:.2f}%")
        
        # Save data from the last fold for visualization
        final_train_hist = epoch_train_losses
        final_val_hist = epoch_val_losses
        final_model = model
        final_val_data = (X_seq_va, X_ctx_va, y_va)

    print(f"\nüèÜ Final Mean Accuracy: {100*np.mean(fold_accuracies):.2f}%")
    
    # 3. Final Diagnostics
    print("\nüìä Generating Complete Diagnostics (Loss, Confusion Matrix, PCA, Tau)...")
    if final_model:
        plot_full_diagnostics(final_model, *final_val_data, final_train_hist, final_val_hist)

if __name__ == "__main__":
    main()