In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision matplotlib seaborn scikit-learn tqdm pandas numpy

In [None]:
#1- transformer

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class SimpleDiverSignDataset(Dataset):
    """Simple and effective dataset class"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None):
        # Label encoding
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Data scaling - StandardScaler is more stable
        if scaler is None:
            self.scaler = StandardScaler()
            # Reshape for scaling
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)

        # Calculate class weights
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)
        self.class_weights = torch.FloatTensor([
            total_samples / (len(class_counts) * class_counts[i])
            for i in range(len(class_counts))
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx]

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class FocalLoss(nn.Module):
    """Focal Loss for imbalanced classes"""
    def __init__(self, alpha=1, gamma=2, reduce=True):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss

        if self.reduce:
            return torch.mean(focal_loss)
        return focal_loss

class OptimizedDiverSignTransformer(nn.Module):
    """Optimized Transformer model"""
    def __init__(self,
                 input_dim=69,
                 d_model=256,           # Medium model size
                 n_heads=8,             # Standard number of heads
                 n_layers=6,            # Optimal depth
                 num_classes=10,
                 max_seq_length=150,
                 dropout=0.1):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.num_classes = num_classes

        # Input projection
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Learnable positional encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_length, d_model) * 0.02)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, n_layers)

        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model // 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 4, num_classes)
        )

        # Weight initialization
        self._init_weights()

    def _init_weights(self):
        """Xavier initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Input projection
        x = self.input_projection(x)

        # Add positional encoding
        x = x + self.pos_embedding[:, :seq_len, :]

        # Transformer encoding
        encoded = self.transformer_encoder(x, src_key_padding_mask=mask)

        # Global pooling
        encoded = encoded.transpose(1, 2)  # [batch, d_model, seq_len]
        pooled = self.global_pool(encoded).squeeze(-1)  # [batch, d_model]

        # Classification
        output = self.classifier(pooled)

        return output

class SmartTrainer:
    """Smart training class"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss function with class weights
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

        # Optimizer
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=1e-3,  # Initial learning rate
            weight_decay=0.01,
            betas=(0.9, 0.999)
        )

        # Scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=10, verbose=True
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def train_epoch(self, dataloader):
        """Single epoch training"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(data)
            loss = self.criterion(output, target)

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # Optimizer step
            self.optimizer.step()

            # Statistics
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1

    def validate(self, dataloader):
        """Validation"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def train(self, train_loader, val_loader, epochs=100, save_path="best_model.pth"):
        """Complete training loop"""
        best_val_f1 = 0
        patience = 20
        patience_counter = 0

        logger.info(f"Training started: {epochs} epochs, device: {self.device}")

        for epoch in range(epochs):
            # Train
            train_loss, train_acc, train_f1 = self.train_epoch(train_loader)

            # Validate
            val_loss, val_acc, val_f1, _, _ = self.validate(val_loader)

            # Scheduler step
            self.scheduler.step(val_f1)

            # Save metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Logging
            if (epoch + 1) % 5 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%, F1: {train_f1:.4f}')
                logger.info(f'Val - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%, F1: {val_f1:.4f}')
                logger.info(f'LR: {self.optimizer.param_groups[0]["lr"]:.6f}')

            # Save best model
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_val_f1': best_val_f1,
                    'train_losses': self.train_losses,
                    'val_losses': self.val_losses,
                    'train_accuracies': self.train_accuracies,
                    'val_accuracies': self.val_accuracies,
                    'val_f1_scores': self.val_f1_scores
                }, save_path)
                logger.info(f"✅ New best model saved: F1 {val_f1:.4f}")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.4f}")
        return best_val_f1

def smart_prepare_data(csv_path, sequence_length=150, overlap_ratio=0.3):
    """Smart data preparation"""
    logger.info(f"📊 Loading CSV data: {csv_path}")

    df = pd.read_csv(csv_path)
    logger.info(f"Total rows: {len(df):,}")
    logger.info(f"Classes: {list(df['class'].unique())}")
    logger.info(f"Class distribution:\n{df['class'].value_counts()}")

    # Feature columns
    feature_columns = [col for col in df.columns if col != 'class']
    logger.info(f"Number of features: {len(feature_columns)}")

    # Data quality check
    logger.info("🔍 Checking data quality...")

    # Check for NaN and infinite values
    nan_count = df[feature_columns].isnull().sum().sum()
    inf_count = np.isinf(df[feature_columns].values).sum()

    if nan_count > 0:
        logger.warning(f"⚠️ Found {nan_count} NaN values, applying forward fill...")
        df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(0)

    if inf_count > 0:
        logger.warning(f"⚠️ Found {inf_count} infinite values, applying clipping...")
        df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], np.nan).fillna(0)

    # Minimum samples per class check
    min_samples_per_class = sequence_length * 3  # Minimum 3 sequences per class
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples_per_class].index

    if len(valid_classes) < len(class_counts):
        dropped_classes = set(class_counts.index) - set(valid_classes)
        logger.warning(f"⚠️ Dropping classes with insufficient data: {list(dropped_classes)}")
        df = df[df['class'].isin(valid_classes)]

    logger.info(f"✅ Classes to process: {list(valid_classes)}")

    sequences = []
    labels = []

    # Overlapping sliding window
    stride = int(sequence_length * (1 - overlap_ratio))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)
        class_sequences = 0

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length:
                # Check sequence quality
                if not np.any(np.isnan(sequence)) and not np.any(np.isinf(sequence)):
                    sequences.append(sequence)
                    labels.append(class_name)
                    class_sequences += 1

        logger.info(f"  {class_name}: {class_sequences} sequences")

    sequences = np.array(sequences)

    logger.info(f"📈 Total sequences: {len(sequences):,}")
    logger.info(f"📐 Sequence shape: {sequences.shape}")
    logger.info(f"🎯 Final class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def plot_results(trainer, save_path="training_results.png"):
    """Visualize results"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss plot
    ax1.plot(epochs, trainer.train_losses, 'b-', label='Train Loss', alpha=0.8)
    ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', alpha=0.8)
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy plot
    ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Train Accuracy', alpha=0.8)
    ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', alpha=0.8)
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # F1 Score plot
    ax3.plot(epochs, trainer.val_f1_scores, 'g-', label='Validation F1', alpha=0.8, linewidth=2)
    ax3.set_title('Validation F1 Score')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('F1 Score')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Learning rate plot
    if hasattr(trainer.scheduler, 'get_last_lr'):
        lrs = [group['lr'] for group in trainer.optimizer.param_groups for _ in epochs][:len(epochs)]
        ax4.semilogy(epochs, lrs, 'purple', label='Learning Rate', alpha=0.8)
        ax4.set_title('Learning Rate Schedule')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Learning Rate (log scale)')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('validation_test_confusion_matrices.png', dpi=150, bbox_inches='tight')
    plt.show()

def main():
    set_seed(42)

    # IMPROVED Hyperparameters
    CSV_PATH = ""  # Update your CSV file path here
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 12          # Smaller batch size for stable gradients
    EPOCHS = 150             # 🔥 Increased epochs (80 -> 150)
    VAL_SIZE = 0.15          # 15% validation
    TEST_SIZE = 0.15         # 15% test (70% remaining for train)
    OVERLAP_RATIO = 0.5      # 🔥 More overlap (0.3 -> 0.5)

    logger.info("🚀 IMPROVED Diver Sign Language Transformer Training!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")

    # Data preparation - more data augmentation
    sequences, labels, feature_count = smart_prepare_data(
        CSV_PATH, sequence_length=SEQUENCE_LENGTH, overlap_ratio=OVERLAP_RATIO
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels,
        test_size=TEST_SIZE,
        random_state=42,
        stratify=labels
    )

    # Second split: train (70%) and val (15%) from the remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)  # 0.15 / 0.85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp,
        test_size=val_size_adjusted,
        random_state=42,
        stratify=y_temp
    )

    logger.info(f"🔄 Train sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create datasets
    train_dataset = SimpleDiverSignDataset(X_train, y_train)
    val_dataset = SimpleDiverSignDataset(
        X_val, y_val,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler()
    )
    test_dataset = SimpleDiverSignDataset(
        X_test, y_test,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler()
    )

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

    # IMPROVED Model - larger architecture
    num_classes = len(train_dataset.get_label_encoder().classes_)
    model = OptimizedDiverSignTransformer(
        input_dim=feature_count,
        d_model=384,             # 🔥 Increased! (256 -> 384)
        n_heads=12,              # 🔥 More heads! (8 -> 12)
        n_layers=8,              # 🔥 Deeper! (6 -> 8)
        num_classes=num_classes,
        max_seq_length=SEQUENCE_LENGTH,
        dropout=0.15             # Slightly increased dropout (prevent overfitting)
    )

    logger.info(f"🤖 IMPROVED Model created:")
    logger.info(f"   - Number of classes: {num_classes}")
    logger.info(f"   - Input dimension: {feature_count}")
    logger.info(f"   - Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # IMPROVED Trainer
    trainer = SmartTrainer(model, train_dataset.get_class_weights())

    # 🔥 IMPROVED Training parameters
    trainer.optimizer = optim.AdamW(
        model.parameters(),
        lr=8e-4,                 # Slightly lower (1e-3 -> 8e-4)
        weight_decay=0.02,       # Increased regularization
        betas=(0.9, 0.95)        # Increased beta2
    )

    # Better scheduler
    trainer.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        trainer.optimizer, mode='max', factor=0.7, patience=15, verbose=True, min_lr=1e-6
    )

    # Training
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for final evaluation
    checkpoint = torch.load("best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets = trainer.validate(test_loader)

    logger.info("🎯 IMPROVED Final Results:")
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.2f}%")
    logger.info(f"     - F1 Score: {val_f1:.4f}")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.2f}%")
    logger.info(f"     - F1 Score: {test_f1:.4f}")

    # Classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print("\n📊 VALIDATION Classification Report:")
    print("="*60)
    print(classification_report(val_targets, val_preds, target_names=target_names))

    print("\n📊 TEST Classification Report:")
    print("="*60)
    print(classification_report(test_targets, test_preds, target_names=target_names))

    # Confusion matrices - both validation and test
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Training history visualization
    plot_results(trainer, "improved_training_results.png")

    logger.info("✅ IMPROVED Training completed successfully!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - best_model.pth: Best model checkpoint")
    logger.info(f"   - validation_test_confusion_matrices.png: Confusion matrices")
    logger.info(f"   - improved_training_results.png: Training graphs")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
#2-LSTM+attention
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class BalancedDiverSignDataset(Dataset):
    """Balanced dataset class for diver sign language recognition"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 noise_std=0.005, train_mode=True):
        # Label encoding
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Data scaling using StandardScaler
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.noise_std = noise_std
        self.train_mode = train_mode

        # Calculate class weights for imbalanced data
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)
        self.class_weights = torch.FloatTensor([
            total_samples / (len(class_counts) * class_counts[i])
            for i in range(len(class_counts))
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]

        # Add light noise during training for regularization
        if self.train_mode and random.random() < 0.2:
            noise = torch.randn_like(sequence) * self.noise_std
            sequence = sequence + noise

        return sequence, label

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class AttentionLayer(nn.Module):
    """Additive Attention Mechanism for sequence modeling"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Attention network with tanh activation
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1, bias=False)
        )

    def forward(self, lstm_output):
        """
        Forward pass for attention mechanism
        Args:
            lstm_output: [batch_size, seq_len, hidden_dim]
        Returns:
            context: [batch_size, hidden_dim]
            attention_weights: [batch_size, seq_len]
        """
        # Calculate attention scores for each time step
        attention_scores = self.attention(lstm_output)  # [batch, seq_len, 1]
        attention_scores = attention_scores.squeeze(-1)  # [batch, seq_len]

        # Apply softmax to get normalized attention weights
        attention_weights = torch.softmax(attention_scores, dim=1)  # [batch, seq_len]

        # Calculate weighted context vector using attention weights
        context = torch.bmm(attention_weights.unsqueeze(1), lstm_output)  # [batch, 1, hidden_dim]
        context = context.squeeze(1)  # [batch, hidden_dim]

        return context, attention_weights

