In [None]:
import sys
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import linalg
from scipy.sparse.linalg import eigsh
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans

# --- SETUP PATHS ---
sys.path.append(os.path.abspath('..'))

from spatial_fdr_evaluation.data.adbench_loader import load_from_ADbench
from spatial_fdr_evaluation.methods.kernels import compute_kernel_matrix, estimate_length_scale

sns.set_context("notebook", font_scale=1.2)
sns.set_style("whitegrid")


# ==========================================
# 1. DATASET DISCOVERY & SCANNING
# ==========================================
def get_all_adbench_datasets():
    """Returns sorted list of all available ADbench classical datasets."""
    candidate_paths = [
        '../third_party/ADbench/datasets/Classical',
        '../../third_party/ADbench/datasets/Classical',
        './datasets/Classical'
    ]
    
    dataset_dir = None
    for p in candidate_paths:
        if os.path.exists(p):
            dataset_dir = p
            break
            
    if dataset_dir is None:
        print("Warning: Could not find ADbench dataset directory. Using default list.")
        return ['23_mammography', '37_satellite', '40_vowels', '2_annthyroid', '33_skin', '12_fault']
    
    files = [f.replace('.npz', '') for f in os.listdir(dataset_dir) if f.endswith('.npz')]
    print(f"Found {len(files)} datasets in {dataset_dir}")
    return sorted(files)

def scan_all_datasets(datasets, sigma_factor=0.5, max_samples_for_scan=2000):
    """
    Iterates through ALL datasets to find those with valid spatial blocks.
    """
    valid_results = {}
    
    print(f"{'Dataset':<25} | {'N':<6} | {'Gap':<8} | {'k':<3} | {'Status'}")
    print("-" * 75)
    
    for name in datasets:
        try:
            # 1. Load Data
            data = load_from_ADbench(name)
            X_full = data['X_train']
            
            # Subsample for speed during scan
            if len(X_full) > max_samples_for_scan:
                idx = np.random.choice(len(X_full), max_samples_for_scan, replace=False)
                X = StandardScaler().fit_transform(X_full[idx])
            else:
                X = StandardScaler().fit_transform(X_full)
                
            # 2. Compute Kernel (Sharper Sigma)
            sigma = estimate_length_scale(X, method='median') * sigma_factor
            K = compute_kernel_matrix(X, kernel_type='rbf', length_scale=sigma)
            
            # 3. Spectral Analysis
            D = np.array(K.sum(axis=1)).flatten() + 1e-10
            A_norm = (np.diag(1/np.sqrt(D)) @ K @ np.diag(1/np.sqrt(D)))
            
            try:
                evals, evecs = eigsh(A_norm, k=10, which='LA')
            except:
                evals, evecs = linalg.eigh(A_norm)
                evals, evecs = evals[-10:], evecs[:, -10:]
            
            idx = np.argsort(evals)[::-1]
            evals = evals[idx]
            evecs = evecs[:, idx]
            
            # 4. Gap Analysis
            gaps = np.diff(1 - evals) 
            valid_gaps = gaps[1:] # Ignore k=1
            
            if len(valid_gaps) == 0:
                print(f"{name:<25} | {len(X):<6} | {'N/A':<8} | -   | REJECT (Flat)")
                continue

            local_best_idx = np.argmax(valid_gaps)
            optimal_k = (local_best_idx + 1) + 1 
            max_gap = valid_gaps[local_best_idx]
            
            if max_gap < 0.02: 
                 print(f"{name:<25} | {len(X):<6} | {max_gap:.4f}   | {optimal_k:<3} | REJECT (Gap too small)")
                 continue

            # 5. Check Block Sizes
            kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
            labels = kmeans.fit_predict(evecs[:, :optimal_k])
            
            counts = np.bincount(labels)
            valid_blocks = np.sum(counts >= 50)
            
            if valid_blocks < 2:
                print(f"{name:<25} | {len(X):<6} | {max_gap:.4f}   | {optimal_k:<3} | REJECT (Only 1 valid block)")
                continue
            
            print(f"{name:<25} | {len(X):<6} | {max_gap:.4f}   | {optimal_k:<3} | OK")
            
            valid_results[name] = {
                'gap': max_gap,
                'k': optimal_k,
                'sigma_factor': sigma_factor
            }
            
        except Exception as e:
            pass 
            
    print("-" * 75)
    print(f"Found {len(valid_results)} suitable datasets.")
    return valid_results


