In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set environment
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
plt.switch_backend('Agg')

def set_global_seed(seed=42):
    """Set global random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Global random seed set to: {seed}")

def ensure_1d_tensor(tensor):
    """Ensure tensor is 1D with robust handling"""
    if not isinstance(tensor, torch.Tensor):
        tensor = torch.tensor(tensor, dtype=torch.float32)
    tensor = tensor.squeeze()
    if tensor.dim() == 0:
        tensor = tensor.unsqueeze(0)
    elif tensor.dim() > 1:
        tensor = tensor.flatten()
    return tensor.to(dtype=torch.float32)

class SpacingFocusedEarlyStopping:
    """Spacing-focused early stopping strategy"""
    def __init__(self, patience=10, min_delta=1e-3, spacing_weight=0.6):
        self.patience = patience
        self.min_delta = min_delta
        self.spacing_weight = spacing_weight
        self.best_score = float('inf')
        self.wait = 0
    
    def __call__(self, val_loss, val_spacing_rmse):
        combined_score = (1 - self.spacing_weight) * val_loss + self.spacing_weight * val_spacing_rmse
        if combined_score < self.best_score - self.min_delta:
            self.best_score = combined_score
            self.wait = 0
            return False
        self.wait += 1
        return self.wait >= self.patience

class OVRVModel:
    """Optimal Velocity Relative Velocity (OVRV) model"""
    def __init__(self, k1, k2, eta, tau):
        self.k1, self.k2, self.eta, self.tau = k1, k2, eta, tau
    
    def predict_acceleration(self, spacing, follow_speed, lead_speed):
        spacing = torch.tensor(spacing, dtype=torch.float32) if not isinstance(spacing, torch.Tensor) else spacing
        follow_speed = torch.tensor(follow_speed, dtype=torch.float32) if not isinstance(follow_speed, torch.Tensor) else follow_speed
        lead_speed = torch.tensor(lead_speed, dtype=torch.float32) if not isinstance(lead_speed, torch.Tensor) else lead_speed
        
        spacing = ensure_1d_tensor(spacing)
        follow_speed = ensure_1d_tensor(follow_speed)
        lead_speed = ensure_1d_tensor(lead_speed)
        
        acc = self.k1 * (spacing - self.eta - self.tau * follow_speed) + self.k2 * (lead_speed - follow_speed)
        return torch.clamp(acc, -10.0, 10.0)

class IDMModel:
    """Intelligent Driver Model (IDM)"""
    def __init__(self, a, b, delta, s0, T, v0):
        self.a = max(abs(a), 1e-6)
        self.b = max(abs(b), 1e-6)
        self.delta, self.s0, self.T, self.v0 = delta, s0, T, v0
    
    def predict_acceleration(self, spacing, follow_speed, lead_speed):
        spacing = torch.tensor(spacing, dtype=torch.float32) if not isinstance(spacing, torch.Tensor) else spacing
        follow_speed = torch.tensor(follow_speed, dtype=torch.float32) if not isinstance(follow_speed, torch.Tensor) else follow_speed
        lead_speed = torch.tensor(lead_speed, dtype=torch.float32) if not isinstance(lead_speed, torch.Tensor) else lead_speed
        
        spacing = ensure_1d_tensor(spacing)
        follow_speed = ensure_1d_tensor(follow_speed)
        lead_speed = ensure_1d_tensor(lead_speed)
        
        s_star = self.s0 + follow_speed * self.T + (follow_speed * (follow_speed - lead_speed)) / (2 * torch.sqrt(torch.tensor(self.a * self.b)))
        spacing_safe = torch.clamp(spacing, min=0.1)
        acc = self.a * (1 - (follow_speed / self.v0) ** self.delta - (s_star / spacing_safe) ** 2)
        return torch.clamp(acc, -10.0, 10.0)

class EnhancedDataset(Dataset):
    """Streamlined dataset with 6D features"""
    def __init__(self, data, traditional_model, seq_length=12, delta_t=0.02, device='cpu', 
                 noise_std=0.005, augment_prob=0.12, model_type='OVRV'):
        self.seq_length = seq_length
        self.delta_t = delta_t
        self.device = device
        self.noise_std = noise_std if model_type != 'OVRV' else 0.002
        self.augment_prob = augment_prob if model_type != 'OVRV' else 0.08
        self.training = False
        
        # Core data tensors
        self.lead_speed = torch.tensor(data['lead_speed'].values, dtype=torch.float32, device=device)
        self.follow_speed = torch.tensor(data['follow_speed'].values, dtype=torch.float32, device=device)
        self.spacing = torch.tensor(data['spacing'].values, dtype=torch.float32, device=device)
        self.actual_acc = torch.tensor(data['actual_acc'].values, dtype=torch.float32, device=device)
        
        # Derived features
        self.relative_speed = self.lead_speed - self.follow_speed
        self.base_acc = self._compute_base_predictions(traditional_model)
        self.safety_margin = self._compute_safety_margin()
        
        # Prepare features for sequence modeling
        self.next_spacing = self.spacing[1:]
        self.current_lead_speed = self.lead_speed[:-1]
        self.current_follow_speed = self.follow_speed[:-1]
        self.actual_acc = self.actual_acc[:-1]
        self.features = self._extract_features()[:-1]
        
        print(f"Dataset - Size: {len(self.features)}, Features: {self.features.shape[1]}D")
    
    def _compute_safety_margin(self):
        safe_distance = 2.0 + 1.5 * torch.clamp(self.follow_speed, min=0)
        safety_margin = (self.spacing - safe_distance) / torch.clamp(safe_distance, min=1e-6)
        return torch.clamp(safety_margin, -5.0, 8.0)
    
    def _compute_base_predictions(self, model):
        n = len(self.lead_speed)
        base_acc = torch.zeros(n, device=self.device)
        for i in range(n):
            try:
                acc = model.predict_acceleration(self.spacing[i], self.follow_speed[i], self.lead_speed[i])
                base_acc[i] = acc.item() if acc.dim() > 0 else acc
            except:
                base_acc[i] = 0.0
        return base_acc
    
    def _extract_features(self):
        return torch.stack([
            self.lead_speed, self.follow_speed, self.spacing,
            self.relative_speed, self.safety_margin, self.base_acc
        ], dim=1)
    
    def set_training(self, training):
        self.training = training
    
    def __len__(self):
        return max(0, len(self.features) - self.seq_length)
    
    def _augment_data(self, features):
        if not self.training or torch.rand(1) > self.augment_prob:
            return features
        
        noise_scales = torch.tensor([1.0, 1.0, 0.3, 0.8, 0.2, 0.5], device=features.device)
        noise = torch.randn_like(features) * self.noise_std * noise_scales.unsqueeze(0)
        augmented = features + noise
        
        # Constraints
        augmented[:, 1] = torch.clamp(augmented[:, 1], min=0)  # Speed >= 0
        augmented[:, 2] = torch.clamp(augmented[:, 2], min=0.5)  # Spacing >= 0.5
        
        # Recalculate derived features
        augmented[:, 3] = augmented[:, 0] - augmented[:, 1]  # Relative speed
        safe_distance = 2.0 + 1.5 * torch.clamp(augmented[:, 1], min=0)
        augmented[:, 4] = torch.clamp((augmented[:, 2] - safe_distance) / torch.clamp(safe_distance, min=1e-6), -5.0, 8.0)
        
        return augmented
    
    def __getitem__(self, idx):
        end_idx = idx + self.seq_length
        features = self.features[idx:end_idx].clone()
        features = self._augment_data(features)
        
        # Return batch items with proper 1D shape
        return (
            features,
            ensure_1d_tensor(self.current_follow_speed[end_idx-1]).unsqueeze(0),
            ensure_1d_tensor(self.spacing[end_idx-1]).unsqueeze(0),
            ensure_1d_tensor(self.current_lead_speed[end_idx-1]).unsqueeze(0),
            ensure_1d_tensor(self.next_spacing[end_idx-1]).unsqueeze(0),
            ensure_1d_tensor(self.actual_acc[end_idx-1]).unsqueeze(0)
        )

class PhaseAwareLSTM(nn.Module):
    """Streamlined LSTM with phase-aware attention"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Main LSTM
        self.lstm = nn.LSTM(
            config['input_size'], config['hidden_size'], config['num_layers'],
            batch_first=True, dropout=0.0
        )
        
        # Phase-specific components
        attention_dim = config['attention_dim']
        self.accel_attention = self._build_attention(config['hidden_size'], attention_dim, config['dropout'])
        self.decel_attention = self._build_attention(config['hidden_size'], attention_dim, config['dropout'])
        
        mid_dim = max(8, config['hidden_size'] // 2)
        self.accel_head = self._build_head(config['hidden_size'], mid_dim, config['dropout'])
        self.decel_head = self._build_head(config['hidden_size'], mid_dim, config['dropout'])
        
        self._init_weights()
    
    def _build_attention(self, hidden_size, attention_dim, dropout):
        return nn.Sequential(
            nn.Linear(hidden_size, attention_dim),
            nn.LayerNorm(attention_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(attention_dim, 1)
        )
    
    def _build_head(self, hidden_size, mid_dim, dropout):
        return nn.Sequential(
            nn.Linear(hidden_size, mid_dim),
            nn.LayerNorm(mid_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mid_dim, 1)
        )
    
    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:
                param.data.fill_(0)
                if 'lstm' in name:
                    n = param.size(0)
                    param.data[(n//4):(n//2)].fill_(1)
    
    def forward(self, x, relative_speed, current_spacing=None):
        batch_size, seq_len, _ = x.size()
        
        # LSTM processing
        lstm_out, _ = self.lstm(x)
        
        # Phase detection based on last time step
        last_features = x[:, -1, :]
        relative_speed_last = last_features[:, 3]
        spacing_margin = last_features[:, 4]
        
        accel_condition = (relative_speed_last > 0.1) | (spacing_margin > 2.0)
        decel_condition = (relative_speed_last < -0.1) | (spacing_margin < -1.0)
        
        # Phase weights [batch_size]
        phase_weight = torch.where(
            accel_condition, torch.tensor(0.8, device=x.device),
            torch.where(decel_condition, torch.tensor(0.2, device=x.device), torch.tensor(0.5, device=x.device))
        )
        
        # Phase-specific processing
        accel_attended = self._apply_attention(lstm_out, self.accel_attention)
        decel_attended = self._apply_attention(lstm_out, self.decel_attention)
        
        accel_residual = self.accel_head(accel_attended).squeeze(-1) * self.config['residual_scale']
        decel_residual = self.decel_head(decel_attended).squeeze(-1) * self.config['residual_scale']
        
        # Combine with phase weighting
        final_residual = phase_weight * accel_residual + (1 - phase_weight) * decel_residual
        
        return final_residual, torch.zeros(batch_size, device=x.device)
    
    def _apply_attention(self, lstm_out, attention_layer):
        batch_size, seq_len, hidden_size = lstm_out.size()
        reshaped = lstm_out.reshape(batch_size * seq_len, hidden_size)
        attention_scores = attention_layer(reshaped).reshape(batch_size, seq_len, 1)
        attention_weights = F.softmax(attention_scores, dim=1)
        return torch.sum(lstm_out * attention_weights, dim=1)

def compute_loss(predicted_acc, real_acc, current_speed, current_spacing, lead_speed, next_real_spacing, dt, 
                spacing_correction=None, alpha=0.7, beta=0.25, gamma=0.05, model_type='IDM'):
    """Unified loss computation"""
    # Ensure all tensors are 1D
    predicted_acc = ensure_1d_tensor(predicted_acc)
    real_acc = ensure_1d_tensor(real_acc)
    current_speed = ensure_1d_tensor(current_speed)
    current_spacing = ensure_1d_tensor(current_spacing)
    lead_speed = ensure_1d_tensor(lead_speed)
    next_real_spacing = ensure_1d_tensor(next_real_spacing)
    
    # OVRV-specific loss
    if model_type == 'OVRV':
        acc_diff = predicted_acc - real_acc
        acc_loss = torch.where(torch.abs(acc_diff) < 1.0, 0.5 * acc_diff ** 2, torch.abs(acc_diff) - 0.5).mean()
        
        pred_next_speed = torch.clamp(current_speed + predicted_acc * dt, min=0)
        pred_next_spacing = torch.clamp(current_spacing + (lead_speed - pred_next_speed) * dt, min=0.1)
        spacing_loss = F.mse_loss(pred_next_spacing, next_real_spacing)
        
        consistency_loss = torch.var(predicted_acc) if predicted_acc.numel() > 1 else torch.tensor(0.0, device=predicted_acc.device)
        return 0.7 * acc_loss + 0.2 * spacing_loss + 0.1 * consistency_loss
    
    # IDM loss
    acc_loss = F.mse_loss(predicted_acc, real_acc)
    
    pred_next_speed = torch.clamp(current_speed + predicted_acc * dt, min=0)
    pred_next_spacing = torch.clamp(current_spacing + (lead_speed - pred_next_speed) * dt, min=0.1)
    spacing_loss = F.mse_loss(pred_next_spacing, next_real_spacing)
    
    safe_spacing = 2.0 + 1.0 * torch.clamp(pred_next_speed, min=0)
    safety_loss = F.relu(safe_spacing - pred_next_spacing).mean()
    physics_loss = spacing_loss + 0.1 * safety_loss
    
    if spacing_correction is not None:
        spacing_correction = ensure_1d_tensor(spacing_correction)
        corrected_spacing = current_spacing + (lead_speed - pred_next_speed) * dt + spacing_correction
        corrected_spacing = torch.clamp(corrected_spacing, min=0.1)
        direct_spacing_loss = F.mse_loss(corrected_spacing, next_real_spacing)
        physics_loss = 0.7 * physics_loss + 0.3 * direct_spacing_loss
    
    reg_loss = 0.01 * torch.mean(torch.abs(predicted_acc - real_acc))
    total_loss = alpha * acc_loss + beta * physics_loss + gamma * reg_loss
    
    return total_loss

def get_loss_weights(epoch, total_epochs, model_type):
    """Dynamic loss weights"""
    if model_type == 'OVRV':
        return (0.95, 0.0, 0.05) if epoch < 40 else (0.8, 0.15, 0.05)
    else:
        if epoch < 25:
            return 0.8, 0.15, 0.05
        elif epoch < 50:
            return 0.6, 0.3, 0.1
        else:
            return 0.4, 0.5, 0.1

class TrainingManager:
    """Streamlined training manager"""
    def __init__(self, train_data, val_data, traditional_model, delta_t, seq_length=12, device='cpu'):
        self.device = device
        self.traditional_model = traditional_model
        self.delta_t = delta_t
        
        model_type = 'OVRV' if 'OVRV' in str(traditional_model.__class__.__name__) else 'IDM'
        
        self.train_dataset = EnhancedDataset(train_data, traditional_model, seq_length, delta_t, device, model_type=model_type)
        self.val_dataset = EnhancedDataset(val_data, traditional_model, seq_length, delta_t, device, 
                                         noise_std=0.0, augment_prob=0.0, model_type=model_type)
        self.val_dataset.set_training(False)
        
        self.train_dataloader = DataLoader(self.train_dataset, batch_size=4, shuffle=True, drop_last=True)
        self.val_dataloader = DataLoader(self.val_dataset, batch_size=8, shuffle=False, drop_last=True)
        
        print(f"Train batches: {len(self.train_dataloader)}, Val batches: {len(self.val_dataloader)}")
    
    def validate_model(self, model, alpha=0.7, beta=0.25, gamma=0.05, model_type='IDM'):
        model.eval()
        val_losses, val_spacing_rmses = [], []
        
        with torch.no_grad():
            for batch in self.val_dataloader:
                features, current_follow_speed, current_spacing, current_lead_speed, next_real_spacing, real_acceleration = batch
                relative_speed = current_lead_speed - current_follow_speed
                
                acc_residual_predictions, spacing_correction = model(features.contiguous(), relative_speed.contiguous(), current_spacing)
                
                base_acc_batch = ensure_1d_tensor(self.traditional_model.predict_acceleration(
                    current_spacing, current_follow_speed, current_lead_speed))
                acc_residual_predictions = ensure_1d_tensor(acc_residual_predictions)
                
                predicted_acc = torch.clamp(base_acc_batch + acc_residual_predictions, -8.0, 6.0)
                
                pred_next_speed = torch.clamp(current_follow_speed + predicted_acc * self.delta_t, min=0)
                pred_next_spacing = torch.clamp(
                    current_spacing + (current_lead_speed - pred_next_speed) * self.delta_t + spacing_correction, min=0.1)
                
                pred_next_spacing = pred_next_spacing.squeeze() if pred_next_spacing.dim() > 1 else pred_next_spacing
                next_real_spacing = next_real_spacing.squeeze() if next_real_spacing.dim() > 1 else next_real_spacing
                
                min_len = min(pred_next_spacing.size(0), next_real_spacing.size(0))
                pred_next_spacing, next_real_spacing = pred_next_spacing[:min_len], next_real_spacing[:min_len]
                
                val_loss = compute_loss(predicted_acc, real_acceleration, current_follow_speed, current_spacing,
                                      current_lead_speed, next_real_spacing, self.delta_t, spacing_correction, 
                                      alpha, beta, gamma, model_type)
                spacing_rmse = torch.sqrt(F.mse_loss(pred_next_spacing, next_real_spacing))
                
                val_losses.append(val_loss.item())
                val_spacing_rmses.append(spacing_rmse.item())
        
        return np.mean(val_losses), np.mean(val_spacing_rmses)

class ModelConfig:
    """Unified model configuration management"""
    
    @staticmethod
    def get_hybrid_config(model_name):
        """Get hybrid model configuration"""
        configs = {
            'OVRV': {
                'hidden_size': 20, 'num_layers': 2, 'dropout': 0.45, 
                'residual_scale': 0.06, 'attention_dim': 12, 'model_type': 'OVRV'
            },
            'IDM': {
                'hidden_size': 32, 'num_layers': 2, 'dropout': 0.5, 
                'residual_scale': 0.12, 'attention_dim': 16, 'model_type': 'IDM'
            }
        }
        return configs.get(model_name, configs['IDM'])
    
    @staticmethod
    def get_training_params(model_type, model_id=0, diversity_factor=0.15):
        """Get training parameters"""
        base_params = {
            'hidden_sizes': [16, 20, 24, 28, 32],
            'dropouts': [0.4, 0.45, 0.5, 0.42, 0.48],
            'lr_multipliers': [1.0, 0.8, 1.2, 0.9, 1.1]
        }
        
        return {
            'hidden_size': base_params['hidden_sizes'][model_id % 5],
            'dropout': base_params['dropouts'][model_id % 5],
            'lr_multiplier': base_params['lr_multipliers'][model_id % 5],
            'weight_decay_factor': 1 + diversity_factor * model_id
        }

class SimplifiedEnsembleTrainer:
    """Simplified ensemble trainer"""
    
    def __init__(self, n_models=5, diversity_factor=0.15):
        self.n_models = n_models
        self.diversity_factor = diversity_factor
        self.models = []
        self.training_histories = []
        self.traditional_model = None
    
    def train_models(self, train_data, val_data, traditional_model, model_type, delta_t, seq_length=12, device='cpu'):
        """Hybrid model training interface only"""
        self.traditional_model = traditional_model
        
        # Only support hybrid models now
        trainer = TrainingManager(train_data, val_data, traditional_model, delta_t, seq_length, device)
    
        print(f"\nTraining {model_type} ({'Ensemble' if self.n_models > 1 else 'Single'}) ({self.n_models} model{'s' if self.n_models > 1 else ''})")
        
        # Get configuration - only hybrid models
        base_config = ModelConfig.get_hybrid_config(model_type)
        base_config['input_size'] = 6  # Include traditional model prediction
        model_class = PhaseAwareLSTM
        
        base_config['seq_length'] = seq_length
        
        # Train models
        for i in range(self.n_models):
            torch.manual_seed(42 + i * 7)
            np.random.seed(42 + i * 7)
            
            # Get training parameters
            train_params = ModelConfig.get_training_params(model_type, i, self.diversity_factor)
            config = {**base_config, **train_params}
            
            # Only hybrid model training
            model, history = self._train_hybrid_model(trainer, config, model_class, device, i)
            
            if model is not None:
                self.models.append(model)
                self.training_histories.append(history)
                print(f"  Model {i+1}: Early stop at epoch {len(history['train_losses'])}, "
                      f"Val: {history['best_val_loss']:.4f}")
        
        return self.models, self.training_histories
    
    def _train_hybrid_model(self, trainer, config, model_class, device, model_id):
        """Train hybrid model"""
        model = model_class(config).to(device)
        
        lr = 5e-5 * config.get('lr_multiplier', 1.0)
        weight_decay = (5e-4 if config['model_type'] == 'OVRV' else 1e-4) * (1 + self.diversity_factor * model_id)
        
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.8, min_lr=1e-7)
        
        patience = 30 if config['model_type'] == 'OVRV' else 25
        early_stopping = SpacingFocusedEarlyStopping(patience=patience, min_delta=1e-3, spacing_weight=0.8 if config['model_type'] == 'OVRV' else 0.6)
        
        train_losses, val_losses, val_spacing_rmses = [], [], []
        best_val_loss, best_model_state = float('inf'), None
        max_epochs, accumulation_steps = 250, 2
        
        for epoch in range(max_epochs):
            alpha, beta, gamma = get_loss_weights(epoch, max_epochs, config['model_type'])
            
            model.train()
            trainer.train_dataset.set_training(True)
            epoch_train_losses = []
            optimizer.zero_grad()
            
            for i, batch in enumerate(trainer.train_dataloader):
                features, current_follow_speed, current_spacing, current_lead_speed, next_real_spacing, real_acceleration = batch
                relative_speed = current_lead_speed - current_follow_speed
                
                acc_residual_predictions, spacing_correction = model(features.contiguous(), relative_speed.contiguous(), current_spacing)
                
                base_acc_batch = ensure_1d_tensor(self.traditional_model.predict_acceleration(
                    current_spacing, current_follow_speed, current_lead_speed))
                acc_residual_predictions = ensure_1d_tensor(acc_residual_predictions)
                
                predicted_acc = torch.clamp(base_acc_batch + acc_residual_predictions, -8.0, 6.0)
                
                total_loss = compute_loss(predicted_acc, real_acceleration, current_follow_speed, current_spacing,
                                        current_lead_speed, next_real_spacing, trainer.delta_t, spacing_correction, 
                                        alpha, beta, gamma, config['model_type'])
                
                # L2 regularization
                l2_reg = sum(torch.norm(param) for param in model.parameters())
                total_loss += 1e-5 * l2_reg
                
                (total_loss / accumulation_steps).backward()
                
                if (i + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.3)
                    optimizer.step()
                    optimizer.zero_grad()
                
                epoch_train_losses.append(total_loss.item())
            
            avg_train_loss = np.mean(epoch_train_losses)
            train_losses.append(avg_train_loss)
            
            val_loss, val_spacing_rmse = trainer.validate_model(model, alpha, beta, gamma, config['model_type'])
            val_losses.append(val_loss)
            val_spacing_rmses.append(val_spacing_rmse)
            
            scheduler.step(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()
            
            if epoch % 20 == 0:
                print(f"  Epoch {epoch}: Train={avg_train_loss:.4f}, Val={val_loss:.4f}, Spacing={val_spacing_rmse:.4f}")
            
            if epoch >= 40 and early_stopping(val_loss, val_spacing_rmse):
                break
        
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        
        return model, {
            'train_losses': train_losses, 'val_losses': val_losses, 'val_spacing_rmses': val_spacing_rmses,
            'best_val_loss': best_val_loss, 'config': config
        }
    
    def predict(self, features, relative_speed, current_spacing=None, use_temperature=False):
        """Unified prediction method"""
        if not self.models:
            raise ValueError("No trained models in ensemble")
        
        predictions = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                pred, _ = model(features, relative_speed, current_spacing)
                if use_temperature and hasattr(model.config, 'model_type') and model.config['model_type'] == 'OVRV':
                    pred = pred / 1.2
                predictions.append(pred)
        
        return torch.mean(torch.stack(predictions), dim=0)

def evaluate_simulation_unified(model_ensemble, test_data, delta_t, seq_length=12, device='cpu', 
                               model_type='hybrid', traditional_model=None):
    """Hybrid simulation evaluation function only"""
    if len(test_data) <= seq_length or not model_ensemble.models:
        return None
    
    # Only support hybrid model evaluation
    return _evaluate_hybrid_simulation(model_ensemble, test_data, traditional_model, delta_t, seq_length, device)

def _evaluate_hybrid_simulation(ensemble_trainer, test_data, traditional_model, delta_t, seq_length, device):
    """Hybrid model simulation evaluation"""
    eval_data = test_data.iloc[seq_length:].reset_index(drop=True)
    n = len(eval_data)
    
    # Initialize arrays
    traditional_acc = np.zeros(n)
    hybrid_acc = np.zeros(n)
    traditional_speed = np.zeros(n)
    hybrid_speed = np.zeros(n)
    traditional_spacing = np.zeros(n)
    hybrid_spacing = np.zeros(n)
    
    # Initial conditions
    traditional_current_speed = test_data['follow_speed'].iloc[seq_length]
    traditional_current_spacing = test_data['spacing'].iloc[seq_length]
    hybrid_current_speed = test_data['follow_speed'].iloc[seq_length]
    hybrid_current_spacing = test_data['spacing'].iloc[seq_length]
    
    # Build initial history buffer
    history_buffer = []
    for i in range(seq_length):
        lead_speed = test_data['lead_speed'].iloc[i]
        follow_speed = test_data['follow_speed'].iloc[i]
        spacing = test_data['spacing'].iloc[i]
        relative_speed = lead_speed - follow_speed
        
        try:
            base_acc_i = traditional_model.predict_acceleration(spacing, follow_speed, lead_speed)
            base_acc_i = base_acc_i.item() if base_acc_i.dim() > 0 else base_acc_i
        except:
            base_acc_i = 0.0
        
        safe_distance = 2.0 + 1.5 * max(follow_speed, 0)
        safety_margin = (spacing - safe_distance) / max(safe_distance, 1e-6)
        safety_margin = min(max(safety_margin, -5.0), 8.0)
        
        features = [lead_speed, follow_speed, spacing, relative_speed, safety_margin, float(base_acc_i)]
        history_buffer.append(features)
    
    model_type = 'OVRV' if isinstance(traditional_model, OVRVModel) else 'IDM'
    use_temperature = (model_type == 'OVRV')
    
    # Simulation loop
    with torch.no_grad():
        for i in range(n):
            lead_speed = eval_data['lead_speed'].iloc[i]
            
            # Traditional model prediction
            try:
                trad_acc = traditional_model.predict_acceleration(traditional_current_spacing, traditional_current_speed, lead_speed)
                trad_acc = trad_acc.item() if trad_acc.dim() > 0 else trad_acc
            except:
                trad_acc = 0.0
            
            traditional_acc[i] = float(trad_acc)
            traditional_speed[i] = traditional_current_speed
            traditional_spacing[i] = traditional_current_spacing
            traditional_current_speed = max(0, traditional_current_speed + float(trad_acc) * delta_t)
            traditional_current_spacing = max(0.1, traditional_current_spacing + (lead_speed - traditional_current_speed) * delta_t)
            
            # Hybrid model prediction
            try:
                current_base_acc = traditional_model.predict_acceleration(hybrid_current_spacing, hybrid_current_speed, lead_speed)
                current_base_acc = current_base_acc.item() if current_base_acc.dim() > 0 else current_base_acc
            except:
                current_base_acc = 0.0
            
            relative_speed = lead_speed - hybrid_current_speed
            safe_distance = 2.0 + 1.5 * max(hybrid_current_speed, 0)
            safety_margin = (hybrid_current_spacing - safe_distance) / max(safe_distance, 1e-6)
            safety_margin = min(max(safety_margin, -5.0), 8.0)
            
            current_features = [lead_speed, hybrid_current_speed, hybrid_current_spacing, 
                               relative_speed, safety_margin, float(current_base_acc)]
            history_buffer.append(current_features)
            
            if len(history_buffer) > seq_length:
                history_buffer.pop(0)
            
            if len(history_buffer) == seq_length:
                features_tensor = torch.tensor(history_buffer, dtype=torch.float32, device=device).unsqueeze(0)
                relative_speed_tensor = torch.tensor([relative_speed], dtype=torch.float32, device=device)
                current_spacing_tensor = torch.tensor([hybrid_current_spacing], dtype=torch.float32, device=device)
                
                residual = ensemble_trainer.predict(features_tensor, relative_speed_tensor, current_spacing_tensor, use_temperature)
                residual = residual.item() if residual.dim() > 0 else residual
                hybrid_acc_i = min(max(float(current_base_acc) + residual, -8.0), 6.0)
            else:
                hybrid_acc_i = float(current_base_acc)
            
            hybrid_acc[i] = hybrid_acc_i
            hybrid_speed[i] = hybrid_current_speed
            hybrid_spacing[i] = hybrid_current_spacing
            hybrid_current_speed = max(0, hybrid_current_speed + hybrid_acc_i * delta_t)
            hybrid_current_spacing = max(0.1, hybrid_current_spacing + (lead_speed - hybrid_current_speed) * delta_t)
    
    # Calculate metrics
    actual_acc = eval_data['actual_acc'].values
    actual_speed = eval_data['follow_speed'].values
    actual_spacing = eval_data['spacing'].values
    
    traditional_acc_rmse = np.sqrt(np.mean((traditional_acc - actual_acc) ** 2))
    traditional_speed_rmse = np.sqrt(np.mean((traditional_speed - actual_speed) ** 2))
    traditional_spacing_rmse = np.sqrt(np.mean((traditional_spacing - actual_spacing) ** 2))
    hybrid_acc_rmse = np.sqrt(np.mean((hybrid_acc - actual_acc) ** 2))
    hybrid_speed_rmse = np.sqrt(np.mean((hybrid_speed - actual_speed) ** 2))
    hybrid_spacing_rmse = np.sqrt(np.mean((hybrid_spacing - actual_spacing) ** 2))
    
    # Calculate improvements
    acc_improvement = (traditional_acc_rmse - hybrid_acc_rmse) / traditional_acc_rmse * 100 if traditional_acc_rmse > 0 else 0
    speed_improvement = (traditional_speed_rmse - hybrid_speed_rmse) / traditional_speed_rmse * 100 if traditional_speed_rmse > 0 else 0
    spacing_improvement = (traditional_spacing_rmse - hybrid_spacing_rmse) / traditional_spacing_rmse * 100 if traditional_spacing_rmse > 0 else 0
    
    print(f"{traditional_model.__class__.__name__} Simulation: Acc {hybrid_acc_rmse:.4f} ({acc_improvement:+.1f}%), "
          f"Spacing {hybrid_spacing_rmse:.4f} ({spacing_improvement:+.1f}%)")
    
    return {
        'time': eval_data['time'].values,
        'actual_acc': actual_acc,
        'actual_speed': actual_speed,
        'actual_spacing': actual_spacing,
        'traditional_acc': traditional_acc,
        'traditional_speed': traditional_speed,
        'traditional_spacing': traditional_spacing,
        'hybrid_acc': hybrid_acc,
        'hybrid_speed': hybrid_speed,
        'hybrid_spacing': hybrid_spacing,
        'traditional_acc_rmse': traditional_acc_rmse,
        'traditional_speed_rmse': traditional_speed_rmse,
        'traditional_spacing_rmse': traditional_spacing_rmse,
        'hybrid_acc_rmse': hybrid_acc_rmse,
        'hybrid_speed_rmse': hybrid_speed_rmse,
        'hybrid_spacing_rmse': hybrid_spacing_rmse,
        'acc_improvement': acc_improvement,
        'speed_improvement': speed_improvement,
        'spacing_improvement': spacing_improvement,
        'n_models': len(ensemble_trainer.models)
    }

def save_comprehensive_simulation_csv(ovrv_result, idm_result, dataset_name, output_dir, is_train=False):
    """Save all simulation results to one comprehensive CSV file"""
    if ovrv_result is None or idm_result is None:
        return
    
    time_axis = ovrv_result['time'] - ovrv_result['time'][0]
    
    columns = {
        'Time': time_axis,
        'Actual_Acc': ovrv_result['actual_acc'],
        'Actual_Speed': ovrv_result['actual_speed'], 
        'Actual_Spacing': ovrv_result['actual_spacing'],
        'OVRV_Traditional_Acc': ovrv_result['traditional_acc'],
        'OVRV_Traditional_Speed': ovrv_result['traditional_speed'],
        'OVRV_Traditional_Spacing': ovrv_result['traditional_spacing'],
        'OVRV_AI_Correction_Acc': ovrv_result['hybrid_acc'],
        'OVRV_AI_Correction_Speed': ovrv_result['hybrid_speed'],
        'OVRV_AI_Correction_Spacing': ovrv_result['hybrid_spacing'],
        'IDM_Traditional_Acc': idm_result['traditional_acc'],
        'IDM_Traditional_Speed': idm_result['traditional_speed'],
        'IDM_Traditional_Spacing': idm_result['traditional_spacing'],
        'IDM_AI_Correction_Acc': idm_result['hybrid_acc'],
        'IDM_AI_Correction_Speed': idm_result['hybrid_speed'],
        'IDM_AI_Correction_Spacing': idm_result['hybrid_spacing']
    }
    
    simulation_data = pd.DataFrame(columns)
    suffix = "_train" if is_train else "_test"
    csv_path = os.path.join(output_dir, f'{dataset_name}{suffix}_comprehensive_simulation_data.csv')
    simulation_data.to_csv(csv_path, index=False)
    print(f"Comprehensive simulation data saved to: {csv_path}")

def preprocess_data(data, skip_warmup=1000, skip_ending=2000, max_rows=20000):
    """Enhanced data preprocessing"""
    if len(data) < skip_warmup + skip_ending + 100:
        raise ValueError(f"Data length {len(data)} too short")
    
    # Extract data range
    end_idx = len(data) - skip_ending if skip_ending > 0 else len(data)
    start_idx = min(skip_warmup, end_idx - max_rows)
    end_idx = min(start_idx + max_rows, end_idx)
    data = data.iloc[start_idx:end_idx].copy()
    print(f"Using data range [{start_idx}:{end_idx}]: {len(data)} points")
    
    # Data cleaning
    data = data.dropna()
    data = data[(data['follow_speed'] >= 0) & (data['lead_speed'] >= 0)]
    data = data[(data['spacing'] > 0.5) & (data['spacing'] < 150)]
    
    # Remove outliers
    speed_99th = data[['follow_speed', 'lead_speed']].quantile(0.99).max()
    data = data[data['follow_speed'] <= speed_99th]
    data = data[data['lead_speed'] <= speed_99th]
    
    # Convert to m/s and calculate acceleration
    data[['follow_speed', 'lead_speed']] *= 5/18
    delta_t = data['time'].diff().median()
    if pd.isna(delta_t) or delta_t <= 0:
        delta_t = 0.1
        print("Warning: Invalid delta_t, defaulting to 0.1")
    
    data['actual_acc'] = data['follow_speed'].diff() / delta_t
    data['actual_acc'] = data['actual_acc'].clip(-10, 10)
    data = data.dropna()
    
    if len(data) < 100:
        raise ValueError(f"Insufficient data after preprocessing: {len(data)} points")
    
    print(f"Processed points: {len(data)}")
    return data, delta_t

def split_data(data, train_ratio=0.6, val_ratio=0.25, test_ratio=0.15):
    """Split data into train, validation, and test sets"""
    n = len(data)
    train_end = int(n * train_ratio)
    val_end = int(n * (train_ratio + val_ratio))
    
    train_data = data.iloc[:train_end].copy()
    val_data = data.iloc[train_end:val_end].copy()
    test_data = data.iloc[val_end:].copy()
    
    print(f"Data split - Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
    return train_data, val_data, test_data

def load_and_preprocess_data(file_path):
    """Data loading and preprocessing"""
    data = pd.read_csv(file_path)
    required_cols = ['Time', 'Smooth Speed Follower', 'Smooth Speed Leader', 'Spacing']
    data = data[required_cols].rename(columns={
        'Time': 'time', 'Smooth Speed Follower': 'follow_speed', 
        'Smooth Speed Leader': 'lead_speed', 'Spacing': 'spacing'
    })
    return preprocess_data(data, skip_warmup=1000, skip_ending=2000, max_rows=20000)

def plot_training_curves_unified(models_and_histories, dataset_name, output_dir):
    """Training curves plotting - OVRV and IDM only"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    model_info = [
        ('ovrv', 'OVRV Ensemble Training', 'OVRV Model'),
        ('idm', 'IDM Ensemble Training', 'IDM Model')
    ]
    
    for idx, (model_key, title, label_prefix) in enumerate(model_info):
        if model_key in models_and_histories:
            _, histories = models_and_histories[model_key]
            for i, hist in enumerate(histories):
                axes[idx].plot(hist['train_losses'], f'C{i}-', alpha=0.7, label=f'{label_prefix} {i+1} Train')
                axes[idx].plot(hist['val_losses'], f'C{i}--', alpha=0.7, label=f'{label_prefix} {i+1} Val')
            
            axes[idx].set_title(title)
            axes[idx].set_ylabel('Loss')
            axes[idx].set_xlabel('Epoch')
            axes[idx].legend()
            axes[idx].grid(True)
    
    plt.tight_layout()
    save_path = os.path.join(output_dir, f'{dataset_name}_training_curves.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Training curves saved: {save_path}")

def plot_simulation_results_unified(results, dataset_name, output_dir, is_train=False):
    """Simulation results plotting - OVRV and IDM only"""
    fig, axes = plt.subplots(3, 2, figsize=(12, 12))
    
    for col, (title_prefix, result_key) in enumerate([('OVRV', 'ovrv'), ('IDM', 'idm')]):
        if result_key not in results or results[result_key] is None:
            continue
            
        result = results[result_key]
        time_axis = result['time']
        if not is_train:
            time_axis = time_axis - time_axis[0]
        
        metrics = [
            ('Acceleration', 'acc', 'Acceleration (m/sÂ²)'),
            ('Speed', 'speed', 'Speed (m/s)'),
            ('Spacing', 'spacing', 'Spacing (m)')
        ]
        
        for row, (metric_name, metric_key, ylabel) in enumerate(metrics):
            ax = axes[row, col]
            
            # Actual values
            ax.plot(time_axis, result[f'actual_{metric_key}'], 'k-', 
                   label='Actual ' + metric_name, linewidth=2)
            
            # Traditional model
            ax.plot(time_axis, result[f'traditional_{metric_key}'], 'r--', 
                   label=f'{title_prefix} (RMSE: {result[f"traditional_{metric_key}_rmse"]:.4f})', linewidth=2)
            
            # AI correction
            ax.plot(time_axis, result[f'hybrid_{metric_key}'], 'b-', 
                   label=f'AI Correction (RMSE: {result[f"hybrid_{metric_key}_rmse"]:.4f})', linewidth=2)
            
            ax.set_title(f'{title_prefix}: {metric_name}')
            ax.set_ylabel(ylabel)
            if row == 2:
                ax.set_xlabel('Time (s)')
            ax.legend()
            ax.grid(True)
    
    plt.tight_layout()
    suffix = "_train_simulation" if is_train else "_test_simulation"
    save_path = os.path.join(output_dir, f'{dataset_name}{suffix}.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"{'Train' if is_train else 'Test'} simulation plot saved: {save_path}")

def print_final_summary(models_and_histories, test_results):
    """Print final summary - hybrid models only"""
    print(f"\n=== FINAL SUMMARY ===")
    
    for model_name, (_, histories) in models_and_histories.items():
        if histories:
            val_losses = [hist['best_val_loss'] for hist in histories]
            min_loss = min(val_losses)
            
            if model_name in test_results and test_results[model_name]:
                spacing_improvement = test_results[model_name]['spacing_improvement']
                print(f"{model_name.upper()}: Val Loss {min_loss:.4f}, Test Spacing {spacing_improvement:+.1f}%")
    
    print(f"Hybrid model simulation completed successfully!")

def main():
    """Simplified main function"""
    set_global_seed(42)
    
    # Setup directories
    base_dir = r"C:\Users\yliu117\Desktop\CarFollowing_Results"
    output_dir = os.path.join(base_dir, f"StreamlinedEnsemble_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
    os.makedirs(output_dir, exist_ok=True)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")
    
    # Initialize results variable
    test_results = {}
    
    # Dataset configuration
    dataset_config = {
        'short_55': {
            'path': r'C:\Users\yliu117\Desktop\EV\EV-ACC data\combined\short_55.csv',
            'ovrv_params': {'k1': 0.0717, 'k2': 0.6541, 'eta': 17.9107, 'tau': 0.5452},
            'idm_params': {'a': 1.6932, 'b': 10.0000, 'delta': 5.0000, 's0': 6.0000, 'T': 1.0325, 'v0': 40.0000}
        }
    }
    
    for dataset_name, config in dataset_config.items():
        print(f"\nProcessing dataset: {dataset_name}")
        
        # Data loading and preprocessing
        try:
            data, delta_t = load_and_preprocess_data(config['path'])
        except Exception as e:
            print(f"Data loading failed: {e}")
            continue
        
        train_data, val_data, test_data = split_data(data, 0.6, 0.25, 0.15)
        
        # Initialize traditional models
        ovrv_model = OVRVModel(**config['ovrv_params'])
        idm_model = IDMModel(**config['idm_params'])
        
        # Train hybrid models only
        models_and_histories = {}
        
        # OVRV ensemble
        ovrv_trainer = SimplifiedEnsembleTrainer(n_models=5, diversity_factor=0.15)
        models_and_histories['ovrv'] = ovrv_trainer.train_models(
            train_data, val_data, ovrv_model, 'OVRV', delta_t, 12, device)
        
        # IDM ensemble
        idm_trainer = SimplifiedEnsembleTrainer(n_models=5, diversity_factor=0.15)
        models_and_histories['idm'] = idm_trainer.train_models(
            train_data, val_data, idm_model, 'IDM', delta_t, 12, device)
        
        # Plot training curves
        plot_training_curves_unified(models_and_histories, dataset_name, output_dir)
        
        # Evaluate and plot results
        for is_train, data_split in [(True, train_data), (False, test_data)]:
            # Hybrid model evaluation
            results = {}
            results['ovrv'] = evaluate_simulation_unified(
                ovrv_trainer, data_split, delta_t, 12, device, 'hybrid', ovrv_model)
            results['idm'] = evaluate_simulation_unified(
                idm_trainer, data_split, delta_t, 12, device, 'hybrid', idm_model)
            
            # Save test results for final summary
            if not is_train:
                test_results = results
            
            # Plot results
            plot_simulation_results_unified(results, dataset_name, output_dir, is_train)
            
            # Save comprehensive CSV
            save_comprehensive_simulation_csv(
                results['ovrv'], results['idm'], dataset_name, output_dir, is_train)
        
        # Print summary
        print_final_summary(models_and_histories, test_results)
        print(f"Results saved to: {output_dir}")

if __name__ == "__main__":
    main()