class LSTMAttentionModel(nn.Module):
    """LSTM + Attention Model for diver sign language recognition"""
    def __init__(self,
                 input_dim=69,
                 hidden_dim=192,        # LSTM hidden dimension
                 num_layers=3,          # Number of LSTM layers
                 num_classes=10,
                 dropout=0.2,
                 bidirectional=True):   # Use bidirectional LSTM
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_classes = num_classes

        # Input preprocessing layer
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # LSTM layer for sequential processing
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )

        # Calculate LSTM output dimension (double if bidirectional)
        lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim

        # Attention layer for focusing on important time steps
        self.attention = AttentionLayer(lstm_output_dim)

        # Classification head with multiple layers
        self.classifier = nn.Sequential(
            nn.LayerNorm(lstm_output_dim),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim, lstm_output_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(lstm_output_dim // 2, lstm_output_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout * 0.8),
            nn.Linear(lstm_output_dim // 4, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize model weights using proper initialization strategies"""
        for name, param in self.named_parameters():
            if 'weight_ih' in name:  # Input-to-hidden weights
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:  # Hidden-to-hidden weights
                nn.init.orthogonal_(param.data)
            elif 'bias' in name:  # Bias terms
                param.data.fill_(0)
            elif isinstance(param, nn.Linear):  # Linear layer weights
                nn.init.xavier_uniform_(param.weight)
                if param.bias is not None:
                    nn.init.zeros_(param.bias)

    def forward(self, x, return_attention=False):
        """
        Forward pass through the model
        Args:
            x: [batch_size, seq_len, input_dim]
            return_attention: Return attention weights for visualization
        """
        batch_size, seq_len, _ = x.shape

        # Input preprocessing
        x = self.input_projection(x)  # [batch, seq_len, hidden_dim]

        # LSTM forward pass for sequential modeling
        lstm_output, (hidden, cell) = self.lstm(x)  # [batch, seq_len, lstm_output_dim]

        # Apply attention mechanism to focus on important time steps
        context_vector, attention_weights = self.attention(lstm_output)

        # Final classification
        output = self.classifier(context_vector)

        if return_attention:
            return output, attention_weights
        return output

    def get_attention_weights(self, x):
        """Get attention weights for analysis and visualization"""
        self.eval()
        with torch.no_grad():
            _, attention_weights = self.forward(x, return_attention=True)
        return attention_weights

class LSTMAttentionTrainer:
    """Trainer class for LSTM + Attention model"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss function with class weights and label smoothing
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

        # Optimizer optimized for LSTM training
        self.optimizer = optim.AdamW(
            model.parameters(),
            lr=5e-4,                 # Slightly higher learning rate for LSTM
            weight_decay=0.02,
            betas=(0.9, 0.95)
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.7, patience=15, verbose=True, min_lr=1e-6
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def train_epoch(self, dataloader):
        """Training for one epoch"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(data)

            # Loss calculation with L2 regularization
            ce_loss = self.criterion(output, target)
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            loss = ce_loss + 2e-5 * l2_reg

            loss.backward()

            # Gradient clipping (important for LSTM stability)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Statistics tracking
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1

    def validate(self, dataloader):
        """Validation evaluation"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def train(self, train_loader, val_loader, epochs=150):
        """Complete training loop"""
        best_val_f1 = 0
        patience = 20
        patience_counter = 0

        logger.info(f"LSTM + ATTENTION Training started: {epochs} epochs")

        for epoch in range(epochs):
            train_loss, train_acc, train_f1 = self.train_epoch(train_loader)
            val_loss, val_acc, val_f1, _, _ = self.validate(val_loader)

            # Monitor overfitting by tracking train-validation gap
            train_val_gap = train_acc - val_acc

            self.scheduler.step(val_f1)

            # Save metrics for visualization
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Logging every 5 epochs
            if (epoch + 1) % 5 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'Train: Loss={train_loss:.3f}, Acc={train_acc:.1f}%, F1={train_f1:.3f}')
                logger.info(f'Val: Loss={val_loss:.3f}, Acc={val_acc:.1f}%, F1={val_f1:.3f}')
                logger.info(f'Gap: {train_val_gap:.1f}%, LR: {self.optimizer.param_groups[0]["lr"]:.2e}')

            # Save best model based on validation F1 score
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'epoch': epoch,
                    'val_f1': val_f1,
                    'train_val_gap': train_val_gap,
                    'model_type': 'LSTM_Attention'
                }, "lstm_attention_best_model.pth")
                logger.info(f"✅ New best model saved: F1={val_f1:.3f}, Gap={train_val_gap:.1f}%")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping to prevent overfitting
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.3f}")
        return best_val_f1

def visualize_attention_weights(model, dataloader, class_names, device, num_samples=3):
    """Visualize attention weights to understand model focus"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 1, figsize=(15, 4*num_samples))
    if num_samples == 1:
        axes = [axes]

    sample_count = 0
    with torch.no_grad():
        for data, targets in dataloader:
            if sample_count >= num_samples:
                break

            data = data.to(device)

            for i in range(min(data.size(0), num_samples - sample_count)):
                single_input = data[i:i+1]  # [1, seq_len, features]
                target_class = targets[i].item()

                # Get attention weights for visualization
                attention_weights = model.get_attention_weights(single_input)
                attention_weights = attention_weights[0].cpu().numpy()  # [seq_len]

                # Create attention weight plot
                ax = axes[sample_count]
                frames = range(len(attention_weights))
                ax.plot(frames, attention_weights, linewidth=2, color='blue')
                ax.fill_between(frames, attention_weights, alpha=0.3, color='blue')
                ax.set_title(f'Attention Weights - Class: {class_names[target_class]}')
                ax.set_xlabel('Frame Index')
                ax.set_ylabel('Attention Weight')
                ax.grid(True, alpha=0.3)

                sample_count += 1
                if sample_count >= num_samples:
                    break

    plt.tight_layout()
    plt.savefig('lstm_attention_weights.png', dpi=150)
    plt.show()

def prepare_balanced_data(csv_path, sequence_length=150, overlap_ratio=0.3):
    """Prepare balanced dataset with sliding window approach"""
    logger.info(f"📊 Loading CSV data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Data cleaning - handle NaN and infinite values
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Keep only classes with sufficient data for reliable training
    min_samples = sequence_length * 4
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")

    sequences = []
    labels = []

    # Create sequences using sliding window with overlap
    stride = int(sequence_length * (1 - overlap_ratio))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length:
                sequences.append(sequence)
                labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def plot_training_results(trainer, model_name="LSTM_Attention"):
    """Visualize training results and performance metrics"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss curves
    ax1.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy curves
    ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Train-Validation gap analysis
    gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
    ax3.plot(epochs, gaps, 'purple', linewidth=2)
    ax3.axhline(y=12, color='orange', linestyle='--', label='Warning Threshold')
    ax3.axhline(y=8, color='green', linestyle='--', label='Good Threshold')
    ax3.set_title(f'{model_name} - Train-Validation Gap')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Accuracy Gap (%)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # F1 score progression
    ax4.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2)
    ax4.set_title(f'{model_name} - Validation F1 Score')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('F1 Score')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_results.png', dpi=150)
    plt.show()

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('lstm_attention_confusion_matrices.png', dpi=150)
    plt.show()

def main():
    set_seed(42)

    # Configuration parameters
    CSV_PATH = ""  # Update with your dataset path
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 16          # Smaller batch size for LSTM stability
    EPOCHS = 150             # Extended training epochs
    VAL_SIZE = 0.15          # 15% for validation
    TEST_SIZE = 0.15         # 15% for test (70% remaining for training)

    logger.info("🚀 LSTM + ATTENTION Model Training Started!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")

    # Data preparation with balanced approach
    sequences, labels, feature_count = prepare_balanced_data(
        CSV_PATH, SEQUENCE_LENGTH, overlap_ratio=0.3
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels, test_size=TEST_SIZE, random_state=42, stratify=labels
    )

    # Second split: train (70%) and val (15%) from remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)  # 0.15 / 0.85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=42, stratify=y_temp
    )

    logger.info(f"🔄 Training sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create datasets with appropriate modes
    train_dataset = BalancedDiverSignDataset(X_train, y_train, train_mode=True)
    val_dataset = BalancedDiverSignDataset(
        X_val, y_val, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )
    test_dataset = BalancedDiverSignDataset(
        X_test, y_test, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

    # Initialize LSTM + Attention model
    num_classes = len(train_dataset.get_label_encoder().classes_)
    model = LSTMAttentionModel(
        input_dim=feature_count,
        hidden_dim=192,           # LSTM hidden dimension
        num_layers=3,             # Number of LSTM layers
        num_classes=num_classes,
        dropout=0.2,
        bidirectional=True        # Use bidirectional LSTM
    )

    logger.info(f"🤖 LSTM + ATTENTION Model Configuration:")
    logger.info(f"   - Number of classes: {num_classes}")
    logger.info(f"   - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"   - LSTM hidden dimension: 192")
    logger.info(f"   - LSTM layers: 3")
    logger.info(f"   - Bidirectional: True")

    # Initialize trainer with class weights
    trainer = LSTMAttentionTrainer(model, train_dataset.get_class_weights())

    # Start training process
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for final evaluation
    checkpoint = torch.load("lstm_attention_best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets = trainer.validate(test_loader)

    # Calculate final training-validation gap
    final_train_acc = trainer.train_accuracies[-1]
    final_val_acc = trainer.val_accuracies[-1]
    accuracy_gap = final_train_acc - final_val_acc

    logger.info("🎯 LSTM + ATTENTION Final Results:")
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.1f}%")
    logger.info(f"     - F1 Score: {val_f1:.3f}")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.1f}%")
    logger.info(f"     - F1 Score: {test_f1:.3f}")
    logger.info(f"   📊 OVERFITTING ANALYSIS:")
    logger.info(f"     - Train-Validation Gap: {accuracy_gap:.1f}%")

    # Detailed classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print("\n📊 VALIDATION Classification Report:")
    print("="*60)
    print(classification_report(val_targets, val_preds, target_names=target_names))

    print("\n📊 TEST Classification Report:")
    print("="*60)
    print(classification_report(test_targets, test_preds, target_names=target_names))

    # Generate confusion matrices for both validation and test
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Plot training progress and metrics
    plot_training_results(trainer, "LSTM_Attention")

    # Visualize attention weights for model interpretability
    logger.info("🔍 Generating attention weight visualizations...")
    visualize_attention_weights(model, test_loader, target_names, trainer.device, num_samples=3)

    logger.info("✅ LSTM + ATTENTION training completed successfully!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - lstm_attention_best_model.pth: Best model checkpoint")
    logger.info(f"   - lstm_attention_confusion_matrices.png: Confusion matrices")
    logger.info(f"   - lstm_attention_results.png: Training curves")
    logger.info(f"   - lstm_attention_weights.png: Attention visualizations")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
#3-lstm+GRU

# LSTM + GRU Hybrid Model for Diver Sign Language Recognition
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class BalancedDiverSignDataset(Dataset):
    """Balanced dataset class with data augmentation for diver sign language"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 noise_std=0.005, train_mode=True, augment_prob=0.3):
        # Label encoding
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Data scaling using StandardScaler
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.noise_std = noise_std
        self.train_mode = train_mode
        self.augment_prob = augment_prob

        # Calculate class weights for imbalanced data handling
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)
        self.class_weights = torch.FloatTensor([
            total_samples / (len(class_counts) * class_counts[i])
            for i in range(len(class_counts))
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx].clone()
        label = self.labels[idx]

        # Apply data augmentation during training
        if self.train_mode and random.random() < self.augment_prob:
            sequence = self._augment_sequence(sequence)

        return sequence, label

    def _augment_sequence(self, sequence):
        """Apply various sequence augmentation techniques"""
        aug_type = random.choice(['noise', 'scale', 'shift'])

        if aug_type == 'noise':
            # Add Gaussian noise for robustness
            noise = torch.randn_like(sequence) * self.noise_std
            sequence = sequence + noise

        elif aug_type == 'scale':
            # Apply random scaling
            scale_factor = random.uniform(0.98, 1.02)
            sequence = sequence * scale_factor

        elif aug_type == 'shift':
            # Apply temporal shifting
            shift = random.randint(-3, 3)
            sequence = torch.roll(sequence, shift, dims=0)

        return sequence

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class LSTMGRUHybridModel(nn.Module):
    """LSTM + GRU Hybrid Model for enhanced sequential pattern recognition"""
    def __init__(self, input_dim=69, lstm_hidden=256, gru_hidden=256,
                 num_classes=10, dropout=0.2, num_layers=2):
        super().__init__()

        self.input_dim = input_dim
        self.lstm_hidden = lstm_hidden
        self.gru_hidden = gru_hidden
        self.num_classes = num_classes
        self.num_layers = num_layers

        # Input preprocessing layer
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, lstm_hidden),
            nn.LayerNorm(lstm_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # LSTM layers for capturing long-term dependencies
        self.lstm = nn.LSTM(
            input_size=lstm_hidden,
            hidden_size=lstm_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )

        # GRU layers for shorter patterns and feature refinement
        lstm_output_dim = lstm_hidden * 2  # Bidirectional output
        self.gru = nn.GRU(
            input_size=lstm_output_dim,
            hidden_size=gru_hidden,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )

        # Attention mechanism for LSTM outputs
        self.lstm_attention = nn.Sequential(
            nn.Linear(lstm_output_dim, lstm_hidden),
            nn.Tanh(),
            nn.Linear(lstm_hidden, 1, bias=False)
        )

        # Attention mechanism for GRU outputs
        gru_output_dim = gru_hidden * 2  # Bidirectional output
        self.gru_attention = nn.Sequential(
            nn.Linear(gru_output_dim, gru_hidden),
            nn.Tanh(),
            nn.Linear(gru_hidden, 1, bias=False)
        )

        # Feature fusion layer to combine LSTM and GRU representations
        total_features = lstm_output_dim + gru_output_dim
        self.feature_fusion = nn.Sequential(
            nn.Linear(total_features, total_features // 2),
            nn.LayerNorm(total_features // 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Multi-layer classification head
        self.classifier = nn.Sequential(
            nn.Linear(total_features // 2, total_features // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(total_features // 4, total_features // 8),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(total_features // 8, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize model weights using appropriate strategies"""
        for name, param in self.named_parameters():
            if param.dim() >= 2:  # Only initialize 2D+ tensors
                if 'weight_ih' in name or 'weight_hh' in name:
                    if 'lstm' in name:
                        # LSTM weight initialization
                        if 'weight_ih' in name:
                            nn.init.xavier_uniform_(param.data)
                        else:
                            nn.init.orthogonal_(param.data)
                    elif 'gru' in name:
                        # GRU weight initialization
                        if 'weight_ih' in name:
                            nn.init.xavier_uniform_(param.data)
                        else:
                            nn.init.orthogonal_(param.data)
                elif 'weight' in name:
                    nn.init.xavier_uniform_(param.data)
            elif 'bias' in name:
                nn.init.constant_(param.data, 0)

    def forward(self, x, return_attention=False):
        """
        Forward pass through the hybrid model
        Args:
            x: [batch_size, seq_len, input_dim]
            return_attention: Whether to return attention weights
        """
        batch_size, seq_len, _ = x.shape

        # Input preprocessing
        x = self.input_projection(x)  # [batch, seq_len, lstm_hidden]

        # LSTM processing for long-term dependency capture
        lstm_output, _ = self.lstm(x)  # [batch, seq_len, lstm_hidden*2]

        # LSTM attention mechanism
        lstm_attn_scores = self.lstm_attention(lstm_output).squeeze(-1)  # [batch, seq_len]
        lstm_attn_weights = torch.softmax(lstm_attn_scores, dim=1)
        lstm_attended = torch.bmm(lstm_attn_weights.unsqueeze(1), lstm_output).squeeze(1)

        # GRU processing for pattern refinement and shorter dependencies
        gru_output, _ = self.gru(lstm_output)  # [batch, seq_len, gru_hidden*2]

        # GRU attention mechanism
        gru_attn_scores = self.gru_attention(gru_output).squeeze(-1)  # [batch, seq_len]
        gru_attn_weights = torch.softmax(gru_attn_scores, dim=1)
        gru_attended = torch.bmm(gru_attn_weights.unsqueeze(1), gru_output).squeeze(1)

        # Feature fusion: combine LSTM and GRU representations
        combined_features = torch.cat([lstm_attended, gru_attended], dim=1)
        fused_features = self.feature_fusion(combined_features)

        # Final classification
        output = self.classifier(fused_features)

        if return_attention:
            return output, lstm_attn_weights, gru_attn_weights
        return output

    def get_attention_weights(self, x):
        """Extract attention weights for visualization"""
        self.eval()
        with torch.no_grad():
            _, lstm_attn, gru_attn = self.forward(x, return_attention=True)
        return lstm_attn, gru_attn

class LSTMGRUTrainer:
    """Trainer class for LSTM-GRU hybrid model"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss function with class balancing
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

        # Optimizer configuration
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=1e-3,
            weight_decay=0.01,
            betas=(0.9, 0.95)
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.7, patience=12, verbose=True, min_lr=1e-6
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def train_epoch(self, dataloader):
        """Execute one training epoch"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(data)

            # Loss calculation with L2 regularization
            ce_loss = self.criterion(output, target)
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            loss = ce_loss + 1e-5 * l2_reg

            loss.backward()

            # Gradient clipping for training stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Statistics tracking
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1

    def validate(self, dataloader):
        """Perform validation evaluation"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def train(self, train_loader, val_loader, epochs=150):
        """Complete training loop with validation"""
        best_val_f1 = 0
        patience = 20
        patience_counter = 0

        logger.info(f"🚀 LSTM-GRU Hybrid Training Started: {epochs} epochs")

        for epoch in range(epochs):
            train_loss, train_acc, train_f1 = self.train_epoch(train_loader)
            val_loss, val_acc, val_f1, _, _ = self.validate(val_loader)

            # Learning rate scheduling
            self.scheduler.step(val_f1)

            # Calculate train-validation gap for overfitting analysis
            train_val_gap = train_acc - val_acc

            # Save metrics for plotting
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Progress logging every 5 epochs
            if (epoch + 1) % 5 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'  Train: Loss={train_loss:.3f}, Acc={train_acc:.1f}%, F1={train_f1:.3f}')
                logger.info(f'  Val:   Loss={val_loss:.3f}, Acc={val_acc:.1f}%, F1={val_f1:.3f}')
                logger.info(f'  Gap: {train_val_gap:.1f}%, LR: {self.optimizer.param_groups[0]["lr"]:.2e}')

            # Save best model based on validation F1 score
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'epoch': epoch,
                    'val_f1': val_f1,
                    'train_val_gap': train_val_gap,
                    'model_type': 'LSTM_GRU_Hybrid'
                }, "lstm_gru_hybrid_best_model.pth")
                logger.info(f"  ✅ New best model saved: F1={val_f1:.3f}, Gap={train_val_gap:.1f}%")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping to prevent overfitting
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.3f}")
        return best_val_f1

def visualize_dual_attention_weights(model, dataloader, class_names, device, num_samples=3):
    """Visualize both LSTM and GRU attention weights for model interpretability"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 2, figsize=(16, 4*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    sample_count = 0
    with torch.no_grad():
        for data, targets in dataloader:
            if sample_count >= num_samples:
                break

            data = data.to(device)

            for i in range(min(data.size(0), num_samples - sample_count)):
                single_input = data[i:i+1]
                target_class = targets[i].item()

                # Extract attention weights
                lstm_attn, gru_attn = model.get_attention_weights(single_input)
                lstm_attn = lstm_attn[0].cpu().numpy()
                gru_attn = gru_attn[0].cpu().numpy()

                # Plot LSTM attention weights
                ax1 = axes[sample_count, 0]
                frames = range(len(lstm_attn))
                ax1.plot(frames, lstm_attn, linewidth=2, color='blue', label='LSTM Attention')
                ax1.fill_between(frames, lstm_attn, alpha=0.3, color='blue')
                ax1.set_title(f'LSTM Attention - Class: {class_names[target_class]}')
                ax1.set_xlabel('Frame Index')
                ax1.set_ylabel('Attention Weight')
                ax1.grid(True, alpha=0.3)
                ax1.legend()

                # Plot GRU attention weights
                ax2 = axes[sample_count, 1]
                ax2.plot(frames, gru_attn, linewidth=2, color='red', label='GRU Attention')
                ax2.fill_between(frames, gru_attn, alpha=0.3, color='red')
                ax2.set_title(f'GRU Attention - Class: {class_names[target_class]}')
                ax2.set_xlabel('Frame Index')
                ax2.set_ylabel('Attention Weight')
                ax2.grid(True, alpha=0.3)
                ax2.legend()

                sample_count += 1
                if sample_count >= num_samples:
                    break

    plt.tight_layout()
    plt.savefig('lstm_gru_hybrid_attention_weights.png', dpi=150)
    plt.show()

def plot_training_results(trainer, model_name="LSTM_GRU_Hybrid"):
    """Visualize comprehensive training results and metrics"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss curves
    ax1.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy curves
    ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # F1 Score progression
    ax3.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2)
    ax3.set_title(f'{model_name} - Validation F1 Score')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('F1 Score')
    ax3.grid(True, alpha=0.3)

    # Overfitting analysis through train-validation gap
    gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
    ax4.plot(epochs, gaps, 'purple', linewidth=2)
    ax4.axhline(y=10, color='orange', linestyle='--', label='Warning Threshold')
    ax4.axhline(y=5, color='green', linestyle='--', label='Good Threshold')
    ax4.set_title(f'{model_name} - Train-Validation Gap')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Accuracy Gap (%)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_training_results.png', dpi=150)
    plt.show()

def prepare_balanced_dataset(csv_path, sequence_length=150, overlap_ratio=0.3):
    """Prepare balanced dataset with sliding window approach"""
    logger.info(f"📊 Loading data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Data cleaning: handle missing and infinite values
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Filter classes with sufficient data for reliable training
    min_samples = sequence_length * 4
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")

    sequences = []
    labels = []

    # Create sequences using sliding window with overlap
    stride = int(sequence_length * (1 - overlap_ratio))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length:
                sequences.append(sequence)
                labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('lstm_gru_hybrid_confusion_matrices.png', dpi=150)
    plt.show()

def main():
    set_seed(42)

    # Configuration parameters
    CSV_PATH = ""  # Update with your dataset path
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 20
    EPOCHS = 150             # Extended training epochs
    VAL_SIZE = 0.15          # 15% for validation
    TEST_SIZE = 0.15         # 15% for test (70% remaining for training)

    logger.info("🚀 LSTM-GRU HYBRID Model Training Started!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")

    # Data preparation with balanced approach
    sequences, labels, feature_count = prepare_balanced_dataset(
        CSV_PATH, SEQUENCE_LENGTH, overlap_ratio=0.3
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels, test_size=TEST_SIZE, random_state=42, stratify=labels
    )

    # Second split: train (70%) and val (15%) from remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)  # 0.15 / 0.85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=42, stratify=y_temp
    )

    logger.info(f"🔄 Training sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"🔄 Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create datasets with appropriate modes
    train_dataset = BalancedDiverSignDataset(X_train, y_train, train_mode=True)
    val_dataset = BalancedDiverSignDataset(
        X_val, y_val, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )
    test_dataset = BalancedDiverSignDataset(
        X_test, y_test, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

    # Initialize LSTM-GRU hybrid model
    num_classes = len(train_dataset.get_label_encoder().classes_)
    model = LSTMGRUHybridModel(
        input_dim=feature_count,
        lstm_hidden=256,
        gru_hidden=256,
        num_classes=num_classes,
        dropout=0.2,
        num_layers=2
    )

    logger.info(f"🤖 LSTM-GRU Hybrid Model Configuration:")
    logger.info(f"   - Number of classes: {num_classes}")
    logger.info(f"   - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"   - LSTM hidden dimension: 256")
    logger.info(f"   - GRU hidden dimension: 256")
    logger.info(f"   - Number of layers: 2")

    # Initialize trainer with class weights
    trainer = LSTMGRUTrainer(model, train_dataset.get_class_weights())

    # Start training process
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for final evaluation
    checkpoint = torch.load("lstm_gru_hybrid_best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets = trainer.validate(test_loader)

    # Calculate training analysis metrics
    final_train_acc = trainer.train_accuracies[-1]
    final_val_acc = trainer.val_accuracies[-1]
    accuracy_gap = final_train_acc - final_val_acc

    logger.info("🎯 LSTM-GRU Hybrid Final Results:")
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.1f}%")
    logger.info(f"     - F1 Score: {val_f1:.3f}")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.1f}%")
    logger.info(f"     - F1 Score: {test_f1:.3f}")
    logger.info(f"   📊 OVERFITTING ANALYSIS:")
    logger.info(f"     - Train-Validation Gap: {accuracy_gap:.1f}%")

    # Detailed classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print("\n📊 VALIDATION Classification Report:")
    print("="*60)
    print(classification_report(val_targets, val_preds, target_names=target_names))

    print("\n📊 TEST Classification Report:")
    print("="*60)
    print(classification_report(test_targets, test_preds, target_names=target_names))

    # Generate confusion matrices for both validation and test
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Plot comprehensive training results
    plot_training_results(trainer, "LSTM_GRU_Hybrid")

    # Visualize dual attention weights for model interpretability
    logger.info("🔍 Generating dual attention weight visualizations...")
    visualize_dual_attention_weights(model, test_loader, target_names, trainer.device, num_samples=3)

    logger.info("✅ LSTM-GRU Hybrid training completed successfully!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - lstm_gru_hybrid_best_model.pth: Best model checkpoint")
    logger.info(f"   - lstm_gru_hybrid_confusion_matrices.png: Confusion matrices")
    logger.info(f"   - lstm_gru_hybrid_training_results.png: Training curves")
    logger.info(f"   - lstm_gru_hybrid_attention_weights.png: Attention visualizations")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
#4-residual lstm

# Enhanced Deep Learning Model for Diver Sign Language Recognition
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class EnhancedDiverSignDataset(Dataset):
    """Enhanced dataset with advanced augmentation and intelligent balancing"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 train_mode=True, augment_prob=0.4):

        # Label encoding with enhanced error handling
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Advanced scaling with robust normalization
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.train_mode = train_mode
        self.augment_prob = augment_prob

        # Enhanced class weights using effective number of samples
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)

        # Calculate effective number of samples for better class weighting
        beta = 0.999
        effective_nums = [(1 - beta**class_counts[i]) / (1 - beta) for i in range(len(class_counts))]
        weights = [1.0 / effective_nums[i] for i in range(len(class_counts))]

        # Normalize weights for balanced training
        sum_weights = sum(weights)
        self.class_weights = torch.FloatTensor([w * len(weights) / sum_weights for w in weights])

        # Store class indices for intelligent sampling
        self.class_indices = {}
        for i, label in enumerate(encoded_labels):
            if label not in self.class_indices:
                self.class_indices[label] = []
            self.class_indices[label].append(i)

        logger.info(f"Dataset created: {len(self.sequences)} samples, {len(class_counts)} classes")
        logger.info(f"Class distribution: {dict(class_counts)}")

    def __len__(self):
        # Oversample minority classes during training for better balance
        if self.train_mode:
            return len(self.sequences) + len(self.sequences) // 3
        return len(self.sequences)

    def __getitem__(self, idx):
        # Intelligent balanced sampling for training
        if self.train_mode and idx >= len(self.sequences):
            # Sample from minority classes more frequently
            class_sizes = [(k, len(v)) for k, v in self.class_indices.items()]
            minority_threshold = np.median([size for _, size in class_sizes])
            minority_classes = [k for k, size in class_sizes if size <= minority_threshold]

            if minority_classes and random.random() < 0.7:
                selected_class = random.choice(minority_classes)
                idx = random.choice(self.class_indices[selected_class])
            else:
                idx = random.randint(0, len(self.sequences) - 1)
        else:
            idx = idx % len(self.sequences)

        sequence = self.sequences[idx].clone()
        label = self.labels[idx]

        # Apply enhanced augmentation for better generalization
        if self.train_mode and random.random() < self.augment_prob:
            sequence = self._apply_enhanced_augmentation(sequence, label)

        return sequence, label

    def _apply_enhanced_augmentation(self, sequence, label):
        """Apply multiple sophisticated augmentation techniques"""
        aug_types = random.sample(['noise', 'smooth', 'scale', 'shift'], k=random.randint(1, 2))

        for aug_type in aug_types:
            if aug_type == 'noise':
                # Adaptive noise based on sequence statistics
                seq_std = torch.std(sequence, dim=0, keepdim=True)
                noise = torch.randn_like(sequence) * seq_std * 0.02
                sequence = sequence + noise

            elif aug_type == 'smooth':
                # Temporal smoothing with Gaussian kernel
                kernel_size = 3
                sequence_padded = F.pad(sequence.transpose(0, 1), (1, 1), mode='reflect')
                smoothed = F.avg_pool1d(sequence_padded.unsqueeze(0), kernel_size, stride=1, padding=0)
                sequence = smoothed.squeeze(0).transpose(0, 1)

            elif aug_type == 'scale':
                # Feature-wise scaling for robustness
                scale_factors = torch.normal(1.0, 0.03, size=(1, sequence.size(1)))
                sequence = sequence * scale_factors.clamp(0.95, 1.05)

            elif aug_type == 'shift':
                # Temporal shifting for invariance
                shift = random.randint(-5, 5)
                sequence = torch.roll(sequence, shift, dims=0)

        return sequence

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class EnhancedDeepModel(nn.Module):
    """Enhanced deep learning model with superior capacity and regularization"""
    def __init__(self, input_dim=69, hidden_dim=160, num_classes=12, dropout=0.25):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        # Enhanced input processing with residual connections
        self.input_norm = nn.LayerNorm(input_dim)
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Input skip connection for gradient flow
        self.input_skip = nn.Linear(input_dim, hidden_dim)

        # Enhanced LSTM with increased capacity
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=3,
            batch_first=True,
            dropout=dropout,
            bidirectional=True
        )

        # Multi-head attention for superior representation learning
        self.multihead_attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )

        # Enhanced attention pooling mechanism
        self.attention_pool = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim, 1, bias=False)
        )

        # Advanced classifier with batch normalization
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim * 2),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout * 0.7),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        # Initialize weights with proper strategies
        self._initialize_weights()

    def _initialize_weights(self):
        """Advanced weight initialization for optimal training"""
        for name, param in self.named_parameters():
            if param.dim() >= 2:
                if 'lstm' in name and 'weight' in name:
                    nn.init.orthogonal_(param)
                elif 'weight' in name and 'norm' not in name and 'batch' not in name:
                    nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
                if 'lstm' in name and 'bias_hh' in name:
                    n = param.size(0)
                    param.data[n//4:n//2].fill_(1.0)

    def forward(self, x):
        """
        Forward pass through the enhanced model
        Args:
            x: Input tensor of shape [batch_size, seq_len, input_dim]
        Returns:
            output: Classification logits
        """
        batch_size, seq_len, features = x.shape

        # Enhanced input processing with residual connections
        x_norm = self.input_norm(x)
        x_proj = self.input_projection(x_norm)
        x_skip = self.input_skip(x_norm)
        x_input = x_proj + x_skip

        # LSTM processing for sequential modeling
        lstm_out, _ = self.lstm(x_input)

        # Multi-head self-attention for enhanced representation
        attn_out, _ = self.multihead_attention(lstm_out, lstm_out, lstm_out)

        # Combine LSTM and attention outputs
        combined = lstm_out + attn_out

        # Attention pooling for sequence aggregation
        attn_scores = self.attention_pool(combined).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=1)

        # Weighted pooling to create final representation
        pooled = torch.bmm(attn_weights.unsqueeze(1), combined).squeeze(1)

        # Final classification
        output = self.classifier(pooled)

        return output

class EnhancedModelTrainer:
    """Enhanced trainer with advanced loss functions and optimization strategies"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Enhanced loss function with class balancing
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

        # Advanced optimizer with different learning rates for different components
        self.optimizer = optim.AdamW([
            {'params': [p for n, p in model.named_parameters() if 'classifier' in n], 'lr': 2e-3},
            {'params': [p for n, p in model.named_parameters() if 'lstm' in n], 'lr': 8e-4},
            {'params': [p for n, p in model.named_parameters() if 'attention' in n], 'lr': 1e-3},
            {'params': [p for n, p in model.named_parameters() if not any(x in n for x in ['classifier', 'lstm', 'attention'])], 'lr': 1e-3}
        ], weight_decay=1e-4, betas=(0.9, 0.95))

        # Enhanced scheduler with cosine annealing warm restarts
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=15, T_mult=2, eta_min=1e-6
        )

        # Comprehensive metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def focal_loss(self, pred, target, alpha=1.0, gamma=2.0):
        """Focal loss implementation for handling hard examples"""
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss
        return focal_loss.mean()

    def train_epoch(self, dataloader):
        """Enhanced training epoch with focal loss integration"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(data)

            # Combined loss: CrossEntropy + Focal Loss for better hard example handling
            ce_loss = self.criterion(output, target)
            focal_loss = self.focal_loss(output, target, alpha=2.0, gamma=2.0)
            loss = 0.7 * ce_loss + 0.3 * focal_loss

            # L2 regularization for preventing overfitting
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            loss = loss + 1e-5 * l2_reg

            loss.backward()

            # Gradient clipping for training stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)

            self.optimizer.step()

            # Statistics tracking
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1

    def validate(self, dataloader):
        """Standard validation with comprehensive metrics"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def train(self, train_loader, val_loader, epochs=150):
        """Enhanced training loop with comprehensive monitoring"""
        best_val_f1 = 0
        patience = 20
        patience_counter = 0

        logger.info(f"🚀 Enhanced Model Training Started: {epochs} epochs")

        for epoch in range(epochs):
            train_loss, train_acc, train_f1 = self.train_epoch(train_loader)
            val_loss, val_acc, val_f1, _, _ = self.validate(val_loader)

            # Learning rate scheduling
            self.scheduler.step()

            # Save comprehensive metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Regular progress logging
            if (epoch + 1) % 5 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'  Train: Loss={train_loss:.3f}, Acc={train_acc:.1f}%, F1={train_f1:.3f}')
                logger.info(f'  Val:   Loss={val_loss:.3f}, Acc={val_acc:.1f}%, F1={val_f1:.3f}')
                logger.info(f'  LR: {self.optimizer.param_groups[0]["lr"]:.2e}')

            # Save best model based on validation F1
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'epoch': epoch,
                    'val_f1': val_f1,
                    'model_type': 'Enhanced_Deep_Model'
                }, "enhanced_deep_model_best.pth")
                logger.info(f"  ✅ New best model saved: F1={val_f1:.3f}")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping for optimal training
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.3f}")
        return best_val_f1

def plot_comprehensive_training_results(trainer, model_name="Enhanced_Deep_Model"):
    """Plot comprehensive training results and analysis"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss curves with enhanced styling
    ax1.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy curves with detailed visualization
    ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # F1 Score progression
    ax3.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2)
    ax3.set_title(f'{model_name} - Validation F1 Score')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('F1 Score')
    ax3.grid(True, alpha=0.3)

    # Overfitting analysis through train-validation gap
    gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
    ax4.plot(epochs, gaps, 'purple', linewidth=2)
    ax4.axhline(y=10, color='orange', linestyle='--', label='Warning Threshold')
    ax4.axhline(y=5, color='green', linestyle='--', label='Good Threshold')
    ax4.set_title(f'{model_name} - Train-Validation Gap Analysis')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Accuracy Gap (%)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_comprehensive_results.png', dpi=150)
    plt.show()

def prepare_balanced_dataset(csv_path, sequence_length=150, overlap_ratio=0.3):
    """Prepare balanced dataset with intelligent preprocessing"""
    logger.info(f"📊 Loading data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Advanced data cleaning
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Filter classes with sufficient data for reliable training
    min_samples = sequence_length * 3
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")
    logger.info(f"Class distribution: {dict(class_counts[valid_classes])}")

    sequences = []
    labels = []

    # Create overlapping sequences for data augmentation
    stride = int(sequence_length * (1 - overlap_ratio))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length:
                sequences.append(sequence)
                labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Final distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot enhanced confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('enhanced_deep_model_confusion_matrices.png', dpi=150)
    plt.show()

def main():
    set_seed(42)

    # Enhanced configuration parameters
    CSV_PATH = "/content/drive/MyDrive/dalgic_egitim/dataset.csv"  # Update with your dataset path
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 32
    EPOCHS = 150             # Extended training epochs
    VAL_SIZE = 0.15          # 15% for validation
    TEST_SIZE = 0.15         # 15% for test (70% remaining for training)

    logger.info("🚀 Enhanced Deep Learning Model Training Started!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")
    logger.info("="*60)

    # Advanced data preparation with intelligent preprocessing
    sequences, labels, feature_count = prepare_balanced_dataset(
        CSV_PATH, SEQUENCE_LENGTH, overlap_ratio=0.3
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels, test_size=TEST_SIZE, random_state=42, stratify=labels
    )

    # Second split: train (70%) and val (15%) from remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)  # 0.15 / 0.85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=42, stratify=y_temp
    )

    logger.info(f"\n📊 Dataset Summary:")
    logger.info(f"   - Input Features: {feature_count}")
    logger.info(f"   - Sequence Length: {SEQUENCE_LENGTH}")
    logger.info(f"   - Training sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create enhanced datasets with advanced augmentation
    train_dataset = EnhancedDiverSignDataset(X_train, y_train, train_mode=True, augment_prob=0.5)
    val_dataset = EnhancedDiverSignDataset(
        X_val, y_val,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler(),
        train_mode=False
    )
    test_dataset = EnhancedDiverSignDataset(
        X_test, y_test,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler(),
        train_mode=False
    )

    # Create optimized data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

    num_classes = len(train_dataset.get_label_encoder().classes_)
    class_names = train_dataset.get_label_encoder().classes_

    logger.info(f"   - Number of Classes: {num_classes}")
    logger.info(f"   - Class Names: {list(class_names)}")

    # Initialize enhanced deep learning model
    model = EnhancedDeepModel(
        input_dim=feature_count,
        hidden_dim=160,
        num_classes=num_classes,
        dropout=0.25
    )

    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"\n🤖 Enhanced Deep Learning Model Configuration:")
    logger.info(f"   - Total Parameters: {total_params:,}")
    logger.info(f"   - Hidden Dimension: 160")
    logger.info(f"   - LSTM Layers: 3 (bidirectional)")
    logger.info(f"   - Multi-Head Attention: 8 heads")
    logger.info(f"   - Dropout Rate: 0.25")

    # Initialize enhanced trainer with advanced optimization
    trainer = EnhancedModelTrainer(
        model=model,
        class_weights=train_dataset.get_class_weights()
    )

    logger.info(f"\n🔥 Starting Enhanced Training ({EPOCHS} epochs)...")
    logger.info("="*60)

    # Execute comprehensive training with advanced techniques
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for final evaluation
    checkpoint = torch.load("enhanced_deep_model_best.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets = trainer.validate(test_loader)

    logger.info(f"\n📊 Enhanced Model Final Results:")
    logger.info("="*45)
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.2f}%")
    logger.info(f"     - F1 Score: {val_f1:.4f}")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.2f}%")
    logger.info(f"     - F1 Score: {test_f1:.4f}")
    logger.info(f"   📊 TRAINING:")
    logger.info(f"     - Best Validation F1: {best_f1:.4f}")

    # Comprehensive classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print(f"\n📋 VALIDATION Classification Report:")
    print("="*70)
    print(classification_report(val_targets, val_preds, target_names=target_names, digits=3))

    print(f"\n📋 TEST Classification Report:")
    print("="*70)
    print(classification_report(test_targets, test_preds, target_names=target_names, digits=3))

    # Per-class performance analysis
    per_class_f1 = f1_score(test_targets, test_preds, average=None)

    print(f"\n📊 Per-Class Performance Analysis:")
    print("="*60)
    print(f"{'Class':<20} {'F1-Score':<10} {'Performance Status':<20}")
    print("-" * 60)

    for i, (class_name, f1_score_val) in enumerate(zip(class_names, per_class_f1)):
        if f1_score_val >= 0.85:
            status = "🟢 Excellent"
        elif f1_score_val >= 0.75:
            status = "🟡 Good"
        elif f1_score_val >= 0.65:
            status = "🟠 Fair"
        else:
            status = "🔴 Needs Improvement"

        print(f"{class_name:<20} {f1_score_val:<10.3f} {status:<20}")

    # Generate enhanced confusion matrices for both validation and test
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Plot comprehensive training results
    plot_comprehensive_training_results(trainer, "Enhanced_Deep_Model")

    # Enhanced confusion matrix with percentages
    plt.figure(figsize=(14, 10))
    cm = confusion_matrix(test_targets, test_preds)

    # Calculate percentages for better visualization
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

    # Create enhanced heatmap
    sns.heatmap(
        cm_percent,
        annot=True,
        fmt='.1f',
        cmap='RdYlBu_r',
        xticklabels=class_names,
        yticklabels=class_names,
        cbar_kws={'label': 'Percentage (%)'},
        linewidths=0.5
    )

    plt.title('Enhanced Deep Model - Test Confusion Matrix (%)', fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.ylabel('True Label', fontsize=14)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    # Add comprehensive model information
    model_info = (f'Enhanced Deep Model | Parameters: {total_params:,} | '
                  f'Accuracy: {test_acc:.1f}% | F1: {test_f1:.3f}')
    plt.figtext(0.02, 0.02, model_info, fontsize=11,
                bbox=dict(boxstyle="round,pad=0.5", facecolor='lightgreen', alpha=0.8))

    plt.tight_layout()
    plt.savefig('enhanced_deep_model_detailed_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Calculate training-validation gap for overfitting analysis
    final_train_acc = trainer.train_accuracies[-1]
    final_val_acc = trainer.val_accuracies[-1]
    accuracy_gap = final_train_acc - final_val_acc

    # Comprehensive model performance summary
    logger.info(f"\n🏆 ENHANCED MODEL PERFORMANCE SUMMARY:")
    logger.info("="*50)
    logger.info(f"   📈 Final Results:")
    logger.info(f"     - Test Accuracy: {test_acc:.2f}%")
    logger.info(f"     - Test F1 Score: {test_f1:.4f}")
    logger.info(f"     - Validation F1: {val_f1:.4f}")
    logger.info(f"   📊 Training Analysis:")
    logger.info(f"     - Train-Validation Gap: {accuracy_gap:.1f}%")
    logger.info(f"     - Total Parameters: {total_params:,}")
    logger.info(f"     - Training Epochs: {len(trainer.train_losses)}")

    # Performance status assessment
    if test_f1 >= 0.85:
        performance_status = "🎉 Excellent Performance Achieved!"
    elif test_f1 >= 0.75:
        performance_status = "✅ Good Performance Achieved!"
    elif test_f1 >= 0.65:
        performance_status = "👍 Satisfactory Performance"
    else:
        performance_status = "⚠️ Performance Needs Improvement"

    logger.info(f"   🎯 Overall Assessment: {performance_status}")

    # Overfitting analysis
    if accuracy_gap <= 5:
        overfitting_status = "🟢 No Overfitting Detected"
    elif accuracy_gap <= 10:
        overfitting_status = "🟡 Mild Overfitting"
    else:
        overfitting_status = "🔴 Significant Overfitting Detected"

    logger.info(f"   📉 Overfitting Status: {overfitting_status}")

    logger.info("\n✅ Enhanced Deep Learning Model training completed successfully!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - enhanced_deep_model_best.pth: Best model checkpoint")
    logger.info(f"   - enhanced_deep_model_confusion_matrices.png: Dual confusion matrices")
    logger.info(f"   - enhanced_deep_model_comprehensive_results.png: Training curves")
    logger.info(f"   - enhanced_deep_model_detailed_confusion_matrix.png: Detailed test matrix")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
# 5-TPA-Net (Temporal Pyramid Attention Network) for Diver Sign Language Recognition

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class AdvancedDiverSignDataset(Dataset):
    """Advanced dataset with sophisticated augmentation techniques"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 noise_std=0.005, train_mode=True, augment_prob=0.3):

        # Label encoding with error handling
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Robust data scaling
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.noise_std = noise_std
        self.train_mode = train_mode
        self.augment_prob = augment_prob

        # Calculate class weights for imbalanced data handling
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)
        self.class_weights = torch.FloatTensor([
            total_samples / (len(class_counts) * class_counts[i])
            for i in range(len(class_counts))
        ])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx].clone()
        label = self.labels[idx]

        # Apply advanced augmentation during training
        if self.train_mode and random.random() < self.augment_prob:
            sequence = self._apply_advanced_augmentation(sequence)

        return sequence, label

    def _apply_advanced_augmentation(self, sequence):
        """Apply sophisticated sequence augmentation techniques"""
        aug_type = random.choice(['noise', 'scale', 'shift', 'dropout', 'mixup'])

        if aug_type == 'noise':
            # Adaptive noise based on sequence variance
            seq_std = sequence.std(dim=0, keepdim=True)
            noise = torch.randn_like(sequence) * seq_std * 0.02
            sequence = sequence + noise

        elif aug_type == 'scale':
            # Feature-wise random scaling for robustness
            scale_factors = torch.normal(1.0, 0.02, size=(1, sequence.size(1)))
            sequence = sequence * scale_factors

        elif aug_type == 'shift':
            # Temporal shifting with circular padding
            shift = random.randint(-5, 5)
            sequence = torch.roll(sequence, shift, dims=0)

        elif aug_type == 'dropout':
            # Structured dropout (entire time steps)
            mask = torch.rand(sequence.size(0), 1) > 0.05
            sequence = sequence * mask

        elif aug_type == 'mixup':
            # Temporal mixup within sequence for regularization
            alpha = 0.2
            lam = np.random.beta(alpha, alpha)
            rand_idx = torch.randperm(sequence.size(0))
            sequence = lam * sequence + (1 - lam) * sequence[rand_idx]

        return sequence

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class TemporalPyramidBlock(nn.Module):
    """Temporal Pyramid Block - captures multi-scale temporal patterns"""
    def __init__(self, input_dim, hidden_dim, scales=[1, 3, 5, 7], dropout=0.2):
        super().__init__()

        self.scales = scales
        self.hidden_dim = hidden_dim

        # Multi-scale convolutions for different temporal receptive fields
        self.conv_layers = nn.ModuleList()
        for scale in scales:
            conv_block = nn.Sequential(
                nn.Conv1d(input_dim, hidden_dim // len(scales),
                         kernel_size=scale, padding=scale//2),
                nn.BatchNorm1d(hidden_dim // len(scales)),
                nn.GELU(),
                nn.Dropout(dropout)
            )
            self.conv_layers.append(conv_block)

        # Global average pooling branch for long-range dependencies
        self.global_branch = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(input_dim, hidden_dim // len(scales), 1),
            nn.GELU()
        )

        # Feature fusion layer
        total_channels = hidden_dim + hidden_dim // len(scales)
        self.fusion = nn.Sequential(
            nn.Conv1d(total_channels, hidden_dim, 1),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Squeeze-and-Excitation attention mechanism
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(hidden_dim, hidden_dim // 4, 1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim // 4, hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Forward pass through temporal pyramid block
        Args:
            x: Input tensor [batch, seq_len, features]
        Returns:
            Processed tensor with multi-scale temporal features
        """
        # Transform to [batch, features, seq_len] for 1D convolution
        x = x.transpose(1, 2)

        # Apply multi-scale convolutions
        multi_scale_features = []
        for conv_layer in self.conv_layers:
            feature = conv_layer(x)
            multi_scale_features.append(feature)

        # Global branch for long-range dependencies
        global_feature = self.global_branch(x)
        global_feature = global_feature.expand(-1, -1, x.size(2))

        # Concatenate all multi-scale features
        all_features = torch.cat(multi_scale_features + [global_feature], dim=1)

        # Feature fusion
        fused = self.fusion(all_features)

        # Apply Squeeze-and-Excitation attention
        se_weights = self.se(fused)
        attended = fused * se_weights

        # Transform back to [batch, seq_len, features]
        return attended.transpose(1, 2)

class CrossScaleAttention(nn.Module):
    """Cross-Scale Attention mechanism for multi-resolution processing"""
    def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads

        # Multi-head attention components
        self.q_linear = nn.Linear(hidden_dim, hidden_dim)
        self.k_linear = nn.Linear(hidden_dim, hidden_dim)
        self.v_linear = nn.Linear(hidden_dim, hidden_dim)

        # Cross-scale projections for different temporal resolutions
        self.scale_projections = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(3)
        ])

        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

        # Scale fusion layer
        self.scale_fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        """
        Forward pass through cross-scale attention
        Args:
            x: Input tensor [batch_size, seq_len, hidden_dim]
        Returns:
            Attended features with cross-scale information
        """
        batch_size, seq_len, _ = x.shape

        # Create multi-scale representations
        scales = []

        # Fine scale (original resolution)
        scales.append(x)

        # Medium scale (pooled for medium-range patterns)
        medium = F.avg_pool1d(x.transpose(1, 2), kernel_size=3, stride=1, padding=1)
        scales.append(medium.transpose(1, 2))

        # Coarse scale (more pooled for long-range patterns)
        coarse = F.avg_pool1d(x.transpose(1, 2), kernel_size=5, stride=1, padding=2)
        scales.append(coarse.transpose(1, 2))

        # Apply scale-specific projections
        projected_scales = []
        for i, scale in enumerate(scales):
            projected = self.scale_projections[i](scale)
            projected_scales.append(projected)

        # Cross-scale attention computation
        attended_scales = []
        for i, query_scale in enumerate(projected_scales):
            q = self.q_linear(query_scale).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

            # Attend to all scales for comprehensive feature extraction
            scale_attentions = []
            for key_scale in projected_scales:
                k = self.k_linear(key_scale).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
                v = self.v_linear(key_scale).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

                # Scaled dot-product attention
                scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
                attn_weights = F.softmax(scores, dim=-1)
                attn_weights = self.dropout(attn_weights)

                attended = torch.matmul(attn_weights, v)
                attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
                scale_attentions.append(attended)

            # Combine attentions from all scales
            combined = torch.mean(torch.stack(scale_attentions), dim=0)
            attended_scales.append(combined)

        # Fuse multi-scale features
        fused_features = torch.cat(attended_scales, dim=-1)
        output = self.scale_fusion(fused_features)
        output = self.out_proj(output)

        return output