# ==========================================
# 2. STRATIFIED SAMPLING
# ==========================================
def run_stratified_experiment(dataset_name, config, n_h1=50, n_h0_compact=150, n_h0_close=30, n_h0_far=270):
    """
    Loads dataset and performs 4-group stratified sampling.
    """
    # 1. Reload Full Data
    data = load_from_ADbench(dataset_name)
    X = StandardScaler().fit_transform(data['X_train'])
    
    # 2. Recompute Kernel
    sigma = estimate_length_scale(X, method='median') * config['sigma_factor']
    K = compute_kernel_matrix(X, kernel_type='rbf', length_scale=sigma)
    
    D = np.array(K.sum(axis=1)).flatten() + 1e-10
    A_norm = (np.diag(1/np.sqrt(D)) @ K @ np.diag(1/np.sqrt(D)))
    
    try:
        evals, evecs = eigsh(A_norm, k=10, which='LA')
    except:
        evals, evecs = linalg.eigh(A_norm)
        evals, evecs = evals[-10:], evecs[:, -10:]
        
    idx = np.argsort(evals)[::-1]
    evecs = evecs[:, idx]
    
    optimal_k = config['k']
    kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
    labels_full = kmeans.fit_predict(evecs[:, :optimal_k])
    
    # --- SAMPLING LOGIC ---
    valid_clusters = [c for c in range(optimal_k) if np.sum(labels_full == c) >= n_h1]
    if not valid_clusters: return None
    
    signal_cid = np.random.choice(valid_clusters)
    pool_h1 = np.where(labels_full == signal_cid)[0]
    
    # Partition Noise
    pool_noise_all = np.where(labels_full != signal_cid)[0]
    sim_matrix_h1 = K[np.ix_(pool_h1, pool_h1)]
    medoid_idx = pool_h1[np.argmax(sim_matrix_h1.sum(axis=1))]
    sim_to_signal = K[pool_noise_all, medoid_idx]
    
    thresh_close = np.percentile(sim_to_signal, 80)
    thresh_far = np.percentile(sim_to_signal, 50)
    
    pool_close = pool_noise_all[sim_to_signal >= thresh_close]
    pool_far = pool_noise_all[sim_to_signal <= thresh_far]
    
    # Compact H0 selection
    other_clusters = [c for c in range(optimal_k) if c != signal_cid]
    if len(other_clusters) >= 1 and optimal_k > 2:
        compact_cid = np.random.choice(other_clusters)
        pool_compact = np.where(labels_full == compact_cid)[0]
    else:
        mask_middle = (sim_to_signal > thresh_far) & (sim_to_signal < thresh_close)
        pool_compact = pool_noise_all[mask_middle]

    # Quota Adjustment
    if len(pool_compact) < n_h0_compact:
        missing = n_h0_compact - len(pool_compact)
        n_h0_compact = len(pool_compact)
        n_h0_far += missing
        
    if len(pool_close) < n_h0_close: n_h0_close = len(pool_close)
    if len(pool_far) < n_h0_far:     n_h0_far = len(pool_far)
    
    idx_h1 = np.random.choice(pool_h1, n_h1, replace=False)
    idx_compact = np.random.choice(pool_compact, n_h0_compact, replace=False) if n_h0_compact > 0 else []
    idx_close = np.random.choice(pool_close, n_h0_close, replace=False)
    idx_far = np.random.choice(pool_far, n_h0_far, replace=False)
    
    final_indices = np.concatenate([idx_h1, idx_compact, idx_close, idx_far]).astype(int)
    np.random.shuffle(final_indices)
    
    true_labels = np.ones(len(final_indices), dtype=int)
    true_labels[np.isin(final_indices, idx_h1)] = 0
    
    group_ids = np.zeros(len(final_indices), dtype=int)
    group_ids[np.isin(final_indices, idx_compact)] = 1
    group_ids[np.isin(final_indices, idx_close)] = 2
    group_ids[np.isin(final_indices, idx_far)] = 3
    
    return {
        'X': X[final_indices],
        'K': K[np.ix_(final_indices, final_indices)],
        'true_labels': true_labels,
        'group_ids': group_ids,
        'indices': final_indices
    }


