<a href="https://colab.research.google.com/github/leafandsheep/QML-terrain-test/blob/main/model_details.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code is used to compare device, season and bike type.

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import sys
sys.path.append('/content/drive/MyDrive/code')
import os
os.chdir("/content/drive/MyDrive/code")
### please update document path

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!ls

batch_infer.py			   results_iphone_11th_results.csv
data_factory			   results_iphone_16th_results.csv
device_comparison		   results_iphone_billy_results.csv
__pycache__			   results_iphone_Nst_results.csv
results_android_11th_results.csv   tran_update2.py
results_android_16th_results.csv   Untitled0.ipynb
results_android_billy_results.csv  win_56
results_android_Nst_results.csv


In [None]:
from tran_update2 import VQTransAE

In [None]:
#!/usr/bin/env python3

import json
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import average_precision_score, precision_recall_curve

# ============================================================================
# Configuration Parameters
# ============================================================================

class Config:
    """All configurable parameters"""

    # Data paths
    ORIGINAL_DATA_DIR = './dataset/BIKE'
    PROCESSED_DATA_DIR = './dataset/BIKE_processed'
    OUTPUT_DIR = './vqtransae_results'

    # Original model path (optional, for loading pretrained weights)
    PRETRAINED_MODEL_PATH = './win_56/best_vqtransae.pth'

    # Feature columns
    FEATURE_COLUMNS = ['X', 'Y', 'Z', 'G']

    # Model parameters
    WIN_SIZE = 56        # Window size
    IN_DIM = 4           # Input dimension
    HIDDEN_DIM = 64      # Hidden layer dimension
    LATENT_DIM = 32      # Latent space dimension
    CODEBOOK_SIZE = 1024 # Codebook size
    D_MODEL = 64         # Transformer dimension
    N_HEADS = 4          # Number of attention heads
    N_LAYERS = 3         # Number of Transformer layers

    # Data parameters
    STEP = 5             # Sliding window step
    BATCH_SIZE = 64      # Batch size
    VAL_RATIO = 0.15     # Validation set ratio

    # Training parameters
    EPOCHS = 100         # Number of training epochs
    LEARNING_RATE = 1e-4 # Learning rate
    CODEBOOK_LR = 5e-4   # Codebook learning rate (higher)
    WEIGHT_DECAY = 1e-5  # Weight decay

    # Regularization parameters (prevent codebook collapse)
    ENTROPY_WEIGHT = 0.3       # Entropy regularization weight
    ENTROPY_TARGET_RATIO = 0.8 # Target entropy ratio
    DIVERSITY_WEIGHT = 0.2     # Diversity regularization weight
    VQ_WEIGHT_BASE = 2.0       # VQ loss base weight

    # Codebook refresh parameters
    REFRESH_EVERY = 5          # Refresh every N epochs
    MIN_USAGE_THRESHOLD = 10   # Minimum usage threshold
    ACTIVE_TOKEN_TARGET = 100  # Target active token count

    # Composite score weights (optimal configuration)
    ALPHA = 1.0   # Reconstruction error weight
    BETA = 1.0    # Quantization distance weight
    GAMMA = -0.5  # Attention weight (negative! anomaly samples have more focused attention)

    # Threshold percentile
    THRESHOLD_PERCENTILE = 20  # Use 20th percentile of validation set


# ============================================================================
# Utility Functions
# ============================================================================

def to_serializable(obj):
    """Convert numpy/torch objects to JSON serializable format"""
    if isinstance(obj, np.generic):
        return obj.item()
    if isinstance(obj, torch.Tensor):
        return obj.detach().cpu().tolist()
    if isinstance(obj, dict):
        return {k: to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_serializable(v) for v in obj]
    return obj


def save_json(data, path: Path, indent: int = 2):
    """Save JSON file"""
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(to_serializable(data), f, indent=indent, ensure_ascii=False)


def speed_bins(speeds: np.ndarray, bounds: Tuple = (3.0, 6.0, 9.0)) -> np.ndarray:
    """Bin speeds into intervals for stratified scaling"""
    if speeds is None or len(speeds) == 0:
        return np.zeros(0, dtype=int)
    return np.digitize(speeds, bounds, right=False)


def robust_scale_by_group(values: np.ndarray, groups: np.ndarray) -> np.ndarray:
    """
    Robust scaling by group (Median-MAD normalization)

    This function is important! It removes baseline differences across speeds.

    Formula: scaled = (value - median) / MAD
    where MAD = median(|value - median|)
    """
    if values.size == 0:
        return values

    scaled = np.zeros_like(values, dtype=np.float64)
    unique_groups = np.unique(groups) if groups.size else np.array([0])

    for gid in unique_groups:
        mask = groups == gid if groups.size else np.ones_like(values, dtype=bool)
        if not np.any(mask):
            continue
        subset = values[mask]
        median = np.median(subset)
        mad = np.median(np.abs(subset - median)) + 1e-6  # Prevent division by zero
        scaled[mask] = (subset - median) / mad

    return scaled