class AdaptiveTemporalPooling(nn.Module):
    """Adaptive temporal pooling with learnable aggregation strategies"""
    def __init__(self, hidden_dim, pool_sizes=[2, 4, 8]):
        super().__init__()

        self.pool_sizes = pool_sizes
        self.hidden_dim = hidden_dim

        # Adaptive pooling layers for different temporal granularities
        self.adaptive_pools = nn.ModuleList([
            nn.AdaptiveAvgPool1d(size) for size in pool_sizes
        ])

        # Pool-specific transformations
        self.pool_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.GELU(),
                nn.Dropout(0.1)
            ) for _ in pool_sizes
        ])

        # Attention weights for adaptive pool combination
        self.pool_attention = nn.Sequential(
            nn.Linear(hidden_dim * len(pool_sizes), hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, len(pool_sizes)),
            nn.Softmax(dim=-1)
        )

        # Final projection layer
        self.final_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )

    def forward(self, x):
        """
        Forward pass through adaptive temporal pooling
        Args:
            x: Input tensor [batch, seq_len, hidden_dim]
        Returns:
            Pooled representation and attention weights
        """
        batch_size = x.size(0)

        # Apply different pooling strategies
        pooled_features = []
        for i, (pool, transform) in enumerate(zip(self.adaptive_pools, self.pool_transforms)):
            # Adaptive pooling
            pooled = pool(x.transpose(1, 2))  # [batch, hidden_dim, pool_size]
            pooled = pooled.transpose(1, 2)   # [batch, pool_size, hidden_dim]

            # Global average pooling over the pool dimension
            pooled = pooled.mean(dim=1)  # [batch, hidden_dim]

            # Apply transformation
            transformed = transform(pooled)
            pooled_features.append(transformed)

        # Concatenate all pooled features
        all_pooled = torch.cat(pooled_features, dim=-1)  # [batch, hidden_dim * num_pools]

        # Compute attention weights for different pooling strategies
        pool_weights = self.pool_attention(all_pooled)  # [batch, num_pools]

        # Weighted combination of pooled features
        weighted_features = sum(w.unsqueeze(-1) * feat for w, feat in zip(pool_weights.unbind(-1), pooled_features))

        # Final projection
        output = self.final_proj(weighted_features)

        return output, pool_weights