# ==========================================
# 3. COMPREHENSIVE VISUALIZATION
# ==========================================
def verify_dataset_comprehensive(experiment):
    K = experiment['K']
    g_ids = experiment['group_ids']
    labels = ["H1 Signal", "H0 Compact", "H0 Close", "H0 Far"]
    
    # A. Matrix Clustermap (Sorted by Group)
    sort_idx = np.argsort(g_ids)
    K_sorted = K[np.ix_(sort_idx, sort_idx)]
    
    plt.figure(figsize=(6, 5))
    sns.heatmap(K_sorted, cmap="viridis", cbar=False, xticklabels=False, yticklabels=False)
    plt.title("Kernel Matrix (Sorted)\nTop=H1, Blue=Compact, Orange=Close, Gray=Far")
    plt.show()
    
    # B. Similarity to Signal (Boxplot)
    h1_indices = np.where(g_ids == 0)[0]
    sims = []
    names = []
    
    sims.append(K[np.ix_(h1_indices, h1_indices)].mean(axis=1))
    names.append("H1 Self")
    
    for gid in [1, 2, 3]:
        idx = np.where(g_ids == gid)[0]
        if len(idx) > 0:
            sims.append(K[np.ix_(idx, h1_indices)].mean(axis=1))
            names.append(f"{labels[gid]}")
            
    plt.figure(figsize=(8, 4))
    plt.boxplot(sims, labels=names)
    plt.title("Similarity to Signal Block (H1)")
    plt.ylabel("Kernel Value")
    plt.show()

    # C. Separation Stats
    k_h1_h1 = K[np.ix_(h1_indices, h1_indices)]
    sim_h1_h1 = k_h1_h1[~np.eye(len(h1_indices), dtype=bool)].mean()
    
    # H1 vs All Noise
    noise_idx = np.where(g_ids != 0)[0]
    sim_h1_noise = K[np.ix_(h1_indices, noise_idx)].mean()
    
    print(f"  > Signal Coherence (H1-H1): {sim_h1_h1:.4f}")
    print(f"  > Noise Separation (H1-All): {sim_h1_noise:.4f}")
    print(f"  > GAP: {sim_h1_h1 - sim_h1_noise:.4f}")


# ==========================================
# 4. MAIN EXECUTION (LOOP ALL)
# ==========================================

# A. Discovery
all_datasets = get_all_adbench_datasets()

# B. Scan
valid_datasets = scan_all_datasets(all_datasets, sigma_factor=0.5)

# C. Process ALL Candidates
if valid_datasets:
    print(f"\n\n{'='*60}")
    print(f"STARTING ANALYSIS OF {len(valid_datasets)} CANDIDATES")
    print(f"{'='*60}\n")
    
    # Sort by Gap score for better presentation
    sorted_candidates = sorted(valid_datasets.items(), key=lambda x: x[1]['gap'], reverse=True)
    
    for name, config in sorted_candidates:
        print(f"\n>>> DATASET: {name} (Gap: {config['gap']:.4f}, k={config['k']})")
        
        try:
            experiment = run_stratified_experiment(
                name, 
                config,
                n_h1=50, 
                n_h0_compact=150, 
                n_h0_close=20, 
                n_h0_far=280
            )
            
            if experiment:
                verify_dataset_comprehensive(experiment)
            else:
                print("  [Failed to sample valid groups]")
                
        except Exception as e:
            print(f"  [Error processing {name}: {str(e)}]")
else:
    print("No suitable datasets found.")