def token_stats(counts: np.ndarray) -> Tuple[int, float, np.ndarray]:
    """Compute token usage statistics"""
    total = counts.sum()
    if total == 0:
        return 0, 0.0, np.zeros_like(counts, dtype=np.float64)

    probs = counts.astype(np.float64) / total
    non_zero = probs > 0
    entropy = -(probs[non_zero] * np.log(probs[non_zero])).sum()
    perplexity = float(np.exp(entropy))
    active = int((counts > 0).sum())

    return active, perplexity, probs


# ============================================================================
# Dataset and DataLoader
# ============================================================================

class WindowDataset(Dataset):
    """Sliding window dataset"""

    def __init__(self, csv_path: str, win_size: int, step: int,
                 features: List[str] = None):
        if features is None:
            features = Config.FEATURE_COLUMNS

        df = pd.read_csv(csv_path)
        self.data = df[features].values.astype(np.float32)
        self.labels = df['anomaly'].values if 'anomaly' in df.columns else np.zeros(len(df))
        self.speeds = df['Speed'].values if 'Speed' in df.columns else np.zeros(len(df))
        self.win_size = win_size
        self.step = step
        self.n_windows = (len(self.data) - win_size) // step + 1

    def __len__(self):
        return self.n_windows

    def __getitem__(self, idx):
        start = idx * self.step
        end = start + self.win_size
        center = start + self.win_size // 2

        window = self.data[start:end]
        label = float(np.max(self.labels[start:end]))  # Window is anomaly if any point is anomaly
        speed = float(self.speeds[center]) if not np.isnan(self.speeds[center]) else 0.0

        return torch.from_numpy(window), torch.tensor(label), torch.tensor(speed)