class TPANet(nn.Module):
    """Temporal Pyramid Attention Network for advanced sequence modeling"""
    def __init__(self, input_dim=69, hidden_dim=256, num_classes=10,
                 num_layers=3, dropout=0.2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.num_layers = num_layers

        # Input embedding with residual connection
        self.input_embedding = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Stack of Temporal Pyramid Blocks
        self.pyramid_blocks = nn.ModuleList([
            TemporalPyramidBlock(hidden_dim, hidden_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        # Cross-Scale Attention mechanism
        self.cross_scale_attention = CrossScaleAttention(hidden_dim, num_heads=8, dropout=dropout)

        # Adaptive Temporal Pooling
        self.adaptive_pooling = AdaptiveTemporalPooling(hidden_dim)

        # Dynamic feature selection mechanism
        self.feature_selector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.Sigmoid()
        )

        # Advanced classification head with uncertainty estimation
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),

            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Uncertainty estimation head for confidence prediction
        self.uncertainty_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.GELU(),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights using appropriate strategies"""
        for name, param in self.named_parameters():
            if param.dim() >= 2:
                if 'conv' in name and 'weight' in name:
                    nn.init.kaiming_normal_(param, mode='fan_out', nonlinearity='relu')
                elif 'weight' in name:
                    nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0)

    def forward(self, x, return_features=False):
        """
        Forward pass through TPA-Net
        Args:
            x: Input tensor [batch_size, seq_len, input_dim]
            return_features: Whether to return intermediate features
        Returns:
            logits: Classification output
            uncertainty: Prediction uncertainty
            features: Intermediate features (if requested)
        """
        batch_size, seq_len, _ = x.shape

        # Input embedding
        x = self.input_embedding(x)

        # Apply Temporal Pyramid Blocks
        pyramid_features = []
        for block in self.pyramid_blocks:
            x = block(x)
            pyramid_features.append(x)

        # Cross-scale attention for multi-resolution processing
        attended = self.cross_scale_attention(x)

        # Adaptive temporal pooling
        pooled, pool_weights = self.adaptive_pooling(attended)

        # Dynamic feature selection
        feature_importance = self.feature_selector(pooled)
        selected_features = pooled * feature_importance

        # Classification
        logits = self.classifier(selected_features)

        # Uncertainty estimation
        uncertainty = self.uncertainty_head(selected_features)

        if return_features:
            return logits, uncertainty, {
                'pyramid_features': pyramid_features,
                'pool_weights': pool_weights,
                'feature_importance': feature_importance,
                'attended_features': attended
            }

        return logits, uncertainty

    def get_attention_analysis(self, x):
        """Get comprehensive attention analysis for interpretability"""
        self.eval()
        with torch.no_grad():
            _, _, features = self.forward(x, return_features=True)
        return features

class TPANetTrainer:
    """Advanced trainer for TPA-Net with sophisticated training strategies"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss functions
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.classification_criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
        self.uncertainty_criterion = nn.MSELoss()

        # Advanced optimizer with component-specific learning rates
        pyramid_params = []
        attention_params = []
        classifier_params = []

        for name, param in model.named_parameters():
            if 'pyramid' in name:
                pyramid_params.append(param)
            elif 'attention' in name or 'pooling' in name:
                attention_params.append(param)
            else:
                classifier_params.append(param)

        self.optimizer = optim.AdamW([
            {'params': pyramid_params, 'lr': 8e-4, 'weight_decay': 0.01},
            {'params': attention_params, 'lr': 1e-3, 'weight_decay': 0.005},
            {'params': classifier_params, 'lr': 1e-3, 'weight_decay': 0.01}
        ], betas=(0.9, 0.95))

        # Cosine annealing with warm restarts
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=30, T_mult=2, eta_min=1e-6
        )

        # Comprehensive metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []
        self.uncertainty_scores = []

    def train_epoch(self, dataloader):
        """Training epoch with uncertainty-aware loss function"""
        self.model.train()
        total_loss = 0
        total_clf_loss = 0
        total_unc_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []
        all_uncertainties = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            logits, uncertainty = self.model(data)

            # Classification loss
            clf_loss = self.classification_criterion(logits, target)

            # Uncertainty loss (encourage high uncertainty for wrong predictions)
            pred_probs = F.softmax(logits, dim=1)
            pred_confidence = pred_probs.max(dim=1)[0]
            uncertainty_target = 1.0 - pred_confidence
            unc_loss = self.uncertainty_criterion(uncertainty.squeeze(), uncertainty_target.detach())

            # Combined loss
            total_loss_batch = clf_loss + 0.1 * unc_loss

            # L2 regularization
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            total_loss_batch += 1e-5 * l2_reg

            total_loss_batch.backward()

            # Gradient clipping for training stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Statistics tracking
            total_loss += total_loss_batch.item()
            total_clf_loss += clf_loss.item()
            total_unc_loss += unc_loss.item()

            _, predicted = logits.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_uncertainties.extend(uncertainty.squeeze().cpu().detach().numpy())

        avg_loss = total_loss / len(dataloader)
        avg_clf_loss = total_clf_loss / len(dataloader)
        avg_unc_loss = total_unc_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')
        avg_uncertainty = np.mean(all_uncertainties)

        return avg_loss, avg_clf_loss, avg_unc_loss, accuracy, f1, avg_uncertainty

    def validate(self, dataloader):
        """Validation with comprehensive uncertainty analysis"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []
        all_uncertainties = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                logits, uncertainty = self.model(data)
                loss = self.classification_criterion(logits, target)

                total_loss += loss.item()
                _, predicted = logits.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
                all_uncertainties.extend(uncertainty.squeeze().cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')
        avg_uncertainty = np.mean(all_uncertainties)

        return avg_loss, accuracy, f1, all_preds, all_targets, avg_uncertainty

    def train(self, train_loader, val_loader, epochs=150):
        """Complete training loop with advanced monitoring"""
        best_val_f1 = 0
        patience = 25
        patience_counter = 0

        logger.info(f"🚀 TPA-Net Training Started: {epochs} epochs")
        logger.info(f"   - Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")

        for epoch in range(epochs):
            # Training
            train_loss, clf_loss, unc_loss, train_acc, train_f1, train_unc = self.train_epoch(train_loader)

            # Validation
            val_loss, val_acc, val_f1, _, _, val_unc = self.validate(val_loader)

            # Learning rate scheduling
            self.scheduler.step()

            # Calculate overfitting gap
            train_val_gap = train_acc - val_acc

            # Save comprehensive metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)
            self.uncertainty_scores.append(val_unc)

            # Regular progress logging
            if (epoch + 1) % 10 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'  Train: Loss={train_loss:.3f} (Clf={clf_loss:.3f}, Unc={unc_loss:.3f})')
                logger.info(f'         Acc={train_acc:.1f}%, F1={train_f1:.3f}, Unc={train_unc:.3f}')
                logger.info(f'  Val:   Loss={val_loss:.3f}, Acc={val_acc:.1f}%, F1={val_f1:.3f}, Unc={val_unc:.3f}')
                logger.info(f'  Gap: {train_val_gap:.1f}%, LR: {self.optimizer.param_groups[0]["lr"]:.2e}')

            # Save best model based on validation F1
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'epoch': epoch,
                    'val_f1': val_f1,
                    'train_val_gap': train_val_gap,
                    'model_type': 'TPA_Net'
                }, "tpa_net_best_model.pth")
                logger.info(f"  ✅ New best model saved: F1={val_f1:.3f}, Gap={train_val_gap:.1f}%")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping for optimal training
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.3f}")
        return best_val_f1

def visualize_tpa_attention(model, dataloader, class_names, device, num_samples=2):
    """Visualize TPA-Net attention mechanisms for interpretability"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 3, figsize=(18, 6*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    sample_count = 0
    with torch.no_grad():
        for data, targets in dataloader:
            if sample_count >= num_samples:
                break

            data = data.to(device)

            for i in range(min(data.size(0), num_samples - sample_count)):
                single_input = data[i:i+1]
                target_class = targets[i].item()

                # Get comprehensive attention analysis
                features = model.get_attention_analysis(single_input)

                # Adaptive pooling weights visualization
                ax1 = axes[sample_count, 0]
                pool_weights = features['pool_weights'][0].cpu().numpy()
                pool_names = ['Fine', 'Medium', 'Coarse']
                bars = ax1.bar(pool_names, pool_weights)
                ax1.set_title(f'Adaptive Pooling Weights\nClass: {class_names[target_class]}')
                ax1.set_ylabel('Weight')
                for bar, weight in zip(bars, pool_weights):
                    height = bar.get_height()
                    ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                            f'{weight:.3f}', ha='center', va='bottom')

                # Feature importance visualization
                ax2 = axes[sample_count, 1]
                feature_importance = features['feature_importance'][0].cpu().numpy()
                ax2.plot(feature_importance, linewidth=2, color='red')
                ax2.fill_between(range(len(feature_importance)), feature_importance, alpha=0.3, color='red')
                ax2.set_title(f'Dynamic Feature Importance\nClass: {class_names[target_class]}')
                ax2.set_xlabel('Feature Dimension')
                ax2.set_ylabel('Importance Score')
                ax2.grid(True, alpha=0.3)

                # Pyramid layer activations evolution
                ax3 = axes[sample_count, 2]
                pyramid_activations = []
                for j, pyr_feat in enumerate(features['pyramid_features']):
                    activation = pyr_feat[0].mean(dim=-1).cpu().numpy()  # Average over features
                    pyramid_activations.append(activation)
                    ax3.plot(activation, label=f'Layer {j+1}', alpha=0.7, linewidth=2)

                ax3.set_title(f'Pyramid Layer Activations\nClass: {class_names[target_class]}')
                ax3.set_xlabel('Time Step')
                ax3.set_ylabel('Activation')
                ax3.legend()
                ax3.grid(True, alpha=0.3)

                sample_count += 1
                if sample_count >= num_samples:
                    break

    plt.tight_layout()
    plt.savefig('tpa_net_attention_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

def plot_comprehensive_tpa_results(trainer, model_name="TPA_Net"):
    """Comprehensive TPA-Net results visualization"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss curves
    ax = axes[0, 0]
    ax.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
    ax.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax.set_title(f'{model_name} - Loss Curves')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Accuracy curves
    ax = axes[0, 1]
    ax.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax.set_title(f'{model_name} - Accuracy Curves')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Accuracy (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # F1 Score progression
    ax = axes[0, 2]
    ax.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2)
    ax.set_title(f'{model_name} - Validation F1 Score')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('F1 Score')
    ax.grid(True, alpha=0.3)

    # Uncertainty evolution
    ax = axes[1, 0]
    ax.plot(epochs, trainer.uncertainty_scores, 'purple', linewidth=2)
    ax.set_title(f'{model_name} - Uncertainty Evolution')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Average Uncertainty')
    ax.grid(True, alpha=0.3)

    # Overfitting analysis
    gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
    ax = axes[1, 1]
    ax.plot(epochs, gaps, 'orange', linewidth=2)
    ax.axhline(y=10, color='red', linestyle='--', label='Warning Threshold')
    ax.axhline(y=5, color='green', linestyle='--', label='Good Threshold')
    ax.set_title(f'{model_name} - Train-Validation Gap')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Gap (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Performance summary
    ax = axes[1, 2]
    final_metrics = {
        'Accuracy': trainer.val_accuracies[-1],
        'F1 Score': trainer.val_f1_scores[-1] * 100,
        'Best F1': max(trainer.val_f1_scores) * 100,
        'Confidence': (1 - trainer.uncertainty_scores[-1]) * 100  # Convert uncertainty to confidence
    }

    bars = ax.bar(final_metrics.keys(), final_metrics.values(),
                  color=['blue', 'green', 'gold', 'purple'], alpha=0.7)
    ax.set_title(f'{model_name} - Final Performance Metrics')
    ax.set_ylabel('Score (%)')
    ax.set_ylim(0, 100)

    # Add value labels on bars
    for bar, value in zip(bars, final_metrics.values()):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{value:.1f}%', ha='center', va='bottom')

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_comprehensive_results.png', dpi=150, bbox_inches='tight')
    plt.show()

def prepare_enhanced_dataset(csv_path, sequence_length=150, overlap_ratio=0.3):
    """Enhanced data preparation with comprehensive preprocessing"""
    logger.info(f"📊 Loading data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Advanced data cleaning
    original_size = len(df)
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(method='bfill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Remove outliers (beyond 3 standard deviations)
    for col in feature_columns:
        mean_val = df[col].mean()
        std_val = df[col].std()
        df = df[np.abs(df[col] - mean_val) <= (3 * std_val)]

    logger.info(f"After cleaning: {len(df)} samples ({len(df)/original_size*100:.1f}% retained)")

    # Filter classes with sufficient data for reliable training
    min_samples = sequence_length * 3
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")

    sequences = []
    labels = []

    # Create sequences using sliding window with overlap
    stride = max(1, int(sequence_length * (1 - overlap_ratio)))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length and not np.any(np.isnan(sequence)):
                sequences.append(sequence)
                labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def analyze_model_complexity(model):
    """Comprehensive model complexity and efficiency analysis"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Component-wise parameter analysis
    component_params = {}
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Leaf modules
            params = sum(p.numel() for p in module.parameters())
            if params > 0:
                component_params[name] = params

    logger.info(f"\n🔧 Model Complexity Analysis:")
    logger.info(f"   - Total Parameters: {total_params:,}")
    logger.info(f"   - Trainable Parameters: {trainable_params:,}")
    logger.info(f"   - Model Size: {total_params * 4 / 1024 / 1024:.2f} MB")

    # Top 5 largest components
    sorted_components = sorted(component_params.items(), key=lambda x: x[1], reverse=True)[:5]
    logger.info(f"   - Largest Components:")
    for name, params in sorted_components:
        logger.info(f"     * {name}: {params:,} parameters ({params/total_params*100:.1f}%)")

    return total_params, trainable_params

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot enhanced confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('tpa_net_confusion_matrices.png', dpi=150)
    plt.show()

def main():
    set_seed(42)

    # Enhanced configuration parameters
    CSV_PATH = ""  # Update with your dataset path
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 16          # Smaller batch size for complex model
    EPOCHS = 150             # Extended training epochs
    VAL_SIZE = 0.15          # 15% for validation
    TEST_SIZE = 0.15         # 15% for test (70% remaining for training)

    logger.info("🚀 TPA-Net (Temporal Pyramid Attention Network) Training Started!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")
    logger.info("="*80)

    # Enhanced data preparation with comprehensive preprocessing
    sequences, labels, feature_count = prepare_enhanced_dataset(
        CSV_PATH, SEQUENCE_LENGTH, overlap_ratio=0.35
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels, test_size=TEST_SIZE, random_state=42, stratify=labels
    )

    # Second split: train (70%) and val (15%) from remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)  # 0.15 / 0.85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=42, stratify=y_temp
    )

    logger.info(f"\n📊 Dataset Summary:")
    logger.info(f"   - Input Features: {feature_count}")
    logger.info(f"   - Sequence Length: {SEQUENCE_LENGTH}")
    logger.info(f"   - Training sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create advanced datasets with sophisticated augmentation
    train_dataset = AdvancedDiverSignDataset(
        X_train, y_train,
        train_mode=True,
        augment_prob=0.5,  # Higher augmentation for complex model
        noise_std=0.01
    )

    val_dataset = AdvancedDiverSignDataset(
        X_val, y_val,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler(),
        train_mode=False
    )

    test_dataset = AdvancedDiverSignDataset(
        X_test, y_test,
        train_dataset.get_label_encoder(),
        train_dataset.get_scaler(),
        train_mode=False
    )

    # Create optimized data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    num_classes = len(train_dataset.get_label_encoder().classes_)
    class_names = train_dataset.get_label_encoder().classes_

    logger.info(f"   - Number of Classes: {num_classes}")
    logger.info(f"   - Class Names: {list(class_names)}")

    # Create advanced TPA-Net model
    model = TPANet(
        input_dim=feature_count,
        hidden_dim=256,
        num_classes=num_classes,
        num_layers=3,
        dropout=0.25
    )

    # Comprehensive model complexity analysis
    total_params, trainable_params = analyze_model_complexity(model)

    logger.info(f"\n🤖 TPA-Net Architecture Configuration:")
    logger.info(f"   - Temporal Pyramid Blocks: 3 layers")
    logger.info(f"   - Cross-Scale Attention: 8 heads")
    logger.info(f"   - Adaptive Pooling: Multi-scale")
    logger.info(f"   - Uncertainty Estimation: Enabled")
    logger.info(f"   - Dynamic Feature Selection: Enabled")

    # Initialize advanced trainer
    trainer = TPANetTrainer(
        model=model,
        class_weights=train_dataset.get_class_weights()
    )

    logger.info(f"\n🔥 Starting Advanced Training ({EPOCHS} epochs)...")
    logger.info("="*80)

    # Execute comprehensive training
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for final evaluation
    checkpoint = torch.load("tpa_net_best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets, val_unc = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets, test_unc = trainer.validate(test_loader)

    # Training analysis
    final_train_acc = trainer.train_accuracies[-1]
    final_val_acc = trainer.val_accuracies[-1]
    accuracy_gap = final_train_acc - final_val_acc

    logger.info(f"\n📊 TPA-Net Final Results:")
    logger.info("="*50)
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.2f}%")
    logger.info(f"     - F1 Score: {val_f1:.4f}")
    logger.info(f"     - Uncertainty: {val_unc:.4f}")
    logger.info(f"     - Confidence: {(1-val_unc)*100:.1f}%")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.2f}%")
    logger.info(f"     - F1 Score: {test_f1:.4f}")
    logger.info(f"     - Uncertainty: {test_unc:.4f}")
    logger.info(f"     - Confidence: {(1-test_unc)*100:.1f}%")
    logger.info(f"   📊 TRAINING ANALYSIS:")
    logger.info(f"     - Train-Validation Gap: {accuracy_gap:.2f}%")
    logger.info(f"     - Best Validation F1: {best_f1:.4f}")

    # Model generalization analysis
    if accuracy_gap < 3:
        logger.info("   ✅ Excellent generalization achieved")
    elif accuracy_gap < 7:
        logger.info("   ✅ Good generalization achieved")
    elif accuracy_gap < 12:
        logger.info("   ⚠️ Moderate overfitting detected")
    else:
        logger.info("   ❌ Significant overfitting detected")

    # Comprehensive classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print(f"\n📋 VALIDATION Classification Report:")
    print("="*80)
    print(classification_report(val_targets, val_preds, target_names=target_names, digits=3))

    print(f"\n📋 TEST Classification Report:")
    print("="*80)
    print(classification_report(test_targets, test_preds, target_names=target_names, digits=3))

    # Generate dual confusion matrices
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Enhanced confusion matrix with percentages
    plt.figure(figsize=(12, 10))
    cm = confusion_matrix(test_targets, test_preds)

    # Normalize to percentages for better interpretation
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

    # Create enhanced heatmap
    sns.heatmap(
        cm_percent,
        annot=True,
        fmt='.1f',
        cmap='RdYlBu_r',
        xticklabels=class_names,
        yticklabels=class_names,
        cbar_kws={'label': 'Percentage (%)'},
        linewidths=0.5
    )

    plt.title('TPA-Net - Test Confusion Matrix (%)', fontsize=18, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.ylabel('True Label', fontsize=14)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    # Add comprehensive model information
    model_info = f'TPA-Net | Params: {total_params:,} | F1: {test_f1:.3f} | Confidence: {(1-test_unc)*100:.1f}%'
    plt.figtext(0.02, 0.02, model_info, fontsize=10,
                bbox=dict(boxstyle="round,pad=0.3", facecolor='lightblue', alpha=0.7))

    plt.tight_layout()
    plt.savefig('tpa_net_detailed_confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Comprehensive training curves visualization
    plot_comprehensive_tpa_results(trainer, "TPA_Net")

    # Advanced attention analysis visualization
    logger.info("\n🔍 Generating Comprehensive Attention Analysis...")
    visualize_tpa_attention(model, test_loader, class_names, trainer.device, num_samples=2)

    # Model efficiency analysis
    logger.info(f"\n⚡ Model Efficiency Analysis:")
    logger.info(f"   - Parameters per Class: {total_params // num_classes:,}")
    logger.info(f"   - Accuracy per 1K Parameters: {test_acc / (total_params / 1000):.3f}")
    logger.info(f"   - F1 Score per 1M Parameters: {test_f1 / (total_params / 1000000):.3f}")

    # Save comprehensive results
    logger.info(f"\n💾 Saving Comprehensive Results...")

    # Save complete model with all metadata
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': {
            'input_dim': feature_count,
            'hidden_dim': 256,
            'num_classes': num_classes,
            'num_layers': 3,
            'dropout': 0.25
        },
        'training_history': {
            'train_losses': trainer.train_losses,
            'val_losses': trainer.val_losses,
            'train_accuracies': trainer.train_accuracies,
            'val_accuracies': trainer.val_accuracies,
            'val_f1_scores': trainer.val_f1_scores,
            'uncertainty_scores': trainer.uncertainty_scores
        },
        'test_results': {
            'test_accuracy': test_acc,
            'test_f1': test_f1,
            'test_uncertainty': test_unc,
            'validation_accuracy': val_acc,
            'validation_f1': val_f1,
            'validation_uncertainty': val_unc,
            'confusion_matrix': cm.tolist(),
            'class_names': class_names.tolist()
        },
        'model_analysis': {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'accuracy_gap': accuracy_gap,
            'best_val_f1': best_f1
        }
    }, 'tpa_net_comprehensive_model.pth')

    logger.info("\n✅ TPA-NET TRAINING COMPLETED SUCCESSFULLY!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - tpa_net_comprehensive_model.pth: Complete model with metadata")
    logger.info(f"   - tpa_net_confusion_matrices.png: Dual confusion matrices")
    logger.info(f"   - tpa_net_comprehensive_results.png: Training curves")
    logger.info(f"   - tpa_net_attention_analysis.png: Attention visualizations")
    logger.info(f"   - tpa_net_detailed_confusion_matrix.png: Detailed test matrix")

    logger.info(f"\n🏆 FINAL PERFORMANCE SUMMARY:")
    logger.info(f"🎯 Test F1 Score: {test_f1:.4f}")
    logger.info(f"🎯 Test Accuracy: {test_acc:.2f}%")
    logger.info(f"🎯 Model Confidence: {(1-test_unc)*100:.1f}%")
    logger.info(f"🎯 Total Parameters: {total_params:,}")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
# ConvLSTM + Vision Transformer Hybrid Model for Diver Sign Language Recognition

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import math
import warnings
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class EnhancedDiverSignDataset(Dataset):
    """Enhanced dataset with advanced augmentation for hybrid model"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 train_mode=True, augment_prob=0.4):

        # Label encoding
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Advanced data scaling
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.train_mode = train_mode
        self.augment_prob = augment_prob

        # Enhanced class weights using effective number of samples
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)

        # Effective number of samples for class balancing
        beta = 0.9999
        effective_nums = [(1 - beta**class_counts[i]) / (1 - beta) for i in range(len(class_counts))]
        weights = [1.0 / effective_nums[i] for i in range(len(class_counts))]
        sum_weights = sum(weights)
        self.class_weights = torch.FloatTensor([w * len(weights) / sum_weights for w in weights])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx].clone()
        label = self.labels[idx]

        # Apply sophisticated augmentation during training
        if self.train_mode and random.random() < self.augment_prob:
            sequence = self._apply_hybrid_augmentation(sequence)

        return sequence, label

    def _apply_hybrid_augmentation(self, sequence):
        """Hybrid-specific augmentation techniques"""
        aug_techniques = ['noise', 'temporal_mask', 'feature_dropout', 'mixup', 'gaussian_blur']
        selected_augs = random.sample(aug_techniques, k=random.randint(1, 2))

        for aug_type in selected_augs:
            if aug_type == 'noise':
                # Adaptive Gaussian noise
                noise_std = torch.std(sequence) * 0.05
                noise = torch.randn_like(sequence) * noise_std
                sequence = sequence + noise

            elif aug_type == 'temporal_mask':
                # Temporal masking (similar to SpecAugment)
                mask_length = random.randint(5, 15)
                mask_start = random.randint(0, max(0, sequence.size(0) - mask_length))
                sequence[mask_start:mask_start + mask_length] *= 0.1

            elif aug_type == 'feature_dropout':
                # Feature channel dropout
                num_features_to_drop = random.randint(1, sequence.size(1) // 4)
                features_to_drop = random.sample(range(sequence.size(1)), num_features_to_drop)
                sequence[:, features_to_drop] *= 0.1

            elif aug_type == 'mixup':
                # Temporal mixup within sequence
                alpha = 0.2
                lam = np.random.beta(alpha, alpha)
                rand_index = torch.randperm(sequence.size(0))
                sequence = lam * sequence + (1 - lam) * sequence[rand_index]

            elif aug_type == 'gaussian_blur':
                # 1D Gaussian blur for temporal smoothing
                kernel_size = 3
                sigma = 0.5
                sequence = self._gaussian_blur_1d(sequence, kernel_size, sigma)

        return torch.clamp(sequence, -5, 5)  # Prevent extreme values

    def _gaussian_blur_1d(self, tensor, kernel_size, sigma):
        """Apply 1D Gaussian blur to temporal sequence"""
        channels = tensor.size(1)

        # Create Gaussian kernel
        x = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
        gaussian_kernel = torch.exp(-x.pow(2) / (2 * sigma**2))
        gaussian_kernel = gaussian_kernel / gaussian_kernel.sum()

        # Apply convolution for each feature channel
        blurred = F.conv1d(
            tensor.transpose(0, 1).unsqueeze(0),
            gaussian_kernel.unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1),
            padding=kernel_size // 2,
            groups=channels
        )

        return blurred.squeeze(0).transpose(0, 1)

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class ConvLSTMCell(nn.Module):
    """ConvLSTM Cell for spatio-temporal processing"""
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        # Convolutional gates
        self.conv = nn.Conv1d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Ensure both tensors have the same spatial dimension
        # input_tensor: [batch, input_dim, 1]
        # h_cur: [batch, hidden_dim, 1]
        combined = torch.cat([input_tensor, h_cur], dim=1)

        # Convolutional gates
        combined_conv = self.conv(combined)

        # Split into gates
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        # Apply gates
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        # Update cell state
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, device):
        """Initialize hidden and cell states"""
        return (torch.zeros(batch_size, self.hidden_dim, 1, device=device),
                torch.zeros(batch_size, self.hidden_dim, 1, device=device))

class ConvLSTM(nn.Module):
    """ConvLSTM module for temporal-spatial feature extraction"""
    def __init__(self, input_dim, hidden_dims, kernel_sizes, num_layers,
                 bidirectional=True, dropout=0.2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims if isinstance(hidden_dims, list) else [hidden_dims] * num_layers
        self.kernel_sizes = kernel_sizes if isinstance(kernel_sizes, list) else [kernel_sizes] * num_layers
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dropout = dropout

        # Create ConvLSTM layers
        cell_list = []
        for i in range(num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            if self.bidirectional and i > 0:
                cur_input_dim = self.hidden_dims[i-1] * 2

            cell_list.append(ConvLSTMCell(
                input_dim=cur_input_dim,
                hidden_dim=self.hidden_dims[i],
                kernel_size=self.kernel_sizes[i]
            ))

        self.cell_list = nn.ModuleList(cell_list)
        self.dropout_layers = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers)])

    def forward(self, input_tensor):
        """
        Forward pass through ConvLSTM
        Args:
            input_tensor: [batch, seq_len, features]
        Returns:
            layer_output_list: List of outputs from each layer
            last_state_list: List of final states
        """
        batch_size, seq_len, _ = input_tensor.shape
        device = input_tensor.device

        # Transform input for conv processing: [batch, features, seq_len]
        input_tensor = input_tensor.transpose(1, 2)

        layer_output_list = []
        last_state_list = []

        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            # Initialize hidden states
            h, c = self.cell_list[layer_idx].init_hidden(batch_size, device)

            # Forward direction
            output_inner = []
            for t in range(seq_len):
                # Extract single timestep: [batch, features, 1]
                timestep_input = cur_layer_input[:, :, t:t+1]
                h, c = self.cell_list[layer_idx](timestep_input, (h, c))
                output_inner.append(h)

            # Concatenate outputs: [batch, hidden_dim, seq_len]
            forward_output = torch.cat(output_inner, dim=2)

            if self.bidirectional:
                # Backward direction
                h_back, c_back = self.cell_list[layer_idx].init_hidden(batch_size, device)
                output_inner_back = []

                for t in reversed(range(seq_len)):
                    timestep_input = cur_layer_input[:, :, t:t+1]
                    h_back, c_back = self.cell_list[layer_idx](timestep_input, (h_back, c_back))
                    output_inner_back.append(h_back)

                # Reverse and concatenate
                backward_output = torch.cat(output_inner_back[::-1], dim=2)
                layer_output = torch.cat([forward_output, backward_output], dim=1)
            else:
                layer_output = forward_output

            # Apply dropout
            if layer_idx < self.num_layers - 1:
                layer_output = self.dropout_layers[layer_idx](layer_output)

            layer_output_list.append(layer_output)
            last_state_list.append((h, c))
            cur_layer_input = layer_output

        return layer_output_list, last_state_list

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention mechanism for Vision Transformer"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_dropout(x)

        return x, attn

class TransformerBlock(nn.Module):
    """Transformer block with multi-head attention and MLP"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Multi-head self-attention with residual connection
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out

        # MLP with residual connection
        x = x + self.mlp(self.norm2(x))

        return x, attn_weights

class VisionTransformer(nn.Module):
    """Vision Transformer for sequence modeling"""
    def __init__(self, seq_len, embed_dim, num_heads, num_layers,
                 mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers

        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len, embed_dim) * 0.02)
        self.pos_dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Forward pass through Vision Transformer
        Args:
            x: [batch, seq_len, embed_dim]
        Returns:
            x: Processed sequence
            attn_weights: List of attention weights from each layer
        """
        B, N, C = x.shape

        # Add positional embedding
        x = x + self.pos_embed[:, :N, :]
        x = self.pos_dropout(x)

        # Process through transformer blocks
        attn_weights_list = []
        for block in self.blocks:
            x, attn_weights = block(x)
            attn_weights_list.append(attn_weights)

        x = self.norm(x)

        return x, attn_weights_list

class CrossModalAttention(nn.Module):
    """Cross-modal attention for fusing ConvLSTM and ViT features"""
    def __init__(self, conv_dim, vit_dim, fusion_dim, num_heads=8):
        super().__init__()

        self.conv_dim = conv_dim
        self.vit_dim = vit_dim
        self.fusion_dim = fusion_dim
        self.num_heads = num_heads

        # Project inputs to same dimension
        self.conv_proj = nn.Linear(conv_dim, fusion_dim)
        self.vit_proj = nn.Linear(vit_dim, fusion_dim)

        # Cross-attention layers
        self.conv_to_vit_attn = nn.MultiheadAttention(fusion_dim, num_heads, batch_first=True)
        self.vit_to_conv_attn = nn.MultiheadAttention(fusion_dim, num_heads, batch_first=True)

        # Fusion layers
        self.fusion_norm = nn.LayerNorm(fusion_dim * 2)
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, fusion_dim)
        )

    def forward(self, conv_features, vit_features):
        """
        Cross-modal attention fusion
        Args:
            conv_features: [batch, seq_len, conv_dim]
            vit_features: [batch, seq_len, vit_dim]
        Returns:
            fused_features: [batch, seq_len, fusion_dim]
        """
        # Project to same dimension
        conv_proj = self.conv_proj(conv_features)
        vit_proj = self.vit_proj(vit_features)

        # Cross-attention: ConvLSTM features attend to ViT features
        conv_attended, _ = self.conv_to_vit_attn(conv_proj, vit_proj, vit_proj)

        # Cross-attention: ViT features attend to ConvLSTM features
        vit_attended, _ = self.vit_to_conv_attn(vit_proj, conv_proj, conv_proj)

        # Concatenate and fuse
        combined = torch.cat([conv_attended, vit_attended], dim=-1)
        combined = self.fusion_norm(combined)
        fused = self.fusion_mlp(combined)

        # Residual connection
        fused = fused + conv_proj + vit_proj

        return fused

class ConvLSTM_ViT_Hybrid(nn.Module):
    """Hybrid model combining ConvLSTM and Vision Transformer"""
    def __init__(self, input_dim=69, hidden_dim=128, num_classes=10,
                 seq_len=150, dropout=0.2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.seq_len = seq_len

        # Input preprocessing
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # ConvLSTM branch for local temporal-spatial patterns
        self.convlstm = ConvLSTM(
            input_dim=hidden_dim,
            hidden_dims=[hidden_dim, hidden_dim],
            kernel_sizes=[3, 5],
            num_layers=2,
            bidirectional=True,
            dropout=dropout
        )

        # Vision Transformer branch for global sequential patterns
        self.vit = VisionTransformer(
            seq_len=seq_len,
            embed_dim=hidden_dim,
            num_heads=8,
            num_layers=4,
            mlp_ratio=4.0,
            dropout=dropout
        )

        # Cross-modal attention fusion
        conv_output_dim = hidden_dim * 2  # Bidirectional
        self.cross_modal_fusion = CrossModalAttention(
            conv_dim=conv_output_dim,
            vit_dim=hidden_dim,
            fusion_dim=hidden_dim,
            num_heads=8
        )

        # Adaptive pooling strategies
        self.adaptive_pool = nn.ModuleList([
            nn.AdaptiveAvgPool1d(1),
            nn.AdaptiveMaxPool1d(1),
            nn.AdaptiveAvgPool1d(4)
        ])

        # Attention-based pooling selection
        self.pool_attention = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 3),
            nn.Softmax(dim=-1)
        )

        # Classification head with uncertainty estimation
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(dropout * 0.3),
            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Uncertainty head for confidence estimation
        self.uncertainty_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.GELU(),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights using appropriate strategies"""
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Conv1d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x, return_features=False):
        """
        Forward pass through hybrid model
        Args:
            x: Input tensor [batch, seq_len, input_dim]
            return_features: Whether to return intermediate features
        Returns:
            logits: Classification output
            uncertainty: Prediction uncertainty
            features: Intermediate features (if requested)
        """
        batch_size, seq_len, _ = x.shape

        # Input preprocessing
        x_proj = self.input_projection(x)

        # ConvLSTM branch for local patterns
        conv_outputs, _ = self.convlstm(x_proj)
        conv_features = conv_outputs[-1].transpose(1, 2)  # [batch, seq_len, conv_dim]

        # Vision Transformer branch for global patterns
        vit_features, vit_attention_weights = self.vit(x_proj)

        # Cross-modal fusion
        fused_features = self.cross_modal_fusion(conv_features, vit_features)

        # Adaptive pooling
        pooled_features = []
        for pool in self.adaptive_pool:
            if pool.output_size == 1:
                pooled = pool(fused_features.transpose(1, 2)).squeeze(-1)
            else:
                pooled = pool(fused_features.transpose(1, 2)).transpose(1, 2).mean(dim=1)
            pooled_features.append(pooled)

        # Attention-weighted pooling
        all_pooled = torch.cat(pooled_features, dim=-1)
        pool_weights = self.pool_attention(all_pooled)

        final_features = sum(w.unsqueeze(-1) * feat for w, feat in zip(pool_weights.unbind(-1), pooled_features))

        # Classification and uncertainty
        logits = self.classifier(final_features)
        uncertainty = self.uncertainty_head(final_features)

        if return_features:
            return logits, uncertainty, {
                'conv_features': conv_features,
                'vit_features': vit_features,
                'fused_features': fused_features,
                'vit_attention': vit_attention_weights,
                'pool_weights': pool_weights
            }

        return logits, uncertainty

class HybridModelTrainer:
    """Trainer for ConvLSTM-ViT Hybrid model"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss functions
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.classification_criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
        self.uncertainty_criterion = nn.MSELoss()

        # Advanced optimizer with component-specific learning rates
        convlstm_params = []
        vit_params = []
        fusion_params = []
        classifier_params = []

        for name, param in model.named_parameters():
            if 'convlstm' in name:
                convlstm_params.append(param)
            elif 'vit' in name:
                vit_params.append(param)
            elif 'cross_modal' in name or 'pool' in name:
                fusion_params.append(param)
            else:
                classifier_params.append(param)

        self.optimizer = optim.AdamW([
            {'params': convlstm_params, 'lr': 5e-4, 'weight_decay': 0.01},
            {'params': vit_params, 'lr': 3e-4, 'weight_decay': 0.005},
            {'params': fusion_params, 'lr': 8e-4, 'weight_decay': 0.01},
            {'params': classifier_params, 'lr': 1e-3, 'weight_decay': 0.01}
        ], betas=(0.9, 0.95))

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=20, T_mult=2, eta_min=1e-6
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def train_epoch(self, dataloader):
        """Training epoch with advanced loss computation"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            logits, uncertainty = self.model(data)

            # Classification loss
            clf_loss = self.classification_criterion(logits, target)

            # Uncertainty loss
            pred_probs = F.softmax(logits, dim=1)
            pred_confidence = pred_probs.max(dim=1)[0]
            uncertainty_target = 1.0 - pred_confidence
            unc_loss = self.uncertainty_criterion(uncertainty.squeeze(), uncertainty_target.detach())

            # Combined loss
            total_loss_batch = clf_loss + 0.1 * unc_loss

            # L2 regularization
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            total_loss_batch += 1e-5 * l2_reg

            total_loss_batch.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Statistics
            total_loss += total_loss_batch.item()
            _, predicted = logits.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def validate(self, dataloader):
        """Validation with comprehensive metrics"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating"):
                data, target = data.to(self.device), target.to(self.device)

                logits, uncertainty = self.model(data)
                loss = self.classification_criterion(logits, target)

                total_loss += loss.item()
                _, predicted = logits.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def train(self, train_loader, val_loader, epochs=150):
        """Complete training loop with advanced monitoring"""
        best_val_f1 = 0
        patience = 25
        patience_counter = 0

        logger.info(f"🚀 ConvLSTM-ViT Hybrid Training Started: {epochs} epochs")
        logger.info(f"   - Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")

        for epoch in range(epochs):
            # Training
            train_loss, train_acc, train_f1, _, _ = self.train_epoch(train_loader)

            # Validation
            val_loss, val_acc, val_f1, _, _ = self.validate(val_loader)

            # Learning rate scheduling
            self.scheduler.step()

            # Save metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Progress logging
            if (epoch + 1) % 10 == 0:
                logger.info(f'Epoch {epoch+1}/{epochs}:')
                logger.info(f'  Train: Loss={train_loss:.3f}, Acc={train_acc:.1f}%, F1={train_f1:.3f}')
                logger.info(f'  Val:   Loss={val_loss:.3f}, Acc={val_acc:.1f}%, F1={val_f1:.3f}')
                logger.info(f'  LR: {self.optimizer.param_groups[0]["lr"]:.2e}')

            # Save best model
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save({
                    'model_state_dict': self.model.state_dict(),
                    'epoch': epoch,
                    'val_f1': val_f1,
                    'model_type': 'ConvLSTM_ViT_Hybrid'
                }, "convlstm_vit_hybrid_best_model.pth")
                logger.info(f"  ✅ New best model saved: F1={val_f1:.3f}")
                patience_counter = 0
            else:
                patience_counter += 1

            # Early stopping
            if patience_counter >= patience:
                logger.info(f"Early stopping triggered at epoch {epoch+1}")
                break

        logger.info(f"🎯 Training completed. Best F1 Score: {best_val_f1:.3f}")
        return best_val_f1

def visualize_hybrid_attention(model, dataloader, class_names, device, num_samples=2):
    """Visualize hybrid model attention mechanisms"""
    model.eval()

    fig, axes = plt.subplots(num_samples, 4, figsize=(20, 6*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    sample_count = 0
    with torch.no_grad():
        for data, targets in dataloader:
            if sample_count >= num_samples:
                break

            data = data.to(device)

            for i in range(min(data.size(0), num_samples - sample_count)):
                single_input = data[i:i+1]
                target_class = targets[i].item()

                # Get features and attention
                _, _, features = model(single_input, return_features=True)

                # ConvLSTM features visualization
                ax1 = axes[sample_count, 0]
                conv_feat = features['conv_features'][0].mean(dim=-1).cpu().numpy()
                ax1.plot(conv_feat, linewidth=2, color='blue')
                ax1.fill_between(range(len(conv_feat)), conv_feat, alpha=0.3, color='blue')
                ax1.set_title(f'ConvLSTM Features\nClass: {class_names[target_class]}')
                ax1.set_xlabel('Time Step')
                ax1.set_ylabel('Feature Activation')
                ax1.grid(True, alpha=0.3)

                # ViT attention weights (last layer)
                ax2 = axes[sample_count, 1]
                vit_attn = features['vit_attention'][-1][0].mean(dim=0).mean(dim=0).cpu().numpy()
                ax2.plot(vit_attn, linewidth=2, color='red')
                ax2.fill_between(range(len(vit_attn)), vit_attn, alpha=0.3, color='red')
                ax2.set_title(f'ViT Attention Weights\nClass: {class_names[target_class]}')
                ax2.set_xlabel('Time Step')
                ax2.set_ylabel('Attention Weight')
                ax2.grid(True, alpha=0.3)

                # Fused features
                ax3 = axes[sample_count, 2]
                fused_feat = features['fused_features'][0].mean(dim=-1).cpu().numpy()
                ax3.plot(fused_feat, linewidth=2, color='green')
                ax3.fill_between(range(len(fused_feat)), fused_feat, alpha=0.3, color='green')
                ax3.set_title(f'Fused Features\nClass: {class_names[target_class]}')
                ax3.set_xlabel('Time Step')
                ax3.set_ylabel('Feature Activation')
                ax3.grid(True, alpha=0.3)

                # Pooling weights
                ax4 = axes[sample_count, 3]
                pool_weights = features['pool_weights'][0].cpu().numpy()
                pool_names = ['Avg Pool', 'Max Pool', 'Adaptive Pool']
                bars = ax4.bar(pool_names, pool_weights)
                ax4.set_title(f'Pooling Strategy Weights\nClass: {class_names[target_class]}')
                ax4.set_ylabel('Weight')
                for bar, weight in zip(bars, pool_weights):
                    height = bar.get_height()
                    ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                            f'{weight:.3f}', ha='center', va='bottom')

                sample_count += 1
                if sample_count >= num_samples:
                    break

    plt.tight_layout()
    plt.savefig('convlstm_vit_hybrid_attention_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

def plot_comprehensive_results(trainer, model_name="ConvLSTM_ViT_Hybrid"):
    """Plot comprehensive training results"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Loss curves
    ax1.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_title(f'{model_name} - Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Accuracy curves
    ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_title(f'{model_name} - Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # F1 Score progression
    ax3.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2)
    ax3.set_title(f'{model_name} - Validation F1 Score')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('F1 Score')
    ax3.grid(True, alpha=0.3)

    # Train-validation gap analysis
    gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
    ax4.plot(epochs, gaps, 'purple', linewidth=2)
    ax4.axhline(y=10, color='orange', linestyle='--', label='Warning Threshold')
    ax4.axhline(y=5, color='green', linestyle='--', label='Good Threshold')
    ax4.set_title(f'{model_name} - Train-Validation Gap')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Accuracy Gap (%)')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_training_results.png', dpi=150)
    plt.show()

def prepare_dataset_for_hybrid(csv_path, sequence_length=150, overlap_ratio=0.35):
    """Prepare dataset optimized for hybrid model"""
    logger.info(f"📊 Loading data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Advanced data cleaning
    original_size = len(df)
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(method='bfill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Remove statistical outliers
    for col in feature_columns:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]

    logger.info(f"After cleaning: {len(df)} samples ({len(df)/original_size*100:.1f}% retained)")

    # Filter classes with sufficient data
    min_samples = sequence_length * 4
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")

    sequences = []
    labels = []

    # Create sequences with adaptive stride
    stride = max(1, int(sequence_length * (1 - overlap_ratio)))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length and not np.any(np.isnan(sequence)):
                # Quality check: ensure sequence has reasonable variance
                if np.std(sequence) > 1e-6:
                    sequences.append(sequence)
                    labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Final class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

def plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names):
    """Plot confusion matrices for both validation and test sets"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix
    cm_val = confusion_matrix(val_targets, val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names, ax=ax1)
    ax1.set_title('Validation Confusion Matrix', fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix
    cm_test = confusion_matrix(test_targets, test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=target_names, yticklabels=target_names, ax=ax2)
    ax2.set_title('Test Confusion Matrix', fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('convlstm_vit_hybrid_confusion_matrices.png', dpi=150)
    plt.show()

def main():
    set_seed(42)

    # Configuration parameters
    CSV_PATH = ""  # Update with your dataset path
    SEQUENCE_LENGTH = 150
    BATCH_SIZE = 16          # Optimized for hybrid model
    EPOCHS = 150             # Extended training
    VAL_SIZE = 0.15          # 15% for validation
    TEST_SIZE = 0.15         # 15% for test (70% remaining for training)

    logger.info("🚀 ConvLSTM + Vision Transformer Hybrid Model Training Started!")
    logger.info(f"📊 Data split: 70% Train, 15% Validation, 15% Test")
    logger.info("="*80)

    # Enhanced data preparation
    sequences, labels, feature_count = prepare_dataset_for_hybrid(
        CSV_PATH, SEQUENCE_LENGTH, overlap_ratio=0.35
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    # First split: train+val (85%) and test (15%)
    X_temp, X_test, y_temp, y_test = train_test_split(
        sequences, labels, test_size=TEST_SIZE, random_state=42, stratify=labels
    )

    # Second split: train (70%) and val (15%) from remaining 85%
    val_size_adjusted = VAL_SIZE / (1 - TEST_SIZE)
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size_adjusted, random_state=42, stratify=y_temp
    )

    logger.info(f"\n📊 Dataset Summary:")
    logger.info(f"   - Input Features: {feature_count}")
    logger.info(f"   - Sequence Length: {SEQUENCE_LENGTH}")
    logger.info(f"   - Training sequences: {len(X_train):,} ({len(X_train)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Validation sequences: {len(X_val):,} ({len(X_val)/len(sequences)*100:.1f}%)")
    logger.info(f"   - Test sequences: {len(X_test):,} ({len(X_test)/len(sequences)*100:.1f}%)")

    # Create enhanced datasets
    train_dataset = EnhancedDiverSignDataset(X_train, y_train, train_mode=True, augment_prob=0.6)
    val_dataset = EnhancedDiverSignDataset(
        X_val, y_val, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )
    test_dataset = EnhancedDiverSignDataset(
        X_test, y_test, train_dataset.get_label_encoder(),
        train_dataset.get_scaler(), train_mode=False
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                             num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                           num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=2, pin_memory=True)

    num_classes = len(train_dataset.get_label_encoder().classes_)
    class_names = train_dataset.get_label_encoder().classes_

    logger.info(f"   - Number of Classes: {num_classes}")
    logger.info(f"   - Class Names: {list(class_names)}")

    # Create ConvLSTM-ViT Hybrid model
    model = ConvLSTM_ViT_Hybrid(
        input_dim=feature_count,
        hidden_dim=128,           # Optimized size
        num_classes=num_classes,
        seq_len=SEQUENCE_LENGTH,
        dropout=0.25
    )

    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"\n🤖 ConvLSTM-ViT Hybrid Model Configuration:")
    logger.info(f"   - Total Parameters: {total_params:,}")
    logger.info(f"   - ConvLSTM: 2 layers, bidirectional")
    logger.info(f"   - Vision Transformer: 4 layers, 8 heads")
    logger.info(f"   - Cross-Modal Fusion: 8 heads")
    logger.info(f"   - Hidden Dimension: 128")
    logger.info(f"   - Uncertainty Estimation: Enabled")

    # Initialize trainer
    trainer = HybridModelTrainer(
        model=model,
        class_weights=train_dataset.get_class_weights()
    )

    logger.info(f"\n🔥 Starting Hybrid Training ({EPOCHS} epochs)...")
    logger.info("="*80)

    # Train the model
    best_f1 = trainer.train(train_loader, val_loader, epochs=EPOCHS)

    # Load best model for evaluation
    checkpoint = torch.load("convlstm_vit_hybrid_best_model.pth")
    model.load_state_dict(checkpoint['model_state_dict'])

    # Final validation evaluation
    val_loss, val_acc, val_f1, val_preds, val_targets = trainer.validate(val_loader)

    # Final test evaluation
    test_loss, test_acc, test_f1, test_preds, test_targets = trainer.validate(test_loader)

    # Training analysis
    final_train_acc = trainer.train_accuracies[-1]
    final_val_acc = trainer.val_accuracies[-1]
    accuracy_gap = final_train_acc - final_val_acc

    logger.info(f"\n📊 ConvLSTM-ViT Hybrid Final Results:")
    logger.info("="*60)
    logger.info(f"   📊 VALIDATION:")
    logger.info(f"     - Accuracy: {val_acc:.2f}%")
    logger.info(f"     - F1 Score: {val_f1:.4f}")
    logger.info(f"   📊 TEST:")
    logger.info(f"     - Accuracy: {test_acc:.2f}%")
    logger.info(f"     - F1 Score: {test_f1:.4f}")
    logger.info(f"   📊 TRAINING ANALYSIS:")
    logger.info(f"     - Train-Validation Gap: {accuracy_gap:.2f}%")
    logger.info(f"     - Best Validation F1: {best_f1:.4f}")

    # Performance assessment
    if test_f1 >= 0.80:
        performance_status = "🎉 Excellent Performance!"
    elif test_f1 >= 0.70:
        performance_status = "✅ Good Performance!"
    elif test_f1 >= 0.60:
        performance_status = "👍 Satisfactory Performance"
    else:
        performance_status = "⚠️ Performance Needs Improvement"

    logger.info(f"   🎯 Overall Assessment: {performance_status}")

    # Classification reports
    target_names = train_dataset.get_label_encoder().classes_

    print(f"\n📋 VALIDATION Classification Report:")
    print("="*80)
    print(classification_report(val_targets, val_preds, target_names=target_names, digits=3))

    print(f"\n📋 TEST Classification Report:")
    print("="*80)
    print(classification_report(test_targets, test_preds, target_names=target_names, digits=3))

    # Generate confusion matrices
    plot_confusion_matrices(val_targets, val_preds, test_targets, test_preds, target_names)

    # Plot training results
    plot_comprehensive_results(trainer, "ConvLSTM_ViT_Hybrid")

    # Visualize attention mechanisms
    logger.info("\n🔍 Generating Attention Analysis...")
    visualize_hybrid_attention(model, test_loader, class_names, trainer.device, num_samples=2)

    # Model efficiency analysis
    logger.info(f"\n⚡ Model Efficiency Analysis:")
    logger.info(f"   - Parameters per Class: {total_params // num_classes:,}")
    logger.info(f"   - Accuracy per 1K Parameters: {test_acc / (total_params / 1000):.3f}")
    logger.info(f"   - F1 Score per 1M Parameters: {test_f1 / (total_params / 1000000):.3f}")

    # Save comprehensive results
    logger.info(f"\n💾 Saving Results...")

    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': {
            'input_dim': feature_count,
            'hidden_dim': 128,
            'num_classes': num_classes,
            'seq_len': SEQUENCE_LENGTH,
            'dropout': 0.25
        },
        'training_history': {
            'train_losses': trainer.train_losses,
            'val_losses': trainer.val_losses,
            'train_accuracies': trainer.train_accuracies,
            'val_accuracies': trainer.val_accuracies,
            'val_f1_scores': trainer.val_f1_scores
        },
        'final_results': {
            'test_accuracy': test_acc,
            'test_f1': test_f1,
            'validation_accuracy': val_acc,
            'validation_f1': val_f1,
            'best_val_f1': best_f1,
            'accuracy_gap': accuracy_gap
        }
    }, 'convlstm_vit_hybrid_final_model.pth')

    logger.info("✅ ConvLSTM-ViT Hybrid Training Completed Successfully!")
    logger.info(f"📁 Files saved:")
    logger.info(f"   - convlstm_vit_hybrid_final_model.pth: Complete model")
    logger.info(f"   - convlstm_vit_hybrid_confusion_matrices.png: Confusion matrices")
    logger.info(f"   - convlstm_vit_hybrid_training_results.png: Training curves")
    logger.info(f"   - convlstm_vit_hybrid_attention_analysis.png: Attention visualization")

    logger.info(f"\n🏆 FINAL PERFORMANCE SUMMARY:")
    logger.info(f"🎯 Test F1 Score: {test_f1:.4f}")
    logger.info(f"🎯 Test Accuracy: {test_acc:.2f}%")
    logger.info(f"🎯 Model Parameters: {total_params:,}")

    return trainer, model, train_dataset, val_targets, val_preds, test_targets, test_preds

if __name__ == "__main__":
    trainer, model, dataset, val_targets, val_preds, test_targets, test_preds = main()

In [None]:
# ConvLSTM + Vision Transformer Hybrid Model with Enhanced 5-Fold Cross Validation - COMPLETE CODE

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import os
import logging
import random
from collections import Counter
import math
import warnings
from scipy import stats
warnings.filterwarnings('ignore')

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def prepare_dataset_for_hybrid(csv_path, sequence_length=150, overlap_ratio=0.35):
    """Prepare dataset optimized for hybrid model"""
    logger.info(f"📊 Loading data from: {csv_path}")

    df = pd.read_csv(csv_path)
    feature_columns = [col for col in df.columns if col != 'class']

    # Advanced data cleaning
    original_size = len(df)
    df[feature_columns] = df[feature_columns].fillna(method='ffill').fillna(method='bfill').fillna(0)
    df[feature_columns] = df[feature_columns].replace([np.inf, -np.inf], 0)

    # Remove statistical outliers
    for col in feature_columns:
        Q1 = df[col].quantile(0.25)
        Q3 = df[col].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]

    logger.info(f"After cleaning: {len(df)} samples ({len(df)/original_size*100:.1f}% retained)")

    # Filter classes with sufficient data
    min_samples = sequence_length * 4
    class_counts = df['class'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    df = df[df['class'].isin(valid_classes)]

    logger.info(f"Valid classes after filtering: {list(valid_classes)}")

    sequences = []
    labels = []

    # Create sequences with adaptive stride
    stride = max(1, int(sequence_length * (1 - overlap_ratio)))

    for class_name in valid_classes:
        class_data = df[df['class'] == class_name].reset_index(drop=True)

        for i in range(0, len(class_data) - sequence_length + 1, stride):
            sequence = class_data.iloc[i:i+sequence_length][feature_columns].values

            if len(sequence) == sequence_length and not np.any(np.isnan(sequence)):
                # Quality check: ensure sequence has reasonable variance
                if np.std(sequence) > 1e-6:
                    sequences.append(sequence)
                    labels.append(class_name)

    sequences = np.array(sequences)
    logger.info(f"Total sequences created: {len(sequences)}")
    logger.info(f"Final class distribution: {dict(Counter(labels))}")

    return sequences, labels, len(feature_columns)

class EnhancedDiverSignDataset(Dataset):
    """Enhanced dataset with advanced augmentation for hybrid model"""
    def __init__(self, sequences, labels, label_encoder=None, scaler=None,
                 train_mode=True, augment_prob=0.4):

        # Label encoding
        if label_encoder is None:
            self.label_encoder = LabelEncoder()
            encoded_labels = self.label_encoder.fit_transform(labels)
        else:
            self.label_encoder = label_encoder
            encoded_labels = self.label_encoder.transform(labels)

        # Advanced data scaling
        if scaler is None:
            self.scaler = StandardScaler()
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.fit_transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)
        else:
            self.scaler = scaler
            reshaped = sequences.reshape(-1, sequences.shape[-1])
            scaled_reshaped = self.scaler.transform(reshaped)
            scaled_sequences = scaled_reshaped.reshape(sequences.shape)

        self.sequences = torch.FloatTensor(scaled_sequences)
        self.labels = torch.LongTensor(encoded_labels)
        self.train_mode = train_mode
        self.augment_prob = augment_prob

        # Enhanced class weights using effective number of samples
        class_counts = Counter(encoded_labels)
        total_samples = len(encoded_labels)

        # Effective number of samples for class balancing
        beta = 0.9999
        effective_nums = [(1 - beta**class_counts[i]) / (1 - beta) for i in range(len(class_counts))]
        weights = [1.0 / effective_nums[i] for i in range(len(class_counts))]
        sum_weights = sum(weights)
        self.class_weights = torch.FloatTensor([w * len(weights) / sum_weights for w in weights])

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx].clone()
        label = self.labels[idx]

        # Apply sophisticated augmentation during training
        if self.train_mode and random.random() < self.augment_prob:
            sequence = self._apply_hybrid_augmentation(sequence)

        return sequence, label

    def _apply_hybrid_augmentation(self, sequence):
        """Hybrid-specific augmentation techniques"""
        aug_techniques = ['noise', 'temporal_mask', 'feature_dropout', 'mixup', 'gaussian_blur']
        selected_augs = random.sample(aug_techniques, k=random.randint(1, 2))

        for aug_type in selected_augs:
            if aug_type == 'noise':
                # Adaptive Gaussian noise
                noise_std = torch.std(sequence) * 0.05
                noise = torch.randn_like(sequence) * noise_std
                sequence = sequence + noise

            elif aug_type == 'temporal_mask':
                # Temporal masking (similar to SpecAugment)
                mask_length = random.randint(5, 15)
                mask_start = random.randint(0, max(0, sequence.size(0) - mask_length))
                sequence[mask_start:mask_start + mask_length] *= 0.1

            elif aug_type == 'feature_dropout':
                # Feature channel dropout
                num_features_to_drop = random.randint(1, sequence.size(1) // 4)
                features_to_drop = random.sample(range(sequence.size(1)), num_features_to_drop)
                sequence[:, features_to_drop] *= 0.1

            elif aug_type == 'mixup':
                # Temporal mixup within sequence
                alpha = 0.2
                lam = np.random.beta(alpha, alpha)
                rand_index = torch.randperm(sequence.size(0))
                sequence = lam * sequence + (1 - lam) * sequence[rand_index]

            elif aug_type == 'gaussian_blur':
                # 1D Gaussian blur for temporal smoothing
                kernel_size = 3
                sigma = 0.5
                sequence = self._gaussian_blur_1d(sequence, kernel_size, sigma)

        return torch.clamp(sequence, -5, 5)  # Prevent extreme values

    def _gaussian_blur_1d(self, tensor, kernel_size, sigma):
        """Apply 1D Gaussian blur to temporal sequence"""
        channels = tensor.size(1)

        # Create Gaussian kernel
        x = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
        gaussian_kernel = torch.exp(-x.pow(2) / (2 * sigma**2))
        gaussian_kernel = gaussian_kernel / gaussian_kernel.sum()

        # Apply convolution for each feature channel
        blurred = F.conv1d(
            tensor.transpose(0, 1).unsqueeze(0),
            gaussian_kernel.unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1),
            padding=kernel_size // 2,
            groups=channels
        )

        return blurred.squeeze(0).transpose(0, 1)

    def get_label_encoder(self):
        return self.label_encoder

    def get_scaler(self):
        return self.scaler

    def get_class_weights(self):
        return self.class_weights

class ConvLSTMCell(nn.Module):
    """ConvLSTM Cell for spatio-temporal processing"""
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size // 2
        self.bias = bias

        # Convolutional gates
        self.conv = nn.Conv1d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Ensure both tensors have the same spatial dimension
        combined = torch.cat([input_tensor, h_cur], dim=1)

        # Convolutional gates
        combined_conv = self.conv(combined)

        # Split into gates
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        # Apply gates
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        # Update cell state
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, device):
        """Initialize hidden and cell states"""
        return (torch.zeros(batch_size, self.hidden_dim, 1, device=device),
                torch.zeros(batch_size, self.hidden_dim, 1, device=device))

class ConvLSTM(nn.Module):
    """ConvLSTM module for temporal-spatial feature extraction"""
    def __init__(self, input_dim, hidden_dims, kernel_sizes, num_layers,
                 bidirectional=True, dropout=0.2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims if isinstance(hidden_dims, list) else [hidden_dims] * num_layers
        self.kernel_sizes = kernel_sizes if isinstance(kernel_sizes, list) else [kernel_sizes] * num_layers
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dropout = dropout

        # Create ConvLSTM layers
        cell_list = []
        for i in range(num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i-1]
            if self.bidirectional and i > 0:
                cur_input_dim = self.hidden_dims[i-1] * 2

            cell_list.append(ConvLSTMCell(
                input_dim=cur_input_dim,
                hidden_dim=self.hidden_dims[i],
                kernel_size=self.kernel_sizes[i]
            ))

        self.cell_list = nn.ModuleList(cell_list)
        self.dropout_layers = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers)])

    def forward(self, input_tensor):
        """Forward pass through ConvLSTM"""
        batch_size, seq_len, _ = input_tensor.shape
        device = input_tensor.device

        # Transform input for conv processing: [batch, features, seq_len]
        input_tensor = input_tensor.transpose(1, 2)

        layer_output_list = []
        last_state_list = []

        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            # Initialize hidden states
            h, c = self.cell_list[layer_idx].init_hidden(batch_size, device)

            # Forward direction
            output_inner = []
            for t in range(seq_len):
                # Extract single timestep: [batch, features, 1]
                timestep_input = cur_layer_input[:, :, t:t+1]
                h, c = self.cell_list[layer_idx](timestep_input, (h, c))
                output_inner.append(h)

            # Concatenate outputs: [batch, hidden_dim, seq_len]
            forward_output = torch.cat(output_inner, dim=2)

            if self.bidirectional:
                # Backward direction
                h_back, c_back = self.cell_list[layer_idx].init_hidden(batch_size, device)
                output_inner_back = []

                for t in reversed(range(seq_len)):
                    timestep_input = cur_layer_input[:, :, t:t+1]
                    h_back, c_back = self.cell_list[layer_idx](timestep_input, (h_back, c_back))
                    output_inner_back.append(h_back)

                # Reverse and concatenate
                backward_output = torch.cat(output_inner_back[::-1], dim=2)
                layer_output = torch.cat([forward_output, backward_output], dim=1)
            else:
                layer_output = forward_output

            # Apply dropout
            if layer_idx < self.num_layers - 1:
                layer_output = self.dropout_layers[layer_idx](layer_output)

            layer_output_list.append(layer_output)
            last_state_list.append((h, c))
            cur_layer_input = layer_output

        return layer_output_list, last_state_list

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention mechanism for Vision Transformer"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_dropout(x)

        return x, attn

class TransformerBlock(nn.Module):
    """Transformer block with multi-head attention and MLP"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Multi-head self-attention with residual connection
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out

        # MLP with residual connection
        x = x + self.mlp(self.norm2(x))

        return x, attn_weights