def create_dataloaders(
    data_dir: str,
    win_size: int = Config.WIN_SIZE,
    step: int = Config.STEP,
    batch_size: int = Config.BATCH_SIZE
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create DataLoaders"""

    data_path = Path(data_dir)

    train_dataset = WindowDataset(data_path / 'train.csv', win_size, step)
    val_dataset = WindowDataset(data_path / 'val.csv', win_size, step)
    test_dataset = WindowDataset(data_path / 'test.csv', win_size, step)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


# ============================================================================
# Step 3: Regularization Loss Functions
# ============================================================================

def entropy_regularizer(batch_counts: torch.Tensor, target_ratio: float,
                        device: torch.device) -> torch.Tensor:
    if batch_counts.sum() == 0:
        return torch.zeros(1, device=device)

    probs = batch_counts.float().to(device)
    probs = probs / probs.sum()
    probs = probs[probs > 0]

    entropy = -(probs * probs.log()).sum()
    max_entropy = math.log(batch_counts.numel())
    target_entropy = target_ratio * max_entropy

    return torch.relu(target_entropy - entropy)


def diversity_regularizer(indices: torch.Tensor) -> torch.Tensor:
    """
    Diversity regularization

    Purpose: Penalize overly concentrated token usage
    """
    if indices.numel() == 0:
        return torch.zeros(1, device=indices.device)

    unique, counts = torch.unique(indices, return_counts=True)
    if unique.numel() <= 1:
        return torch.ones(1, device=indices.device)

    probs = counts.float() / counts.sum()
    entropy = -(probs * torch.log(probs + 1e-10)).sum()
    target_entropy = torch.log(torch.tensor(float(unique.numel()), device=indices.device))

    div_loss = torch.relu((target_entropy - entropy) / (target_entropy + 1e-8))
    return div_loss


def refresh_inactive_codes(model, usage_counts: np.ndarray, min_usage: int,
                           reuse_active: bool = True) -> int:
    """
    Refresh inactive codebook vectors

    Strategy: Replace low-usage vectors with active vectors + noise
    """
    if usage_counts.sum() == 0:
        return 0

    dormant_mask = usage_counts < min_usage
    dormant = np.nonzero(dormant_mask)[0]

    if dormant.size == 0:
        return 0

    with torch.no_grad():
        embed = model.quant.embed.weight
        scale = 1.0 / math.sqrt(embed.shape[1])
        noise = torch.randn((dormant.size, embed.shape[1]), device=embed.device) * scale
        dormant_tensor = torch.from_numpy(dormant).to(embed.device)

        if reuse_active:
            active_idx = np.nonzero(usage_counts >= min_usage)[0]
            if active_idx.size > 0:
                chosen = np.random.choice(active_idx, size=dormant.size, replace=True)
                chosen_tensor = torch.from_numpy(chosen).to(embed.device)
                sampled = embed[chosen_tensor]
                embed[dormant_tensor] = sampled + 0.1 * noise
            else:
                embed[dormant_tensor] = noise
        else:
            embed[dormant_tensor] = noise

    return int(dormant.size)


# ============================================================================
# Step 4: Training Functions
# ============================================================================

def train_epoch(model, loader, optimizer, device, epoch: int,
                entropy_weight: float, entropy_target_ratio: float,
                diversity_weight: float, vq_weight: float) -> Dict:
    """Train one epoch"""

    model.train()
    total_loss = 0
    total_recon = 0
    total_vq = 0
    codebook_size = model.quant.embed.num_embeddings
    epoch_hist = torch.zeros(codebook_size, dtype=torch.long)

    for x, _, _ in loader:
        x = x.to(device)

        # Forward pass
        recon, _, loss_vq, indices = model(x)

        # Reconstruction loss
        rec_loss = nn.functional.mse_loss(recon, x)
        loss = rec_loss + vq_weight * loss_vq

        # Token usage statistics
        batch_counts = torch.bincount(indices.detach().cpu().flatten(), minlength=codebook_size)
        epoch_hist += batch_counts

        # Entropy regularization
        if entropy_weight > 0:
            entropy_reg = entropy_regularizer(batch_counts, entropy_target_ratio, device)
            loss = loss + entropy_weight * entropy_reg

        # Diversity regularization
        if diversity_weight > 0:
            div_loss = diversity_regularizer(indices)
            loss = loss + diversity_weight * div_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon += rec_loss.item()
        total_vq += loss_vq.item()

    n_batches = max(1, len(loader))
    return {
        'loss': total_loss / n_batches,
        'recon': total_recon / n_batches,
        'vq': total_vq / n_batches,
        'token_counts': epoch_hist.numpy()
    }


def validate(model, loader, device) -> Dict:
    """Validation"""

    model.eval()
    total_loss = 0

    with torch.no_grad():
        for x, _, _ in loader:
            x = x.to(device)
            recon, _, loss_vq, _ = model(x)
            rec_loss = nn.functional.mse_loss(recon, x)
            loss = rec_loss + loss_vq
            total_loss += loss.item()

    return {'loss': total_loss / max(1, len(loader))}


def train_model(
    data_dir: str = Config.PROCESSED_DATA_DIR,
    output_dir: str = Config.OUTPUT_DIR,
    pretrained_path: str = None,
    epochs: int = Config.EPOCHS,
    config: Config = None
) -> Tuple[nn.Module, List[Dict], Path]:
    """
    Train Model

    Parameters
    ----------
    data_dir : str
        Preprocessed data directory
    output_dir : str
        Output directory
    pretrained_path : str
        Pretrained model path (optional)
    epochs : int
        Number of training epochs
    config : Config
        Configuration object

    Returns
    -------
    model : nn.Module
        Trained model
    history : List[Dict]
        Training history
    best_model_path : Path
        Path to best model
    """

    if config is None:
        config = Config()

    print("\n" + "=" * 70)
    print("Step 2: Train Model")
    print("=" * 70)

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Create DataLoaders
    print("\nCreating DataLoaders...")
    train_loader, val_loader, _ = create_dataloaders(
        data_dir, config.WIN_SIZE, config.STEP, config.BATCH_SIZE
    )
    print(f"   Train: {len(train_loader.dataset)} windows, {len(train_loader)} batches")
    print(f"   Val: {len(val_loader.dataset)} windows, {len(val_loader)} batches")

    # Initialize model
    print("\nInitializing model...")
    from tran_update2 import VQTransAE

    model = VQTransAE(
        win_size=config.WIN_SIZE,
        in_dim=config.IN_DIM,
        hidden=config.HIDDEN_DIM,
        latent=config.LATENT_DIM,
        codebook=config.CODEBOOK_SIZE,
        d_model=config.D_MODEL,
        heads=config.N_HEADS,
        layers=config.N_LAYERS
    ).to(device)

    # Re-initialize codebook (critical!)
    print("   Re-initializing codebook")
    nn.init.xavier_uniform_(model.quant.embed.weight)

    # Load pretrained weights (optional)
    if pretrained_path and Path(pretrained_path).exists():
        print(f"   Loading pretrained weights: {pretrained_path}")
        checkpoint = torch.load(pretrained_path, map_location=device)
        state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint

        # Only load encoder/decoder, skip codebook
        filtered_state = {k: v for k, v in state_dict.items()
                         if 'quant.embed' not in k and 'ema' not in k}
        model.load_state_dict(filtered_state, strict=False)
        print("   Loaded encoder/decoder weights (skipped codebook)")

    # Optimizer
    optimizer = optim.AdamW([
        {'params': model.encoder.parameters(), 'lr': config.LEARNING_RATE},
        {'params': model.decoder.parameters(), 'lr': config.LEARNING_RATE},
        {'params': model.quant.parameters(), 'lr': config.CODEBOOK_LR},
        {'params': model.tf_layers.parameters(), 'lr': config.LEARNING_RATE},
    ], weight_decay=config.WEIGHT_DECAY)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

    # Training history
    history = []
    best_val_loss = float('inf')
    best_active_tokens = 0
    best_model_state = None

    print("\n" + "=" * 70)
    print(f"Starting training ({epochs} epochs)")
    print(f"   Entropy weight: {config.ENTROPY_WEIGHT}")
    print(f"   Diversity weight: {config.DIVERSITY_WEIGHT}")
    print(f"   VQ weight: {config.VQ_WEIGHT_BASE}")
    print("=" * 70)

    for epoch in range(1, epochs + 1):
        # Adaptive VQ weight
        if epoch < 30:
            vq_weight = config.VQ_WEIGHT_BASE * 1.5
        elif epoch < 60:
            vq_weight = config.VQ_WEIGHT_BASE
        else:
            vq_weight = config.VQ_WEIGHT_BASE * 0.8

        # Training
        train_result = train_epoch(
            model, train_loader, optimizer, device, epoch,
            config.ENTROPY_WEIGHT, config.ENTROPY_TARGET_RATIO,
            config.DIVERSITY_WEIGHT, vq_weight
        )

        # Token statistics
        active_tokens, perplexity, _ = token_stats(train_result['token_counts'])

        # Codebook refresh
        refresh_count = 0
        if epoch % config.REFRESH_EVERY == 0:
            threshold = config.MIN_USAGE_THRESHOLD
            if active_tokens < config.ACTIVE_TOKEN_TARGET // 2:
                threshold = max(3, threshold // 2)
            refresh_count = refresh_inactive_codes(model, train_result['token_counts'], threshold)

        # Validation
        val_result = validate(model, val_loader, device)

        # Update learning rate
        scheduler.step()

        # Record
        record = {
            'epoch': epoch,
            'train_loss': train_result['loss'],
            'train_recon': train_result['recon'],
            'train_vq': train_result['vq'],
            'val_loss': val_result['loss'],
            'active_tokens': active_tokens,
            'perplexity': perplexity,
            'refresh_count': refresh_count
        }
        history.append(record)

        # Save best model
        if active_tokens > best_active_tokens or \
           (active_tokens == best_active_tokens and val_result['loss'] < best_val_loss):
            best_val_loss = val_result['loss']
            best_active_tokens = active_tokens
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        # Print progress
        status = f"Epoch {epoch:3d}: Loss={train_result['loss']:.4f}, Val={val_result['loss']:.4f}, "
        status += f"Active={active_tokens:4d}/1024, Perplexity={perplexity:.1f}"
        if refresh_count > 0:
            status += f", Refreshed={refresh_count}"
        print(status)

        # Periodic save
        if epoch % 20 == 0 or epoch == epochs:
            save_path = output_path / f'model_epoch_{epoch:03d}.pth'
            torch.save({
                'state_dict': model.state_dict(),
                'epoch': epoch,
                'active_tokens': active_tokens,
                'perplexity': perplexity,
                'hyperparams': {
                    'win_size': config.WIN_SIZE,
                    'in_dim': config.IN_DIM,
                    'hidden': config.HIDDEN_DIM,
                    'latent': config.LATENT_DIM,
                    'codebook': config.CODEBOOK_SIZE,
                    'd_model': config.D_MODEL,
                    'heads': config.N_HEADS,
                    'layers': config.N_LAYERS
                }
            }, save_path)
            print(f"   Saved: {save_path}")

    # Save best model
    print("\n" + "=" * 70)
    print(f"Training complete!")
    print(f"   Best active tokens: {best_active_tokens}/1024")
    print("=" * 70)

    best_path = output_path / 'best_model.pth'
    torch.save({
        'state_dict': best_model_state,
        'active_tokens': best_active_tokens,
        'hyperparams': {
            'win_size': config.WIN_SIZE,
            'in_dim': config.IN_DIM,
            'hidden': config.HIDDEN_DIM,
            'latent': config.LATENT_DIM,
            'codebook': config.CODEBOOK_SIZE,
            'd_model': config.D_MODEL,
            'heads': config.N_HEADS,
            'layers': config.N_LAYERS
        }
    }, best_path)
    print(f"   Best model: {best_path}")

    # Save training history
    save_json(history, output_path / 'training_history.json')

    # Plot training curves
    plot_training_curves(history, output_path / 'training_curves.png')

    return model, history, best_path


# ============================================================================
# Step 5: Evaluation Functions
# ============================================================================

def compute_scores(model, data_loader, device) -> Dict[str, np.ndarray]:
    """
    Compute three component scores

    Returns
    -------
    scores : Dict
        - E: Reconstruction error (higher means more anomalous)
        - D: Quantization distance (higher means more anomalous)
        - A: Attention dispersion (lower means more anomalous!)
        - labels: Ground truth labels
        - speeds: Speed values
    """

    recon_errors = []
    quant_dists = []
    attention_scores = []
    all_labels = []
    all_speeds = []

    model.eval()
    with torch.no_grad():
        for x, labels, speeds in data_loader:
            x = x.to(device)

            # Forward pass
            recon, attn_list, _, indices = model(x)

            # 1. Reconstruction error
            rec_err = nn.functional.mse_loss(recon, x, reduction='none').mean(dim=(1, 2))
            recon_errors.append(rec_err.cpu().numpy())

            # 2. Quantization distance
            B, T, _ = x.shape
            z_e = model.encoder(x)
            quantized = model.quant.embed(indices.view(-1)).view_as(z_e)
            q_dist = (z_e - quantized).pow(2).sum(dim=2).mean(dim=1)
            quant_dists.append(q_dist.cpu().numpy())

            # 3. Attention dispersion
            attn = attn_list[-1].mean(dim=1).clamp_min(1e-9)
            t_dim = attn.shape[-1]
            log_t = math.log(max(t_dim, 2))
            entropy = -(attn * torch.log(attn)).sum(dim=-1) / (log_t + 1e-9)
            gini_like = 1.0 - attn.max(dim=-1).values
            composite_attn = 0.5 * (entropy.mean(dim=-1) + gini_like.mean(dim=-1))
            attention_scores.append(composite_attn.cpu().numpy())

            all_labels.append(labels.numpy())
            all_speeds.append(speeds.numpy())

    return {
        'E': np.nan_to_num(np.concatenate(recon_errors), nan=0.0, posinf=0.0, neginf=0.0),
        'D': np.nan_to_num(np.concatenate(quant_dists), nan=0.0, posinf=0.0, neginf=0.0),
        'A': np.nan_to_num(np.concatenate(attention_scores), nan=0.0, posinf=0.0, neginf=0.0),
        'labels': np.concatenate(all_labels),
        'speeds': np.concatenate(all_speeds)
    }


def compute_composite_score(scores: Dict, alpha: float, beta: float, gamma: float) -> np.ndarray:

    groups = speed_bins(scores['speeds'])

    E_norm = np.nan_to_num(robust_scale_by_group(scores['E'], groups), nan=0.0, posinf=0.0, neginf=0.0)
    D_norm = np.nan_to_num(robust_scale_by_group(scores['D'], groups), nan=0.0, posinf=0.0, neginf=0.0)
    A_norm = np.nan_to_num(robust_scale_by_group(scores['A'], groups), nan=0.0, posinf=0.0, neginf=0.0)

    composite = alpha * E_norm + beta * D_norm + gamma * A_norm
    composite = np.nan_to_num(composite, nan=0.0, posinf=0.0, neginf=0.0)

    return composite, E_norm, D_norm, A_norm


def evaluate_with_threshold(scores: np.ndarray, labels: np.ndarray,
                            threshold: float) -> Dict:

    preds = (scores > threshold).astype(int)

    tp = ((preds == 1) & (labels == 1)).sum()
    fp = ((preds == 1) & (labels == 0)).sum()
    tn = ((preds == 0) & (labels == 0)).sum()
    fn = ((preds == 0) & (labels == 1)).sum()

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    accuracy = (tp + tn) / (tp + fp + tn + fn + 1e-8)
    specificity = tn / (tn + fp + 1e-8)

    return {
        'precision': float(precision),
        'recall': float(recall),
        'f1': float(f1),
        'accuracy': float(accuracy),
        'specificity': float(specificity),
        'tp': int(tp), 'fp': int(fp), 'tn': int(tn), 'fn': int(fn)
    }


def evaluate_model(
    model_path: str,
    data_dir: str = Config.PROCESSED_DATA_DIR,
    output_dir: str = None,
    alpha: float = Config.ALPHA,
    beta: float = Config.BETA,
    gamma: float = Config.GAMMA,
    percentile: int = Config.THRESHOLD_PERCENTILE
) -> Dict:
    """
    Evaluate Model

    Threshold Selection Principle:
    1. Validation set contains only normal samples
    2. Compute composite score distribution of validation set
    3. Take Nth percentile as threshold
    4. Samples with scores above threshold are classified as anomalies

    Parameters
    ----------
    model_path : str
        Model path
    data_dir : str
        Data directory
    output_dir : str
        Output directory (optional)
    alpha, beta, gamma : float
        Composite score weights
    percentile : int
        Threshold percentile

    Returns
    -------
    results : Dict
        Evaluation results
    """

    print("\n" + "=" * 70)
    print("Step 3: Evaluate Model")
    print("=" * 70)
    print(f"   Composite score: S = {alpha}*E_norm + {beta}*D_norm + ({gamma})*A_norm")
    print(f"   Threshold percentile: {percentile}th percentile")
    print("=" * 70)

    if output_dir is None:
        output_dir = Path(model_path).parent / 'evaluation'
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nDevice: {device}")

    # Load model
    print("\nLoading model...")
    from tran_update2 import VQTransAE

    checkpoint = torch.load(model_path, map_location=device)
    hyper = checkpoint.get('hyperparams', {
        'win_size': Config.WIN_SIZE, 'in_dim': Config.IN_DIM,
        'hidden': 64, 'latent': 32, 'codebook': 1024,
        'd_model': 64, 'heads': 4, 'layers': 3
    })

    model = VQTransAE(
        hyper['win_size'], hyper['in_dim'],
        hidden=hyper['hidden'], latent=hyper['latent'],
        codebook=hyper['codebook'], d_model=hyper['d_model'],
        heads=hyper['heads'], layers=hyper['layers']
    ).to(device)

    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    active_tokens = checkpoint.get('active_tokens', 'N/A')
    print(f"   Active tokens: {active_tokens}")

    # Create DataLoaders
    print("\nCreating DataLoaders...")
    _, val_loader, test_loader = create_dataloaders(data_dir)

    print(f"   Validation windows: {len(val_loader.dataset)}")
    print(f"   Test windows: {len(test_loader.dataset)}")

    # Compute scores
    print("\nComputing scores...")
    val_scores = compute_scores(model, val_loader, device)
    test_scores = compute_scores(model, test_loader, device)

    # Compute composite scores
    val_composite, _, _, _ = compute_composite_score(val_scores, alpha, beta, gamma)
    test_composite, E_norm, D_norm, A_norm = compute_composite_score(test_scores, alpha, beta, gamma)
    test_labels = test_scores['labels']

    # Component separability
    print("\nComponent separability (anomaly mean - normal mean):")
    separations = {}
    for name, data in [('E', E_norm), ('D', D_norm), ('A', A_norm), ('S', test_composite)]:
        normal = data[test_labels == 0]
        anomaly = data[test_labels == 1]
        sep = anomaly.mean() - normal.mean()
        separations[name] = sep
        print(f"   {name}: {sep:+.4f}")

    # Determine threshold - use valid values to compute percentile
    valid_val = val_composite[np.isfinite(val_composite)]
    if len(valid_val) > 0:
        threshold = np.percentile(valid_val, percentile)
    else:
        threshold = 0.0
    print(f"\nThreshold: {threshold:.4f} ({percentile}th percentile of validation set)")

    # Evaluate
    metrics = evaluate_with_threshold(test_composite, test_labels, threshold)

    print("\n" + "=" * 70)
    print("Evaluation Results")
    print("=" * 70)
    print(f"   Precision:   {metrics['precision']:.4f}")
    print(f"   Recall:      {metrics['recall']:.4f}")
    print(f"   F1 Score:    {metrics['f1']:.4f}")
    print(f"   Accuracy:    {metrics['accuracy']:.4f}")
    print(f"   Specificity: {metrics['specificity']:.4f}")
    print(f"\n   Confusion Matrix:")
    print(f"      TP={metrics['tp']}, FP={metrics['fp']}")
    print(f"      FN={metrics['fn']}, TN={metrics['tn']}")

    # PR-AUC - filter out NaN values
    valid_mask = np.isfinite(test_composite) & np.isfinite(test_labels)
    valid_composite = test_composite[valid_mask]
    valid_labels = test_labels[valid_mask]

    if len(valid_labels) > 0 and len(np.unique(valid_labels)) > 1:
        pr_auc = average_precision_score(valid_labels, valid_composite)
    else:
        pr_auc = float('nan')
    print(f"\n   PR-AUC:  {pr_auc:.4f}")
    print("=" * 70)

    # Performance at different percentiles
    print("\nPerformance at different percentile thresholds:")
    percentile_results = []
    for pct in [10, 15, 20, 25, 30, 40, 50, 60, 70, 80]:
        th = np.percentile(val_composite, pct)
        m = evaluate_with_threshold(test_composite, test_labels, th)
        percentile_results.append({'percentile': pct, 'threshold': th, **m})
        print(f"   {pct:2d}th: F1={m['f1']:.4f}, P={m['precision']:.4f}, R={m['recall']:.4f}")

    # Find best percentile
    best_pct_result = max(percentile_results, key=lambda x: x['f1'])
    print(f"\nBest percentile: {best_pct_result['percentile']}th, F1={best_pct_result['f1']:.4f}")

    # Generate visualizations
    print("\nGenerating visualizations...")
    plot_evaluation_results(
        val_composite, test_composite, test_labels,
        threshold, metrics, pr_auc, output_path
    )

    # Save results
    results = {
        'configuration': {
            'formula': f'S = {alpha}*E_norm + {beta}*D_norm + ({gamma})*A_norm',
            'alpha': alpha, 'beta': beta, 'gamma': gamma,
            'threshold_percentile': percentile,
            'threshold_value': float(threshold)
        },
        'separations': {k: float(v) for k, v in separations.items()},
        'metrics': metrics,
        'pr_auc': float(pr_auc),
        'percentile_results': percentile_results,
        'best_percentile_result': best_pct_result,
        'model_info': {'active_tokens': active_tokens}
    }

    save_json(results, output_path / 'evaluation_results.json')

    # Save scores
    np.savez(
        output_path / 'scores.npz',
        val_composite=val_composite,
        test_composite=test_composite,
        test_E_norm=E_norm,
        test_D_norm=D_norm,
        test_A_norm=A_norm,
        test_labels=test_labels
    )

    print(f"\nResults saved to: {output_path}")

    return results


# ============================================================================
# Visualization Functions
# ============================================================================

def plot_training_curves(history: List[Dict], save_path: Path):
    """Plot training curves"""

    save_path.parent.mkdir(parents=True, exist_ok=True)

    epochs = [h['epoch'] for h in history]
    train_loss = [h['train_loss'] for h in history]
    val_loss = [h['val_loss'] for h in history]
    active_tokens = [h['active_tokens'] for h in history]
    perplexity = [h['perplexity'] for h in history]

    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    axes[0, 0].plot(epochs, train_loss, 'b-', label='Train')
    axes[0, 0].plot(epochs, val_loss, 'r-', label='Val')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training & Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    axes[0, 1].plot(epochs, active_tokens, 'g-', linewidth=2)
    axes[0, 1].axhline(Config.ACTIVE_TOKEN_TARGET, color='r', linestyle='--',
                       label=f'Target={Config.ACTIVE_TOKEN_TARGET}')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Active Tokens')
    axes[0, 1].set_title('Codebook Utilization')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    axes[1, 0].plot(epochs, perplexity, 'm-', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Perplexity')
    axes[1, 0].set_title('Token Distribution Perplexity')
    axes[1, 0].grid(True, alpha=0.3)

    recon_loss = [h['train_recon'] for h in history]
    vq_loss = [h['train_vq'] for h in history]
    axes[1, 1].plot(epochs, recon_loss, 'b-', label='Recon')
    axes[1, 1].plot(epochs, vq_loss, 'orange', label='VQ')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].set_title('Reconstruction vs VQ Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()
    print(f"Training curves saved: {save_path}")


def plot_evaluation_results(val_scores, test_scores, test_labels,
                            threshold, metrics, pr_auc, output_dir):
    """Plot evaluation results"""

    output_path = Path(output_dir)

    # 1. Score distribution
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    test_normal = test_scores[test_labels == 0]
    test_anomaly = test_scores[test_labels == 1]

    axes[0].hist(val_scores, bins=50, alpha=0.5, label=f'Val (n={len(val_scores)})',
                 color='green', density=True)
    axes[0].hist(test_normal, bins=50, alpha=0.5, label=f'Test Normal (n={len(test_normal)})',
                 color='steelblue', density=True)
    axes[0].hist(test_anomaly, bins=50, alpha=0.5, label=f'Test Anomaly (n={len(test_anomaly)})',
                 color='coral', density=True)
    axes[0].axvline(threshold, color='red', linestyle='--', linewidth=2,
                    label=f'Threshold={threshold:.2f}')
    axes[0].set_xlabel('Composite Score')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Score Distribution')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Performance metrics
    metric_names = ['Precision', 'Recall', 'F1', 'Specificity']
    metric_values = [metrics['precision'], metrics['recall'], metrics['f1'], metrics['specificity']]
    colors = ['steelblue', 'coral', 'green', 'purple']

    bars = axes[1].bar(metric_names, metric_values, color=colors)
    axes[1].set_ylim(0, 1.1)
    axes[1].set_ylabel('Score')
    axes[1].set_title(f'Performance Metrics (F1 = {metrics["f1"]:.4f})')
    axes[1].grid(True, alpha=0.3, axis='y')

    for bar, val in zip(bars, metric_values):
        axes[1].annotate(f'{val:.3f}', xy=(bar.get_x() + bar.get_width()/2, val),
                        xytext=(0, 3), textcoords='offset points', ha='center')

    plt.tight_layout()
    plt.savefig(output_path / 'score_distribution.png', dpi=150)
    plt.close()

    # 2. PR Curve
    fig, ax = plt.subplots(figsize=(8, 6))

    precision_curve, recall_curve, _ = precision_recall_curve(test_labels, test_scores)
    ax.plot(recall_curve, precision_curve, linewidth=2, label=f'PR-AUC = {pr_auc:.4f}')
    ax.scatter([metrics['recall']], [metrics['precision']], color='red', s=100, zorder=5,
               label=f'Operating Point')
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_title('Precision-Recall Curve')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path / 'pr_curve.png', dpi=150)
    plt.close()

    # 3. Confusion matrix
    fig, ax = plt.subplots(figsize=(6, 5))

    cm = np.array([[metrics['tn'], metrics['fp']],
                   [metrics['fn'], metrics['tp']]])

    im = ax.imshow(cm, cmap='Blues')
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels(['Pred Normal', 'Pred Anomaly'])
    ax.set_yticklabels(['True Normal', 'True Anomaly'])
    ax.set_title('Confusion Matrix')

    for i in range(2):
        for j in range(2):
            ax.text(j, i, cm[i, j], ha='center', va='center',
                   color='white' if cm[i, j] > cm.max()/2 else 'black', fontsize=16)

    plt.colorbar(im)
    plt.tight_layout()
    plt.savefig(output_path / 'confusion_matrix.png', dpi=150)
    plt.close()

    print(f"Evaluation plots saved to: {output_path}")


# ============================================================================
# Main Function: Complete Pipeline
# ============================================================================

def run_complete_pipeline(
    data_dir: str = Config.PROCESSED_DATA_DIR,
    output_dir: str = Config.OUTPUT_DIR,
    pretrained_path: str = None,
    epochs: int = Config.EPOCHS,
    skip_training: bool = False
) -> Dict:
    """
    Run Complete Pipeline (Training and Evaluation)

    Parameters
    ----------
    data_dir : str
        Processed data directory (must contain train.csv, val.csv, test.csv)
    output_dir : str
        Output directory
    pretrained_path : str
        Pretrained model path (optional)
    epochs : int
        Number of training epochs
    skip_training : bool
        Whether to skip training (use existing model)

    Returns
    -------
    results : Dict
        Complete results
    """

    print("=" * 70)
    print("VQTransAE Training and Evaluation Pipeline")
    print("=" * 70)
    print(f"   Data directory: {data_dir}")
    print(f"   Output directory: {output_dir}")
    print(f"   Training epochs: {epochs}")
    print("=" * 70)

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    processed_dir = Path(data_dir)

    # Step 1: Training
    if not skip_training:
        model, history, best_model_path = train_model(
            data_dir=str(processed_dir),
            output_dir=str(output_path / 'models'),
            pretrained_path=pretrained_path,
            epochs=epochs
        )
    else:
        print("\nSkipping training")
        best_model_path = output_path / 'models' / 'best_model.pth'

    # Step 3: Evaluation
    results = evaluate_model(
        model_path=str(best_model_path),
        data_dir=str(processed_dir),
        output_dir=str(output_path / 'evaluation'),
        alpha=Config.ALPHA,
        beta=Config.BETA,
        gamma=Config.GAMMA,
        percentile=Config.THRESHOLD_PERCENTILE
    )

    # Generate paper summary
    print("\n" + "=" * 70)
    print("Paper Summary")
    print("=" * 70)
    print(f"""
VQTransAE model performance on bike road surface anomaly detection:

Model Configuration:
- Codebook size: {Config.CODEBOOK_SIZE}
- Window size: {Config.WIN_SIZE}, Step: {Config.STEP}

Composite Anomaly Score:
- Formula: S = {Config.ALPHA}*E_norm + {Config.BETA}*D_norm + ({Config.GAMMA})*A_norm
- E_norm: Reconstruction error (separability: +{results['separations']['E']:.3f})
- D_norm: Quantization distance (separability: +{results['separations']['D']:.3f})
- A_norm: Attention dispersion (separability: {results['separations']['A']:.3f})

Performance Metrics (@{Config.THRESHOLD_PERCENTILE}th percentile threshold):
- F1 Score: {results['metrics']['f1']:.4f}
- Precision: {results['metrics']['precision']:.4f}
- Recall: {results['metrics']['recall']:.4f}
- PR-AUC: {results['pr_auc']:.4f}

Best percentile: {results['best_percentile_result']['percentile']}th
Best F1: {results['best_percentile_result']['f1']:.4f}
""")
    print("=" * 70)

    print("\nComplete pipeline finished!")
    print(f"   All results saved to: {output_path}")

    return results


# ============================================================================
# Entry Point
# ============================================================================

if __name__ == '__main__':
    # Run complete pipeline using processed data
    results = run_complete_pipeline(
        data_dir='./vqtransae_results/processed_data',  # Use existing processed data
        output_dir='./vqtransae_results',
        pretrained_path='./win_56/best_vqtransae.pth',  # Optional: load pretrained weights
        epochs=100,
        skip_training=False
    )


VQTransAE Training and Evaluation Pipeline
   Data directory: ./vqtransae_results/processed_data
   Output directory: ./vqtransae_results
   Training epochs: 100

Step 2: Train Model
Device: cuda

Creating DataLoaders...
   Train: 2073 windows, 32 batches
   Val: 357 windows, 6 batches

Initializing model...
   Re-initializing codebook
   Loading pretrained weights: ./win_56/best_vqtransae.pth
   Loaded encoder/decoder weights (skipped codebook)

Starting training (100 epochs)
   Entropy weight: 0.3
   Diversity weight: 0.2
   VQ weight: 2.0
Epoch   1: Loss=2.6207, Val=2.1678, Active=  10/1024, Perplexity=1.1
Epoch   2: Loss=2.5700, Val=2.1585, Active=   1/1024, Perplexity=1.0
Epoch   3: Loss=2.5655, Val=2.1791, Active=   1/1024, Perplexity=1.0
Epoch   4: Loss=2.5692, Val=2.1828, Active=   1/1024, Perplexity=1.0
Epoch   5: Loss=2.5728, Val=2.1690, Active=   1/1024, Perplexity=1.0, Refreshed=1023
Epoch   6: Loss=2.5372, Val=2.1706, Active= 177/1024, Perplexity=1.2
Epoch   7: Loss=2.5649