class VisionTransformer(nn.Module):
    """Vision Transformer for sequence modeling"""
    def __init__(self, seq_len, embed_dim, num_heads, num_layers,
                 mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers

        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len, embed_dim) * 0.02)
        self.pos_dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """Forward pass through Vision Transformer"""
        B, N, C = x.shape

        # Add positional embedding
        x = x + self.pos_embed[:, :N, :]
        x = self.pos_dropout(x)

        # Process through transformer blocks
        attn_weights_list = []
        for block in self.blocks:
            x, attn_weights = block(x)
            attn_weights_list.append(attn_weights)

        x = self.norm(x)

        return x, attn_weights_list

class CrossModalAttention(nn.Module):
    """Cross-modal attention for fusing ConvLSTM and ViT features"""
    def __init__(self, conv_dim, vit_dim, fusion_dim, num_heads=8):
        super().__init__()

        self.conv_dim = conv_dim
        self.vit_dim = vit_dim
        self.fusion_dim = fusion_dim
        self.num_heads = num_heads

        # Project inputs to same dimension
        self.conv_proj = nn.Linear(conv_dim, fusion_dim)
        self.vit_proj = nn.Linear(vit_dim, fusion_dim)

        # Cross-attention layers
        self.conv_to_vit_attn = nn.MultiheadAttention(fusion_dim, num_heads, batch_first=True)
        self.vit_to_conv_attn = nn.MultiheadAttention(fusion_dim, num_heads, batch_first=True)

        # Fusion layers
        self.fusion_norm = nn.LayerNorm(fusion_dim * 2)
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, fusion_dim)
        )

    def forward(self, conv_features, vit_features):
        """Cross-modal attention fusion"""
        # Project to same dimension
        conv_proj = self.conv_proj(conv_features)
        vit_proj = self.vit_proj(vit_features)

        # Cross-attention: ConvLSTM features attend to ViT features
        conv_attended, _ = self.conv_to_vit_attn(conv_proj, vit_proj, vit_proj)

        # Cross-attention: ViT features attend to ConvLSTM features
        vit_attended, _ = self.vit_to_conv_attn(vit_proj, conv_proj, conv_proj)

        # Concatenate and fuse
        combined = torch.cat([conv_attended, vit_attended], dim=-1)
        combined = self.fusion_norm(combined)
        fused = self.fusion_mlp(combined)

        # Residual connection
        fused = fused + conv_proj + vit_proj

        return fused

class ConvLSTM_ViT_Hybrid(nn.Module):
    """Hybrid model combining ConvLSTM and Vision Transformer"""
    def __init__(self, input_dim=69, hidden_dim=128, num_classes=10,
                 seq_len=150, dropout=0.2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.seq_len = seq_len

        # Input preprocessing
        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # ConvLSTM branch for local temporal-spatial patterns
        self.convlstm = ConvLSTM(
            input_dim=hidden_dim,
            hidden_dims=[hidden_dim, hidden_dim],
            kernel_sizes=[3, 5],
            num_layers=2,
            bidirectional=True,
            dropout=dropout
        )

        # Vision Transformer branch for global sequential patterns
        self.vit = VisionTransformer(
            seq_len=seq_len,
            embed_dim=hidden_dim,
            num_heads=8,
            num_layers=4,
            mlp_ratio=4.0,
            dropout=dropout
        )

        # Cross-modal attention fusion
        conv_output_dim = hidden_dim * 2  # Bidirectional
        self.cross_modal_fusion = CrossModalAttention(
            conv_dim=conv_output_dim,
            vit_dim=hidden_dim,
            fusion_dim=hidden_dim,
            num_heads=8
        )

        # Adaptive pooling strategies
        self.adaptive_pool = nn.ModuleList([
            nn.AdaptiveAvgPool1d(1),
            nn.AdaptiveMaxPool1d(1),
            nn.AdaptiveAvgPool1d(4)
        ])

        # Attention-based pooling selection
        self.pool_attention = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 3),
            nn.Softmax(dim=-1)
        )

        # Classification head with uncertainty estimation
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.GELU(),
            nn.Dropout(dropout * 0.3),
            nn.Linear(hidden_dim // 4, num_classes)
        )

        # Uncertainty head for confidence estimation
        self.uncertainty_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 4),
            nn.GELU(),
            nn.Linear(hidden_dim // 4, 1),
            nn.Sigmoid()
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize model weights using appropriate strategies"""
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Conv1d):
                nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x, return_features=False):
        """Forward pass through hybrid model"""
        batch_size, seq_len, _ = x.shape

        # Input preprocessing
        x_proj = self.input_projection(x)

        # ConvLSTM branch for local patterns
        conv_outputs, _ = self.convlstm(x_proj)
        conv_features = conv_outputs[-1].transpose(1, 2)  # [batch, seq_len, conv_dim]

        # Vision Transformer branch for global patterns
        vit_features, vit_attention_weights = self.vit(x_proj)

        # Cross-modal fusion
        fused_features = self.cross_modal_fusion(conv_features, vit_features)

        # Adaptive pooling
        pooled_features = []
        for pool in self.adaptive_pool:
            if pool.output_size == 1:
                pooled = pool(fused_features.transpose(1, 2)).squeeze(-1)
            else:
                pooled = pool(fused_features.transpose(1, 2)).transpose(1, 2).mean(dim=1)
            pooled_features.append(pooled)

        # Attention-weighted pooling
        all_pooled = torch.cat(pooled_features, dim=-1)
        pool_weights = self.pool_attention(all_pooled)

        final_features = sum(w.unsqueeze(-1) * feat for w, feat in zip(pool_weights.unbind(-1), pooled_features))

        # Classification and uncertainty
        logits = self.classifier(final_features)
        uncertainty = self.uncertainty_head(final_features)

        if return_features:
            return logits, uncertainty, {
                'conv_features': conv_features,
                'vit_features': vit_features,
                'fused_features': fused_features,
                'vit_attention': vit_attention_weights,
                'pool_weights': pool_weights
            }

        return logits, uncertainty

class HybridModelTrainer:
    """Trainer for ConvLSTM-ViT Hybrid model"""
    def __init__(self, model, class_weights=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(device)

        # Loss functions
        if class_weights is not None:
            class_weights = class_weights.to(device)

        self.classification_criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
        self.uncertainty_criterion = nn.MSELoss()

        # Advanced optimizer with component-specific learning rates
        convlstm_params = []
        vit_params = []
        fusion_params = []
        classifier_params = []

        for name, param in model.named_parameters():
            if 'convlstm' in name:
                convlstm_params.append(param)
            elif 'vit' in name:
                vit_params.append(param)
            elif 'cross_modal' in name or 'pool' in name:
                fusion_params.append(param)
            else:
                classifier_params.append(param)

        self.optimizer = optim.AdamW([
            {'params': convlstm_params, 'lr': 5e-4, 'weight_decay': 0.01},
            {'params': vit_params, 'lr': 3e-4, 'weight_decay': 0.005},
            {'params': fusion_params, 'lr': 8e-4, 'weight_decay': 0.01},
            {'params': classifier_params, 'lr': 1e-3, 'weight_decay': 0.01}
        ], betas=(0.9, 0.95))

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=20, T_mult=2, eta_min=1e-6
        )

        # Metrics tracking
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def train_epoch(self, dataloader):
        """Training epoch with advanced loss computation"""
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []

        for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc="Training", leave=False)):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            logits, uncertainty = self.model(data)

            # Classification loss
            clf_loss = self.classification_criterion(logits, target)

            # Uncertainty loss
            pred_probs = F.softmax(logits, dim=1)
            pred_confidence = pred_probs.max(dim=1)[0]
            uncertainty_target = 1.0 - pred_confidence
            unc_loss = self.uncertainty_criterion(uncertainty.squeeze(), uncertainty_target.detach())

            # Combined loss
            total_loss_batch = clf_loss + 0.1 * unc_loss

            # L2 regularization
            l2_reg = sum(param.pow(2).sum() for param in self.model.parameters())
            total_loss_batch += 1e-5 * l2_reg

            total_loss_batch.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            # Statistics
            total_loss += total_loss_batch.item()
            _, predicted = logits.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')

        return avg_loss, accuracy, f1, all_preds, all_targets

    def validate(self, dataloader):
        """Validation with comprehensive metrics"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []
        all_uncertainties = []

        with torch.no_grad():
            for data, target in tqdm(dataloader, desc="Validating", leave=False):
                data, target = data.to(self.device), target.to(self.device)

                logits, uncertainty = self.model(data)
                loss = self.classification_criterion(logits, target)

                total_loss += loss.item()
                _, predicted = logits.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(target.cpu().numpy())
                all_uncertainties.extend(uncertainty.cpu().numpy().flatten())

        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        f1 = f1_score(all_targets, all_preds, average='weighted')
        precision = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)

        return avg_loss, accuracy, f1, precision, recall, all_preds, all_targets, all_uncertainties

    def train(self, train_loader, val_loader, epochs=100):
        """Complete training loop with advanced monitoring"""
        best_val_f1 = 0
        patience = 20
        patience_counter = 0

        for epoch in range(epochs):
            # Training
            train_loss, train_acc, train_f1, _, _ = self.train_epoch(train_loader)

            # Validation
            val_loss, val_acc, val_f1, val_precision, val_recall, _, _, _ = self.validate(val_loader)

            # Learning rate scheduling
            self.scheduler.step()

            # Save metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)
            self.val_f1_scores.append(val_f1)

            # Save best model
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()
            else:
                patience_counter += 1

            # Early stopping
            if patience_counter >= patience:
                logger.info(f"      Early stopping at epoch {epoch + 1}")
                break

        # Load best model
        if hasattr(self, 'best_model_state'):
            self.model.load_state_dict(self.best_model_state)

        return best_val_f1

    def reset_model(self):
        """Reset model weights for new fold"""
        self.model.apply(self._reset_weights)

        # Reset optimizer
        convlstm_params = []
        vit_params = []
        fusion_params = []
        classifier_params = []

        for name, param in self.model.named_parameters():
            if 'convlstm' in name:
                convlstm_params.append(param)
            elif 'vit' in name:
                vit_params.append(param)
            elif 'cross_modal' in name or 'pool' in name:
                fusion_params.append(param)
            else:
                classifier_params.append(param)

        self.optimizer = optim.AdamW([
            {'params': convlstm_params, 'lr': 5e-4, 'weight_decay': 0.01},
            {'params': vit_params, 'lr': 3e-4, 'weight_decay': 0.005},
            {'params': fusion_params, 'lr': 8e-4, 'weight_decay': 0.01},
            {'params': classifier_params, 'lr': 1e-3, 'weight_decay': 0.01}
        ], betas=(0.9, 0.95))

        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=20, T_mult=2, eta_min=1e-6
        )

        # Reset metrics
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.val_f1_scores = []

    def _reset_weights(self, m):
        """Reset weights for layers"""
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

def analyze_kfold_results(fold_results, class_names):
    """Comprehensive analysis of K-fold results"""

    # Extract metrics
    val_accuracies = [fold['val_accuracy'] for fold in fold_results]
    val_f1_scores = [fold['val_f1'] for fold in fold_results]
    val_precisions = [fold['val_precision'] for fold in fold_results]
    val_recalls = [fold['val_recall'] for fold in fold_results]
    test_accuracies = [fold['test_accuracy'] for fold in fold_results]
    test_f1_scores = [fold['test_f1'] for fold in fold_results]
    test_precisions = [fold['test_precision'] for fold in fold_results]
    test_recalls = [fold['test_recall'] for fold in fold_results]

    # Calculate statistics
    val_acc_mean, val_acc_std = np.mean(val_accuracies), np.std(val_accuracies)
    val_f1_mean, val_f1_std = np.mean(val_f1_scores), np.std(val_f1_scores)
    val_prec_mean, val_prec_std = np.mean(val_precisions), np.std(val_precisions)
    val_rec_mean, val_rec_std = np.mean(val_recalls), np.std(val_recalls)

    test_acc_mean, test_acc_std = np.mean(test_accuracies), np.std(test_accuracies)
    test_f1_mean, test_f1_std = np.mean(test_f1_scores), np.std(test_f1_scores)
    test_prec_mean, test_prec_std = np.mean(test_precisions), np.std(test_precisions)
    test_rec_mean, test_rec_std = np.mean(test_recalls), np.std(test_recalls)

    logger.info(f"\n🎯 ENHANCED 5-FOLD CROSS VALIDATION RESULTS")
    logger.info("="*80)
    logger.info(f"📊 VALIDATION METRICS:")
    logger.info(f"   - Accuracy: {val_acc_mean:.2f}% ± {val_acc_std:.2f}%")
    logger.info(f"   - Precision: {val_prec_mean:.4f} ± {val_prec_std:.4f}")
    logger.info(f"   - Recall: {val_rec_mean:.4f} ± {val_rec_std:.4f}")
    logger.info(f"   - F1 Score: {val_f1_mean:.4f} ± {val_f1_std:.4f}")
    logger.info(f"   - Range: [{min(val_accuracies):.2f}% - {max(val_accuracies):.2f}%]")

    logger.info(f"\n📊 TEST METRICS:")
    logger.info(f"   - Accuracy: {test_acc_mean:.2f}% ± {test_acc_std:.2f}%")
    logger.info(f"   - Precision: {test_prec_mean:.4f} ± {test_prec_std:.4f}")
    logger.info(f"   - Recall: {test_rec_mean:.4f} ± {test_rec_std:.4f}")
    logger.info(f"   - F1 Score: {test_f1_mean:.4f} ± {test_f1_std:.4f}")
    logger.info(f"   - Range: [{min(test_accuracies):.2f}% - {max(test_accuracies):.2f}%]")

    # Performance assessment
    if test_f1_mean >= 0.85:
        performance_status = "🏆 Outstanding Performance!"
    elif test_f1_mean >= 0.80:
        performance_status = "🎉 Excellent Performance!"
    elif test_f1_mean >= 0.70:
        performance_status = "✅ Good Performance!"
    elif test_f1_mean >= 0.60:
        performance_status = "👍 Satisfactory Performance"
    else:
        performance_status = "⚠️ Performance Needs Improvement"

    logger.info(f"\n🎯 Overall Assessment: {performance_status}")

    # Stability analysis
    cv_coefficient_test_f1 = test_f1_std / test_f1_mean if test_f1_mean > 0 else float('inf')
    if cv_coefficient_test_f1 < 0.05:
        stability_status = "🔒 Extremely Stable"
    elif cv_coefficient_test_f1 < 0.1:
        stability_status = "🔒 Very Stable"
    elif cv_coefficient_test_f1 < 0.2:
        stability_status = "✅ Stable"
    else:
        stability_status = "⚠️ Variable"

    logger.info(f"📊 Model Stability: {stability_status} (CV = {cv_coefficient_test_f1:.3f})")

    return {
        'val_accuracy_mean': val_acc_mean,
        'val_accuracy_std': val_acc_std,
        'val_precision_mean': val_prec_mean,
        'val_precision_std': val_prec_std,
        'val_recall_mean': val_rec_mean,
        'val_recall_std': val_rec_std,
        'val_f1_mean': val_f1_mean,
        'val_f1_std': val_f1_std,
        'test_accuracy_mean': test_acc_mean,
        'test_accuracy_std': test_acc_std,
        'test_precision_mean': test_prec_mean,
        'test_precision_std': test_prec_std,
        'test_recall_mean': test_rec_mean,
        'test_recall_std': test_rec_std,
        'test_f1_mean': test_f1_mean,
        'test_f1_std': test_f1_std,
        'stability_coefficient': cv_coefficient_test_f1,
        'performance_status': performance_status,
        'stability_status': stability_status
    }

def plot_individual_fold_confusion_matrices(fold_results, class_names):
    """Plot individual confusion matrices for each fold"""
    n_folds = len(fold_results)
    fig, axes = plt.subplots(2, n_folds, figsize=(5*n_folds, 10))

    if n_folds == 1:
        axes = axes.reshape(2, 1)

    for fold_idx, fold_result in enumerate(fold_results):
        # Validation confusion matrix
        cm_val = confusion_matrix(fold_result['val_targets'], fold_result['val_predictions'])
        sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names,
                    ax=axes[0, fold_idx], cbar_kws={'shrink': 0.8})
        axes[0, fold_idx].set_title(f'Fold {fold_idx + 1} - Validation CM\n'
                                   f'Acc: {fold_result["val_accuracy"]:.1f}%',
                                   fontsize=12, fontweight='bold')
        axes[0, fold_idx].set_ylabel('True Label' if fold_idx == 0 else '')
        axes[0, fold_idx].set_xlabel('')

        # Test confusion matrix
        cm_test = confusion_matrix(fold_result['test_targets'], fold_result['test_predictions'])
        sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                    xticklabels=class_names, yticklabels=class_names,
                    ax=axes[1, fold_idx], cbar_kws={'shrink': 0.8})
        axes[1, fold_idx].set_title(f'Fold {fold_idx + 1} - Test CM\n'
                                   f'Acc: {fold_result["test_accuracy"]:.1f}%',
                                   fontsize=12, fontweight='bold')
        axes[1, fold_idx].set_ylabel('True Label' if fold_idx == 0 else '')
        axes[1, fold_idx].set_xlabel('Predicted Label')

        # Rotate x-axis labels
        axes[0, fold_idx].tick_params(axis='x', rotation=45)
        axes[1, fold_idx].tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('individual_fold_confusion_matrices.png', dpi=150, bbox_inches='tight')
    plt.show()

def plot_kfold_results(fold_results, model_name="ConvLSTM_ViT_Hybrid"):
    """Plot comprehensive K-fold cross-validation results"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

    # Extract metrics
    fold_nums = list(range(1, len(fold_results) + 1))
    val_accuracies = [fold['val_accuracy'] for fold in fold_results]
    val_f1_scores = [fold['val_f1'] for fold in fold_results]
    test_accuracies = [fold['test_accuracy'] for fold in fold_results]
    test_f1_scores = [fold['test_f1'] for fold in fold_results]
    val_precision = [fold['val_precision'] for fold in fold_results]
    val_recall = [fold['val_recall'] for fold in fold_results]
    test_precision = [fold['test_precision'] for fold in fold_results]
    test_recall = [fold['test_recall'] for fold in fold_results]

    # Validation Accuracy across folds
    bars1 = ax1.bar(fold_nums, val_accuracies, alpha=0.7, color='blue', label='Validation')
    ax1.axhline(y=np.mean(val_accuracies), color='red', linestyle='--',
                label=f'Mean: {np.mean(val_accuracies):.2f}±{np.std(val_accuracies):.2f}%')
    ax1.set_title(f'{model_name} - Validation Accuracy by Fold', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Fold')
    ax1.set_ylabel('Accuracy (%)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, acc in zip(bars1, val_accuracies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{acc:.1f}%', ha='center', va='bottom', fontsize=10)

    # Test Accuracy across folds
    bars2 = ax2.bar(fold_nums, test_accuracies, alpha=0.7, color='green', label='Test')
    ax2.axhline(y=np.mean(test_accuracies), color='red', linestyle='--',
                label=f'Mean: {np.mean(test_accuracies):.2f}±{np.std(test_accuracies):.2f}%')
    ax2.set_title(f'{model_name} - Test Accuracy by Fold', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Fold')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Add value labels on bars
    for bar, acc in zip(bars2, test_accuracies):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{acc:.1f}%', ha='center', va='bottom', fontsize=10)

    # All metrics comparison
    x = np.arange(len(fold_nums))
    width = 0.15

    ax3.bar(x - width*2, val_accuracies, width, label='Val Acc', alpha=0.8, color='lightblue')
    ax3.bar(x - width, test_accuracies, width, label='Test Acc', alpha=0.8, color='lightgreen')
    ax3.bar(x, val_f1_scores, width, label='Val F1', alpha=0.8, color='blue')
    ax3.bar(x + width, test_f1_scores, width, label='Test F1', alpha=0.8, color='green')
    ax3.bar(x + width*2, [f*100 for f in val_precision], width, label='Val Prec', alpha=0.8, color='orange')

    ax3.set_title(f'{model_name} - All Metrics by Fold', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Fold')
    ax3.set_ylabel('Score')
    ax3.set_xticks(x)
    ax3.set_xticklabels(fold_nums)
    ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax3.grid(True, alpha=0.3)

    # Performance distribution violin plot
    data_for_violin = [val_accuracies, test_accuracies,
                      [f*100 for f in val_f1_scores], [f*100 for f in test_f1_scores]]
    labels_violin = ['Val Acc', 'Test Acc', 'Val F1×100', 'Test F1×100']

    parts = ax4.violinplot(data_for_violin, positions=range(len(labels_violin)),
                          showmeans=True, showmedians=True)
    ax4.set_xticks(range(len(labels_violin)))
    ax4.set_xticklabels(labels_violin)
    ax4.set_title(f'{model_name} - Performance Distribution', fontsize=14, fontweight='bold')
    ax4.set_ylabel('Score')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{model_name.lower()}_5fold_results_detailed.png', dpi=150, bbox_inches='tight')
    plt.show()

def create_fold_summary_table(fold_results, class_names):
    """Create comprehensive fold summary table"""

    # Create detailed results DataFrame
    summary_data = []
    for fold in fold_results:
        summary_data.append({
            'Fold': f"Fold {fold['fold']}",
            'Val_Accuracy': f"{fold['val_accuracy']:.2f}%",
            'Val_Precision': f"{fold['val_precision']:.4f}",
            'Val_Recall': f"{fold['val_recall']:.4f}",
            'Val_F1': f"{fold['val_f1']:.4f}",
            'Test_Accuracy': f"{fold['test_accuracy']:.2f}%",
            'Test_Precision': f"{fold['test_precision']:.4f}",
            'Test_Recall': f"{fold['test_recall']:.4f}",
            'Test_F1': f"{fold['test_f1']:.4f}",
            'Best_Val_F1': f"{fold['best_val_f1']:.4f}"
        })

    # Calculate statistics
    val_accs = [fold['val_accuracy'] for fold in fold_results]
    val_precs = [fold['val_precision'] for fold in fold_results]
    val_recalls = [fold['val_recall'] for fold in fold_results]
    val_f1s = [fold['val_f1'] for fold in fold_results]
    test_accs = [fold['test_accuracy'] for fold in fold_results]
    test_precs = [fold['test_precision'] for fold in fold_results]
    test_recalls = [fold['test_recall'] for fold in fold_results]
    test_f1s = [fold['test_f1'] for fold in fold_results]
    best_val_f1s = [fold['best_val_f1'] for fold in fold_results]

    # Add mean row
    summary_data.append({
        'Fold': 'Mean',
        'Val_Accuracy': f"{np.mean(val_accs):.2f}%",
        'Val_Precision': f"{np.mean(val_precs):.4f}",
        'Val_Recall': f"{np.mean(val_recalls):.4f}",
        'Val_F1': f"{np.mean(val_f1s):.4f}",
        'Test_Accuracy': f"{np.mean(test_accs):.2f}%",
        'Test_Precision': f"{np.mean(test_precs):.4f}",
        'Test_Recall': f"{np.mean(test_recalls):.4f}",
        'Test_F1': f"{np.mean(test_f1s):.4f}",
        'Best_Val_F1': f"{np.mean(best_val_f1s):.4f}"
    })

    # Add std row
    summary_data.append({
        'Fold': 'Std',
        'Val_Accuracy': f"±{np.std(val_accs):.2f}%",
        'Val_Precision': f"±{np.std(val_precs):.4f}",
        'Val_Recall': f"±{np.std(val_recalls):.4f}",
        'Val_F1': f"±{np.std(val_f1s):.4f}",
        'Test_Accuracy': f"±{np.std(test_accs):.2f}%",
        'Test_Precision': f"±{np.std(test_precs):.4f}",
        'Test_Recall': f"±{np.std(test_recalls):.4f}",
        'Test_F1': f"±{np.std(test_f1s):.4f}",
        'Best_Val_F1': f"±{np.std(best_val_f1s):.4f}"
    })

    # Create DataFrame
    results_df = pd.DataFrame(summary_data)

    # Display as a formatted table
    print(f"\n{'='*120}")
    print(f"{'🎯 DETAILED 5-FOLD CROSS VALIDATION RESULTS TABLE':^120}")
    print(f"{'='*120}")
    print(results_df.to_string(index=False))
    print(f"{'='*120}")

    return results_df

def create_aggregated_confusion_matrix(all_fold_predictions, class_names):
    """Create aggregated confusion matrix from all folds"""

    # Aggregate all predictions and targets
    all_val_preds = []
    all_val_targets = []
    all_test_preds = []
    all_test_targets = []

    for fold_pred in all_fold_predictions:
        all_val_preds.extend(fold_pred['val_preds'])
        all_val_targets.extend(fold_pred['val_targets'])
        all_test_preds.extend(fold_pred['test_preds'])
        all_test_targets.extend(fold_pred['test_targets'])

    # Plot aggregated confusion matrices
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Validation Confusion Matrix (Aggregated)
    cm_val = confusion_matrix(all_val_targets, all_val_preds)
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names, ax=ax1,
                cbar_kws={'shrink': 0.8})
    ax1.set_title('Aggregated Validation Confusion Matrix (5-Fold CV)',
                  fontsize=16, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12)
    ax1.set_xlabel('Predicted Label', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)

    # Test Confusion Matrix (Aggregated)
    cm_test = confusion_matrix(all_test_targets, all_test_preds)
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=class_names, yticklabels=class_names, ax=ax2,
                cbar_kws={'shrink': 0.8})
    ax2.set_title('Aggregated Test Confusion Matrix (5-Fold CV)',
                  fontsize=16, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12)
    ax2.set_xlabel('Predicted Label', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig('5fold_aggregated_confusion_matrices_enhanced.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print aggregated classification reports
    print(f"\n📋 AGGREGATED VALIDATION Classification Report (5-Fold CV):")
    print("="*90)
    print(classification_report(all_val_targets, all_val_preds,
                               target_names=class_names, digits=4))

    print(f"\n📋 AGGREGATED TEST Classification Report (5-Fold CV):")
    print("="*90)
    print(classification_report(all_test_targets, all_test_preds,
                               target_names=class_names, digits=4))

def plot_fold_training_curves(fold_trainers, fold=None):
    """Plot training curves for specific fold or all folds"""
    if fold is not None:
        # Plot single fold
        trainer = fold_trainers[fold]
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

        epochs = range(1, len(trainer.train_losses) + 1)

        # Loss curves
        ax1.plot(epochs, trainer.train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.plot(epochs, trainer.val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax1.set_title(f'Fold {fold+1} - Loss Curves', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Accuracy curves
        ax2.plot(epochs, trainer.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
        ax2.plot(epochs, trainer.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
        ax2.set_title(f'Fold {fold+1} - Accuracy Curves', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # F1 Score progression
        ax3.plot(epochs, trainer.val_f1_scores, 'green', linewidth=2, label='Validation F1')
        ax3.set_title(f'Fold {fold+1} - Validation F1 Score', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('F1 Score')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Train-validation gap analysis
        gaps = [t - v for t, v in zip(trainer.train_accuracies, trainer.val_accuracies)]
        ax4.plot(epochs, gaps, 'purple', linewidth=2, label='Train-Val Gap')
        ax4.axhline(y=10, color='orange', linestyle='--', label='Warning Threshold (10%)')
        ax4.axhline(y=5, color='green', linestyle='--', label='Good Threshold (5%)')
        ax4.axhline(y=0, color='black', linestyle='-', alpha=0.3)
        ax4.set_title(f'Fold {fold+1} - Overfitting Analysis', fontsize=14, fontweight='bold')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Accuracy Gap (%)')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'fold_{fold+1}_training_curves_detailed.png', dpi=150, bbox_inches='tight')
        plt.show()
    else:
        # Plot all folds together
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        colors = ['blue', 'red', 'green', 'orange', 'purple']

        for i, trainer in enumerate(fold_trainers):
            epochs = range(1, len(trainer.val_losses) + 1)
            ax1.plot(epochs, trainer.val_losses, color=colors[i],
                    label=f'Fold {i+1}', linewidth=2, alpha=0.7)
            ax2.plot(epochs, trainer.val_f1_scores, color=colors[i],
                    label=f'Fold {i+1}', linewidth=2, alpha=0.7)

        ax1.set_title('Validation Loss - All Folds', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        ax2.set_title('Validation F1 Score - All Folds', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('F1 Score')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig('all_folds_training_curves_comparison.png', dpi=150, bbox_inches='tight')
        plt.show()

def plot_uncertainty_analysis(fold_results, class_names):
    """Analyze prediction uncertainty across folds"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

    # Collect all uncertainties
    all_val_uncertainties = []
    all_test_uncertainties = []
    all_val_correct = []
    all_test_correct = []

    for fold in fold_results:
        val_correct = [1 if p == t else 0 for p, t in zip(fold['val_predictions'], fold['val_targets'])]
        test_correct = [1 if p == t else 0 for p, t in zip(fold['test_predictions'], fold['test_targets'])]

        all_val_uncertainties.extend(fold['val_uncertainties'])
        all_test_uncertainties.extend(fold['test_uncertainties'])
        all_val_correct.extend(val_correct)
        all_test_correct.extend(test_correct)

    # Uncertainty distribution
    ax1.hist(all_val_uncertainties, bins=50, alpha=0.7, label='Validation', color='blue', density=True)
    ax1.hist(all_test_uncertainties, bins=50, alpha=0.7, label='Test', color='green', density=True)
    ax1.set_title('Prediction Uncertainty Distribution', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Uncertainty Score')
    ax1.set_ylabel('Density')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Uncertainty vs Accuracy (Validation)
    correct_uncertainties = [u for u, c in zip(all_val_uncertainties, all_val_correct) if c == 1]
    incorrect_uncertainties = [u for u, c in zip(all_val_uncertainties, all_val_correct) if c == 0]

    ax2.hist(correct_uncertainties, bins=30, alpha=0.7, label='Correct', color='green', density=True)
    ax2.hist(incorrect_uncertainties, bins=30, alpha=0.7, label='Incorrect', color='red', density=True)
    ax2.set_title('Validation: Uncertainty vs Prediction Correctness', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Uncertainty Score')
    ax2.set_ylabel('Density')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Uncertainty vs Accuracy (Test)
    test_correct_uncertainties = [u for u, c in zip(all_test_uncertainties, all_test_correct) if c == 1]
    test_incorrect_uncertainties = [u for u, c in zip(all_test_uncertainties, all_test_correct) if c == 0]

    ax3.hist(test_correct_uncertainties, bins=30, alpha=0.7, label='Correct', color='green', density=True)
    ax3.hist(test_incorrect_uncertainties, bins=30, alpha=0.7, label='Incorrect', color='red', density=True)
    ax3.set_title('Test: Uncertainty vs Prediction Correctness', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Uncertainty Score')
    ax3.set_ylabel('Density')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Uncertainty calibration
    uncertainty_bins = np.linspace(0, 1, 11)
    bin_acc_val = []
    bin_conf_val = []

    for i in range(len(uncertainty_bins) - 1):
        mask = (np.array(all_val_uncertainties) >= uncertainty_bins[i]) & \
               (np.array(all_val_uncertainties) < uncertainty_bins[i+1])
        if mask.sum() > 0:
            bin_acc = np.mean(np.array(all_val_correct)[mask])
            bin_conf = 1 - np.mean(np.array(all_val_uncertainties)[mask])  # Convert uncertainty to confidence
            bin_acc_val.append(bin_acc)
            bin_conf_val.append(bin_conf)
        else:
            bin_acc_val.append(0)
            bin_conf_val.append(0)

    ax4.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
    ax4.plot(bin_conf_val, bin_acc_val, 'ro-', label='Model Calibration')
    ax4.set_title('Uncertainty Calibration', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Confidence Score')
    ax4.set_ylabel('Accuracy')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.set_xlim([0, 1])
    ax4.set_ylim([0, 1])

    plt.tight_layout()
    plt.savefig('uncertainty_analysis_5fold.png', dpi=150, bbox_inches='tight')
    plt.show()

def run_5fold_cross_validation(sequences, labels, feature_count, config):
    """Run enhanced 5-fold cross-validation with 15% test, 15% val, 70% train split"""

    # First separate test set (15% of total data)
    X_remaining, X_test, y_remaining, y_test = train_test_split(
        sequences, labels, test_size=0.15, random_state=42, stratify=labels
    )

    # Then separate validation set (15% of total data) from remaining 85%
    # 15% out of 85% = 15/85 ≈ 0.176
    X_train, X_val, y_train, y_val = train_test_split(
        X_remaining, y_remaining, test_size=0.176, random_state=42, stratify=y_remaining
    )

    # Initialize StratifiedKFold for train data only (70% of total)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    fold_results = []
    fold_trainers = []
    all_fold_predictions = []

    logger.info(f"\n🎯 Starting Enhanced 5-Fold Cross Validation with 15%-15%-70% Split (150 Epochs)")
    logger.info(f"   - Total samples: {len(sequences):,}")
    logger.info(f"   - Train samples: {len(X_train):,} (70%)")
    logger.info(f"   - Validation samples: {len(X_val):,} (15%)")
    logger.info(f"   - Test samples: {len(X_test):,} (15%)")
    logger.info("="*80)

    # Create global datasets (validation and test are same for all folds)
    le_global = LabelEncoder()
    le_global.fit(labels)
    num_classes = len(le_global.classes_)

    # Pre-create validation and test datasets once
    val_dataset_base = EnhancedDiverSignDataset(
        X_val, y_val, train_mode=False, augment_prob=0.0
    )
    test_dataset_base = EnhancedDiverSignDataset(
        X_test, y_test,
        val_dataset_base.get_label_encoder(),
        val_dataset_base.get_scaler(),
        train_mode=False, augment_prob=0.0
    )

    for fold, (train_idx, _) in enumerate(skf.split(X_train, y_train)):
        logger.info(f"\n🔥 FOLD {fold + 1}/5")
        logger.info("="*50)

        # Use only training data for this fold (70% of total data gets split 5 ways)
        X_train_fold = X_train[train_idx]
        y_train_fold = [y_train[i] for i in train_idx]

        # Calculate actual percentages relative to total dataset
        train_percentage = (len(X_train_fold) / len(sequences)) * 100
        val_percentage = (len(X_val) / len(sequences)) * 100
        test_percentage = (len(X_test) / len(sequences)) * 100

        logger.info(f"   - Train: {len(X_train_fold):,} samples ({train_percentage:.1f}% of total)")
        logger.info(f"   - Val: {len(X_val):,} samples ({val_percentage:.1f}% of total)")
        logger.info(f"   - Test: {len(X_test):,} samples ({test_percentage:.1f}% of total)")

        # Create train dataset for current fold
        train_dataset = EnhancedDiverSignDataset(
            X_train_fold, y_train_fold,
            val_dataset_base.get_label_encoder(),
            val_dataset_base.get_scaler(),
            train_mode=True, augment_prob=0.6
        )

        # Use the same validation and test datasets for all folds
        val_dataset = EnhancedDiverSignDataset(
            X_val, y_val,
            val_dataset_base.get_label_encoder(),
            val_dataset_base.get_scaler(),
            train_mode=False, augment_prob=0.0
        )

        test_dataset = EnhancedDiverSignDataset(
            X_test, y_test,
            val_dataset_base.get_label_encoder(),
            val_dataset_base.get_scaler(),
            train_mode=False, augment_prob=0.0
        )

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'],
                                 shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=config['batch_size'],
                               shuffle=False, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=config['batch_size'],
                                shuffle=False, num_workers=2, pin_memory=True)

        # Create model for current fold
        model = ConvLSTM_ViT_Hybrid(
            input_dim=feature_count,
            hidden_dim=config['hidden_dim'],
            num_classes=num_classes,
            seq_len=config['sequence_length'],
            dropout=config['dropout']
        )

        # Create trainer
        trainer = HybridModelTrainer(
            model=model,
            class_weights=train_dataset.get_class_weights()
        )

        logger.info(f"   - Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

        # Train model
        logger.info(f"   🚀 Training Fold {fold + 1} for {config['epochs']} epochs...")
        best_val_f1 = trainer.train(train_loader, val_loader, epochs=config['epochs'])

        # Evaluate on validation set
        val_loss, val_acc, val_f1, val_precision, val_recall, val_preds, val_targets, val_uncertainties = trainer.validate(val_loader)

        # Evaluate on test set
        test_loss, test_acc, test_f1, test_precision, test_recall, test_preds, test_targets, test_uncertainties = trainer.validate(test_loader)

        # Store results
        fold_result = {
            'fold': fold + 1,
            'train_size': len(X_train_fold),
            'val_size': len(X_val),
            'test_size': len(X_test),
            'train_percentage': train_percentage,
            'val_percentage': val_percentage,
            'test_percentage': test_percentage,
            'val_accuracy': val_acc,
            'val_precision': val_precision,
            'val_recall': val_recall,
            'val_f1': val_f1,
            'test_accuracy': test_acc,
            'test_precision': test_precision,
            'test_recall': test_recall,
            'test_f1': test_f1,
            'best_val_f1': best_val_f1,
            'val_predictions': val_preds,
            'val_targets': val_targets,
            'test_predictions': test_preds,
            'test_targets': test_targets,
            'val_uncertainties': val_uncertainties,
            'test_uncertainties': test_uncertainties
        }

        fold_results.append(fold_result)
        fold_trainers.append(trainer)
        all_fold_predictions.append({
            'val_preds': val_preds,
            'val_targets': val_targets,
            'test_preds': test_preds,
            'test_targets': test_targets
        })

        logger.info(f"   ✅ Fold {fold + 1} Results:")
        logger.info(f"      - Validation: Acc={val_acc:.2f}%, Prec={val_precision:.4f}, Rec={val_recall:.4f}, F1={val_f1:.4f}")
        logger.info(f"      - Test: Acc={test_acc:.2f}%, Prec={test_precision:.4f}, Rec={test_recall:.4f}, F1={test_f1:.4f}")
        logger.info(f"      - Best Val F1: {best_val_f1:.4f}")

    return fold_results, fold_trainers, all_fold_predictions, le_global.classes_

def main():
    set_seed(42)

    # Configuration parameters with 150 epochs
    config = {
        'csv_path': "",  # Update with your dataset path
        'sequence_length': 150,
        'batch_size': 16,
        'epochs': 150,           # Increased to 150 epochs
        'hidden_dim': 128,
        'dropout': 0.25
    }

    logger.info("🚀 Enhanced ConvLSTM + Vision Transformer Hybrid Model - 150 Epochs Training!")
    logger.info(f"   Training Configuration: {config['epochs']} epochs per fold")
    logger.info("="*80)

    # Enhanced data preparation
    sequences, labels, feature_count = prepare_dataset_for_hybrid(
        config['csv_path'], config['sequence_length'], overlap_ratio=0.35
    )

    if len(sequences) == 0:
        logger.error("❌ No valid sequences found!")
        return

    logger.info(f"\n📊 Dataset Summary:")
    logger.info(f"   - Input Features: {feature_count}")
    logger.info(f"   - Sequence Length: {config['sequence_length']}")
    logger.info(f"   - Total sequences: {len(sequences):,}")
    logger.info(f"   - Classes: {len(set(labels))}")
    logger.info(f"   - Class distribution: {dict(Counter(labels))}")
    logger.info(f"\n📊 Data Split Strategy:")
    logger.info(f"   - Train Set: 70% of total data (split into 5 folds)")
    logger.info(f"   - Validation Set: 15% of total data (fixed across all folds)")
    logger.info(f"   - Test Set: 15% of total data (fixed across all folds)")
    logger.info(f"   - Per fold: ~56% train (70%/5), 15% validation, 15% test")
    logger.info(f"   - Training Duration: {config['epochs']} epochs per fold")

    # Run enhanced 5-fold cross validation with 15%-15%-70% split
    fold_results, fold_trainers, all_fold_predictions, class_names = run_5fold_cross_validation(
        sequences, labels, feature_count, config
    )

    # Comprehensive analysis
    summary_stats = analyze_kfold_results(fold_results, class_names)

    # Generate comprehensive visualizations
    logger.info(f"\n📊 Generating Comprehensive Visualizations...")

    # 1. Individual fold confusion matrices
    plot_individual_fold_confusion_matrices(fold_results, class_names)

    # 2. Enhanced K-fold results
    plot_kfold_results(fold_results, "ConvLSTM_ViT_Hybrid_150_Epochs")

    # 3. Training curves for all folds
    plot_fold_training_curves(fold_trainers)

    # 4. Best performing fold training curves
    best_fold_idx = np.argmax([fold['test_f1'] for fold in fold_results])
    logger.info(f"   📈 Best performing fold: {best_fold_idx + 1} (Test F1: {fold_results[best_fold_idx]['test_f1']:.4f})")
    plot_fold_training_curves(fold_trainers, fold=best_fold_idx)

    # 5. Aggregated confusion matrices
    create_aggregated_confusion_matrix(all_fold_predictions, class_names)

    # 6. Uncertainty analysis
    plot_uncertainty_analysis(fold_results, class_names)

    # 7. Create comprehensive fold summary table
    results_df = create_fold_summary_table(fold_results, class_names)

    # Save comprehensive results
    logger.info(f"\n💾 Saving Comprehensive Results...")

    # Save best model from best fold
    best_trainer = fold_trainers[best_fold_idx]
    torch.save({
        'fold_results': fold_results,
        'summary_statistics': summary_stats,
        'best_fold': best_fold_idx + 1,
        'best_model_state_dict': best_trainer.model.state_dict(),
        'model_config': {
            'input_dim': feature_count,
            'hidden_dim': config['hidden_dim'],
            'num_classes': len(class_names),
            'seq_len': config['sequence_length'],
            'dropout': config['dropout']
        },
        'training_config': config,
        'class_names': class_names,
        'split_info': {
            'train_split': 0.70,
            'val_split': 0.15,
            'test_split': 0.15,
            'epochs': config['epochs'],
            'description': '70% train (5-fold), 15% val (fixed), 15% test (fixed) - 150 epochs'
        }
    }, 'enhanced_convlstm_vit_hybrid_150_epochs_complete_results.pth')

    # Save detailed results to CSV
    results_df.to_csv('enhanced_convlstm_vit_hybrid_150_epochs_detailed_results.csv', index=False)

    # Create additional detailed analysis with split information
    detailed_results = []
    for fold in fold_results:
        detailed_results.append({
            'Fold': fold['fold'],
            'Epochs': config['epochs'],
            'Train_Size': fold['train_size'],
            'Val_Size': fold['val_size'],
            'Test_Size': fold['test_size'],
            'Train_Percentage': fold['train_percentage'],
            'Val_Percentage': fold['val_percentage'],
            'Test_Percentage': fold['test_percentage'],
            'Val_Accuracy': fold['val_accuracy'],
            'Val_Precision': fold['val_precision'],
            'Val_Recall': fold['val_recall'],
            'Val_F1': fold['val_f1'],
            'Test_Accuracy': fold['test_accuracy'],
            'Test_Precision': fold['test_precision'],
            'Test_Recall': fold['test_recall'],
            'Test_F1': fold['test_f1'],
            'Best_Val_F1': fold['best_val_f1'],
            'Avg_Val_Uncertainty': np.mean(fold['val_uncertainties']),
            'Avg_Test_Uncertainty': np.mean(fold['test_uncertainties']),
            'Val_Uncertainty_Std': np.std(fold['val_uncertainties']),
            'Test_Uncertainty_Std': np.std(fold['test_uncertainties'])
        })

    detailed_df = pd.DataFrame(detailed_results)
    detailed_df.to_csv('enhanced_fold_analysis_150_epochs_with_uncertainty.csv', index=False)

    # Statistical significance analysis
    if len(fold_results) >= 5:
        test_scores = [fold['test_f1'] for fold in fold_results]

        # One-sample t-test against a baseline
        baseline = 0.5  # Random performance
        t_stat, p_value = stats.ttest_1samp(test_scores, baseline)

        logger.info(f"\n📊 Statistical Significance Analysis:")
        logger.info(f"   - T-statistic vs random (0.5): {t_stat:.4f}")
        logger.info(f"   - P-value: {p_value:.6f}")

        if p_value < 0.001:
            significance = "Highly Significant (p < 0.001) 🎯"
        elif p_value < 0.01:
            significance = "Significant (p < 0.01) ✅"
        elif p_value < 0.05:
            significance = "Marginally Significant (p < 0.05) 👍"
        else:
            significance = "Not Significant (p ≥ 0.05) ⚠️"

        logger.info(f"   - Significance vs Random: {significance}")

        # Confidence intervals
        confidence_level = 0.95
        alpha = 1 - confidence_level
        n = len(test_scores)
        mean_score = np.mean(test_scores)
        sem = stats.sem(test_scores)
        t_critical = stats.t.ppf(1 - alpha/2, n-1)
        ci_lower = mean_score - t_critical * sem
        ci_upper = mean_score + t_critical * sem

        logger.info(f"   - 95% Confidence Interval: [{ci_lower:.4f}, {ci_upper:.4f}]")

    # Final comprehensive summary
    logger.info(f"\n🏆 FINAL COMPREHENSIVE 5-FOLD CROSS VALIDATION SUMMARY (150 Epochs):")
    logger.info("="*80)
    logger.info(f"📊 Training Configuration:")
    logger.info(f"   - Epochs per fold: {config['epochs']}")
    logger.info(f"   - Total training epochs: {config['epochs'] * 5}")
    logger.info(f"   - Batch size: {config['batch_size']}")
    logger.info(f"   - Hidden dimension: {config['hidden_dim']}")
    logger.info(f"   - Dropout rate: {config['dropout']}")
    logger.info(f"📊 Data Split Summary:")
    logger.info(f"   - Train Set (per fold): {np.mean([fold['train_size'] for fold in fold_results]):.0f} samples ({np.mean([fold['train_percentage'] for fold in fold_results]):.1f}%)")
    logger.info(f"   - Validation Set: {fold_results[0]['val_size']:,} samples (15.0%)")
    logger.info(f"   - Test Set: {fold_results[0]['test_size']:,} samples (15.0%)")
    logger.info(f"🎯 Average Test Performance:")
    logger.info(f"   - Accuracy: {summary_stats['test_accuracy_mean']:.2f}% ± {summary_stats['test_accuracy_std']:.2f}%")
    logger.info(f"   - Precision: {summary_stats['test_precision_mean']:.4f} ± {summary_stats['test_precision_std']:.4f}")
    logger.info(f"   - Recall: {summary_stats['test_recall_mean']:.4f} ± {summary_stats['test_recall_std']:.4f}")
    logger.info(f"   - F1 Score: {summary_stats['test_f1_mean']:.4f} ± {summary_stats['test_f1_std']:.4f}")
    logger.info(f"🔒 Model Stability: {summary_stats['stability_status']}")
    logger.info(f"🎖️ Performance Level: {summary_stats['performance_status']}")
    logger.info(f"📁 Best Fold: {best_fold_idx + 1} (Test F1: {fold_results[best_fold_idx]['test_f1']:.4f})")

    logger.info(f"\n✅ Enhanced 5-Fold Cross Validation with 150 Epochs Completed Successfully!")

    # Generate Model Architecture Documentation
    logger.info(f"\n🏗️ GENERATING MODEL ARCHITECTURE DOCUMENTATION")
    logger.info("="*80)

    # Get the best model for architecture analysis
    best_model = fold_trainers[best_fold_idx].model

    # Generate complete architecture documentation
    generate_complete_architecture_documentation(
        best_model,
        input_shape=(config['batch_size'], config['sequence_length'], feature_count)
    )

    return fold_results, fold_trainers, all_fold_predictions, summary_stats, class_names

def create_model_architecture_table(model, input_shape=(16, 150, 69)):
    """Create detailed architecture table for ConvLSTM-ViT Hybrid model"""
    logger.info("🏗️ Creating ConvLSTM-ViT Hybrid Model Architecture Table")
    logger.info("="*80)

    try:
        # Try to import required packages
        from torchinfo import summary as detailed_summary
        torchinfo_available = True
    except ImportError:
        logger.warning("⚠️ torchinfo not available. Install with: pip install torchinfo")
        torchinfo_available = False

    # Method 1: Using torchinfo for detailed summary (if available)
    if torchinfo_available:
        logger.info("\n📊 DETAILED MODEL SUMMARY:")
        logger.info("-"*60)
        try:
            model_stats = detailed_summary(
                model,
                input_size=input_shape,
                col_names=["input_size", "output_size", "num_params", "trainable"],
                verbose=0
            )
            logger.info(str(model_stats))
        except Exception as e:
            logger.warning(f"Detailed summary failed: {e}")

    # Method 2: Manual architecture breakdown
    logger.info(f"\n📋 COMPONENT-WISE ARCHITECTURE BREAKDOWN:")
    logger.info("-"*60)

    architecture_data = []
    total_params = 0

    # Input Processing Layer
    input_proj_params = sum(p.numel() for p in model.input_projection.parameters())
    architecture_data.append({
        'Component': 'Input Projection',
        'Layer Type': 'Linear + LayerNorm + GELU + Dropout',
        'Input Shape': f'[B, {input_shape[1]}, {input_shape[2]}]',
        'Output Shape': f'[B, {input_shape[1]}, 128]',
        'Parameters': f'{input_proj_params:,}',
        'Description': 'Projects input features to hidden dimension'
    })
    total_params += input_proj_params

    # ConvLSTM Branch
    convlstm_params = sum(p.numel() for p in model.convlstm.parameters())
    architecture_data.append({
        'Component': 'ConvLSTM Branch',
        'Layer Type': 'Bidirectional ConvLSTM (2 layers)',
        'Input Shape': f'[B, {input_shape[1]}, 128]',
        'Output Shape': f'[B, {input_shape[1]}, 256]',
        'Parameters': f'{convlstm_params:,}',
        'Description': 'Local temporal-spatial pattern extraction'
    })
    total_params += convlstm_params

    # Vision Transformer Branch
    vit_params = sum(p.numel() for p in model.vit.parameters())
    architecture_data.append({
        'Component': 'Vision Transformer',
        'Layer Type': 'Multi-head Attention (4 layers, 8 heads)',
        'Input Shape': f'[B, {input_shape[1]}, 128]',
        'Output Shape': f'[B, {input_shape[1]}, 128]',
        'Parameters': f'{vit_params:,}',
        'Description': 'Global sequential dependency modeling'
    })
    total_params += vit_params

    # Cross-Modal Fusion
    fusion_params = sum(p.numel() for p in model.cross_modal_fusion.parameters())
    architecture_data.append({
        'Component': 'Cross-Modal Fusion',
        'Layer Type': 'Cross-Attention + MLP',
        'Input Shape': '[B, 150, 256] + [B, 150, 128]',
        'Output Shape': f'[B, {input_shape[1]}, 128]',
        'Parameters': f'{fusion_params:,}',
        'Description': 'Fuses ConvLSTM and ViT features'
    })
    total_params += fusion_params

    # Adaptive Pooling
    pooling_params = sum(p.numel() for p in model.adaptive_pool.parameters()) + \
                     sum(p.numel() for p in model.pool_attention.parameters())
    architecture_data.append({
        'Component': 'Adaptive Pooling',
        'Layer Type': 'Multi-strategy Pooling + Attention',
        'Input Shape': f'[B, {input_shape[1]}, 128]',
        'Output Shape': '[B, 128]',
        'Parameters': f'{pooling_params:,}',
        'Description': 'Attention-weighted feature aggregation'
    })
    total_params += pooling_params

    # Classification Head
    classifier_params = sum(p.numel() for p in model.classifier.parameters())
    architecture_data.append({
        'Component': 'Classification Head',
        'Layer Type': 'Multi-layer MLP + Dropout',
        'Input Shape': '[B, 128]',
        'Output Shape': f'[B, {model.num_classes}]',
        'Parameters': f'{classifier_params:,}',
        'Description': 'Final classification with uncertainty'
    })
    total_params += classifier_params

    # Uncertainty Head
    uncertainty_params = sum(p.numel() for p in model.uncertainty_head.parameters())
    architecture_data.append({
        'Component': 'Uncertainty Head',
        'Layer Type': 'MLP + Sigmoid',
        'Input Shape': '[B, 128]',
        'Output Shape': '[B, 1]',
        'Parameters': f'{uncertainty_params:,}',
        'Description': 'Prediction confidence estimation'
    })
    total_params += uncertainty_params

    # Create and display architecture table
    import pandas as pd
    df_arch = pd.DataFrame(architecture_data)
    logger.info("\n" + df_arch.to_string(index=False, max_colwidth=30))

    logger.info(f"\n🔢 TOTAL PARAMETERS: {total_params:,}")
    logger.info(f"🎯 PARAMETERS PER CLASS: {total_params // model.num_classes:,}")

    return df_arch, total_params

def create_layer_wise_parameter_table(model):
    """Create detailed layer-wise parameter breakdown"""
    logger.info(f"\n📊 LAYER-WISE PARAMETER BREAKDOWN:")
    logger.info("-"*80)

    layer_data = []

    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Leaf modules only
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
                layer_data.append({
                    'Layer Name': name,
                    'Layer Type': module.__class__.__name__,
                    'Parameters': f'{num_params:,}',
                    'Trainable': f'{trainable_params:,}',
                    'Percentage': f'{(num_params/sum(p.numel() for p in model.parameters()))*100:.2f}%'
                })

    # Sort by parameter count
    layer_data.sort(key=lambda x: int(x['Parameters'].replace(',', '')), reverse=True)

    # Show top 15 layers
    import pandas as pd
    df_layers = pd.DataFrame(layer_data[:15])
    logger.info("TOP 15 LAYERS BY PARAMETER COUNT:")
    logger.info(df_layers.to_string(index=False))

    return df_layers

def create_computational_complexity_table(model, input_shape=(16, 150, 69)):
    """Estimate computational complexity"""
    logger.info(f"\n⚡ COMPUTATIONAL COMPLEXITY ANALYSIS:")
    logger.info("-"*60)

    # Parameter-based estimation
    total_params = sum(p.numel() for p in model.parameters())
    estimated_flops = total_params * input_shape[1] * 2  # Rough estimate

    complexity_data = {
        'Total Parameters': f'{total_params:,}',
        'Estimated FLOPs': f'{estimated_flops:,}',
        'Model Size (MB)': f'{total_params * 4 / (1024**2):.2f}',
        'Memory per Sample (MB)': f'{(total_params + input_shape[1] * input_shape[2]) * 4 / (1024**2):.2f}'
    }

    logger.info("COMPUTATIONAL METRICS:")
    for key, value in complexity_data.items():
        logger.info(f"  {key}: {value}")

    return complexity_data

def create_academic_table_latex(df_arch, total_params, model):
    """Generate LaTeX table for academic papers"""

    latex_table = f"""
\\begin{{table*}}[ht]
\\centering
\\caption{{Architecture of the Proposed ConvLSTM-ViT Hybrid Model}}
\\label{{tab:model_architecture}}
\\begin{{tabular}}{{|l|l|l|l|r|l|}}
\\hline
\\textbf{{Component}} & \\textbf{{Layer Type}} & \\textbf{{Input Shape}} & \\textbf{{Output Shape}} & \\textbf{{Parameters}} & \\textbf{{Description}} \\\\
\\hline
"""

    for _, row in df_arch.iterrows():
        # Escape LaTeX special characters
        component = row['Component'].replace('&', '\\&')
        layer_type = row['Layer Type'].replace('&', '\\&')
        description = row['Description'].replace('&', '\\&')

        latex_table += f"{component} & {layer_type} & {row['Input Shape']} & {row['Output Shape']} & {row['Parameters']} & {description} \\\\\n\\hline\n"

    latex_table += f"""\\multicolumn{{4}}{{|l|}}{{\\textbf{{Total Parameters}}}} & \\textbf{{{total_params:,}}} & - \\\\
\\hline
\\multicolumn{{6}}{{|l|}}{{B: Batch size, Input dimensions: [Batch, 150, {model.input_dim}], Output classes: {model.num_classes}}} \\\\
\\hline
\\end{{tabular}}
\\end{{table*}}
"""

    logger.info(f"\n📝 LATEX TABLE FOR ACADEMIC PAPER:")
    logger.info("-"*60)
    logger.info(latex_table)

    # Save to file
    try:
        with open('model_architecture_table.tex', 'w') as f:
            f.write(latex_table)
        logger.info("💾 LaTeX table saved as 'model_architecture_table.tex'")
    except Exception as e:
        logger.warning(f"Could not save LaTeX file: {e}")

    return latex_table

def create_csv_export(df_arch, df_layers, complexity_data):
    """Export all tables to CSV for easy manipulation"""

    try:
        # Export architecture table
        df_arch.to_csv('model_architecture_summary.csv', index=False)

        # Export layer details
        df_layers.to_csv('model_layer_details.csv', index=False)

        # Export complexity data
        import pandas as pd
        complexity_df = pd.DataFrame(list(complexity_data.items()),
                                    columns=['Metric', 'Value'])
        complexity_df.to_csv('model_complexity_metrics.csv', index=False)

        logger.info(f"\n💾 EXPORTED FILES:")
        logger.info("  - model_architecture_summary.csv")
        logger.info("  - model_layer_details.csv")
        logger.info("  - model_complexity_metrics.csv")
        logger.info("  - model_architecture_table.tex")

    except Exception as e:
        logger.warning(f"Could not export CSV files: {e}")

def generate_complete_architecture_documentation(model, input_shape=(16, 150, 69)):
    """Generate complete architecture documentation for academic paper"""

    logger.info("🚀 GENERATING COMPLETE MODEL ARCHITECTURE DOCUMENTATION")
    logger.info("="*80)

    # 1. Main architecture table
    df_arch, total_params = create_model_architecture_table(model, input_shape)

    # 2. Layer-wise breakdown
    df_layers = create_layer_wise_parameter_table(model)

    # 3. Computational complexity
    complexity_data = create_computational_complexity_table(model, input_shape)

    # 4. LaTeX table for paper
    latex_table = create_academic_table_latex(df_arch, total_params, model)

    # 5. Export all data
    create_csv_export(df_arch, df_layers, complexity_data)

    # 6. Component analysis
    logger.info("\n📊 COMPONENT ANALYSIS:")
    logger.info("-" * 60)

    components = {
        'Input Processing': sum(p.numel() for p in model.input_projection.parameters()),
        'ConvLSTM Branch': sum(p.numel() for p in model.convlstm.parameters()),
        'Vision Transformer': sum(p.numel() for p in model.vit.parameters()),
        'Cross-Modal Fusion': sum(p.numel() for p in model.cross_modal_fusion.parameters()),
        'Pooling Mechanism': sum(p.numel() for p in model.adaptive_pool.parameters()) +
                             sum(p.numel() for p in model.pool_attention.parameters()),
        'Classification Head': sum(p.numel() for p in model.classifier.parameters()),
        'Uncertainty Head': sum(p.numel() for p in model.uncertainty_head.parameters())
    }

    total_params_check = sum(components.values())

    for component, params in components.items():
        percentage = (params / total_params_check) * 100
        logger.info(f"  {component:20}: {params:8,} parameters ({percentage:5.2f}%)")

    logger.info(f"  {'Total':20}: {total_params_check:8,} parameters (100.00%)")

    # 7. Academic table format for direct use
    logger.info(f"\n📋 ACADEMIC PAPER TABLE FORMAT:")
    logger.info("="*80)

    academic_table = f"""
Table 1: ConvLSTM-ViT Hybrid Model Architecture

Component              | Layer Type                        | Input Shape      | Output Shape     | Parameters
-----------------------|-----------------------------------|------------------|------------------|-------------
Input Projection      | Linear + LayerNorm + GELU        | [B, 150, {model.input_dim}]    | [B, 150, 128]   | {components['Input Processing']:,}
ConvLSTM Branch        | Bidirectional ConvLSTM (2 layers)| [B, 150, 128]   | [B, 150, 256]   | {components['ConvLSTM Branch']:,}
Vision Transformer     | Multi-head Attention (4 layers)  | [B, 150, 128]   | [B, 150, 128]   | {components['Vision Transformer']:,}
Cross-Modal Fusion     | Cross-Attention + MLP            | [B, 150, 384]   | [B, 150, 128]   | {components['Cross-Modal Fusion']:,}
Adaptive Pooling       | Multi-strategy + Attention       | [B, 150, 128]   | [B, 128]        | {components['Pooling Mechanism']:,}
Classification Head    | Multi-layer MLP + Dropout        | [B, 128]        | [B, {model.num_classes}]         | {components['Classification Head']:,}
Uncertainty Head       | MLP + Sigmoid                     | [B, 128]        | [B, 1]          | {components['Uncertainty Head']:,}
-----------------------|-----------------------------------|------------------|------------------|-------------
TOTAL                  |                                   |                  |                  | {total_params_check:,}

B: Batch size, Sequence length: 150, Input features: {model.input_dim}, Classes: {model.num_classes}
Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}
Model size: {sum(p.numel() for p in model.parameters()) * 4 / (1024**2):.2f} MB
"""

    logger.info(academic_table)

    logger.info(f"\n✅ ARCHITECTURE DOCUMENTATION COMPLETE!")
    logger.info(f"🎯 Model has {total_params:,} parameters across {len(df_arch)} main components")
    logger.info(f"📊 Documentation files saved for academic publication")

    return df_arch, df_layers, complexity_data, latex_table

if __name__ == "__main__":
    results = main()