<a href="https://colab.research.google.com/github/fjadidi2001/Insurance/blob/main/Insurance_ICCKe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Insurance Claims Prediction - Complete ML Pipeline with xLSTM
## Step-by-step implementation for telematics data analysis with advanced neural architecture

In [11]:
# =============================================================================
# STEP 1: IMPORT REQUIRED LIBRARIES
# =============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

# Machine Learning Libraries
from sklearn.model_selection import train_test_split, RandomizedSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.feature_selection import mutual_info_classif, SelectKBest
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

# Deep Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# Evaluation Metrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, matthews_corrcoef, confusion_matrix,
    roc_curve, precision_recall_curve, average_precision_score,
    classification_report
)

# Handle Imbalanced Data
from imblearn.over_sampling import SMOTE

# Model Interpretation
try:
    import shap
    import lime
    import lime.lime_tabular
    INTERPRETATION_AVAILABLE = True
except ImportError:
    INTERPRETATION_AVAILABLE = False
    print("SHAP and LIME not available. Install with: pip install shap lime")


SHAP and LIME not available. Install with: pip install shap lime


In [12]:
# =============================================================================
# STEP 1.5: xLSTM IMPLEMENTATION
# =============================================================================

class xLSTMCell(nn.Module):
    """
    Extended LSTM Cell with exponential gating and matrix memory
    Based on xLSTM architecture for enhanced sequence modeling
    """
    def __init__(self, input_size, hidden_size, memory_size=64):
        super(xLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.memory_size = memory_size

        # Standard LSTM gates
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.candidate_gate = nn.Linear(input_size + hidden_size, hidden_size)

        # xLSTM extensions
        self.exp_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.matrix_memory = nn.Parameter(torch.randn(memory_size, hidden_size))
        self.memory_gate = nn.Linear(input_size + hidden_size, memory_size)

        # Weight computation layers
        self.weight_transform = nn.Linear(hidden_size, hidden_size)
        self.weight_attention = nn.MultiheadAttention(hidden_size, num_heads=4)

    def forward(self, x, hidden_state, cell_state):
        combined = torch.cat([x, hidden_state], dim=1)

        # Standard LSTM computations
        forget = torch.sigmoid(self.forget_gate(combined))
        input_gate_val = torch.sigmoid(self.input_gate(combined))
        output = torch.sigmoid(self.output_gate(combined))
        candidate = torch.tanh(self.candidate_gate(combined))

        # xLSTM extensions
        exp_gate_val = torch.exp(self.exp_gate(combined))
        memory_weights = torch.softmax(self.memory_gate(combined), dim=1)

        # Matrix memory interaction
        memory_contribution = torch.mm(memory_weights, self.matrix_memory)

        # Enhanced cell state update
        cell_state = forget * cell_state + input_gate_val * candidate + 0.1 * memory_contribution
        cell_state = cell_state * exp_gate_val  # Exponential gating

        # Enhanced hidden state
        hidden_state = output * torch.tanh(cell_state)

        # Weight computation using attention
        weight_features = self.weight_transform(hidden_state).unsqueeze(0)
        attended_weights, attention_weights = self.weight_attention(
            weight_features, weight_features, weight_features
        )

        return hidden_state, cell_state, attended_weights.squeeze(0), attention_weights

class xLSTMWeightPredictor(nn.Module):
    """
    xLSTM-based model for predicting feature weights and insurance claims
    """
    def __init__(self, input_size, hidden_size=128, num_layers=2, memory_size=64):
        super(xLSTMWeightPredictor, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Feature embedding
        self.feature_embedding = nn.Linear(input_size, hidden_size)

        # xLSTM layers
        self.xlstm_cells = nn.ModuleList([
            xLSTMCell(hidden_size if i == 0 else hidden_size, hidden_size, memory_size)
            for i in range(num_layers)
        ])

        # Weight prediction layers
        self.weight_predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, input_size),
            nn.Softmax(dim=1)
        )

        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x, return_weights=False):
        batch_size = x.size(0)

        # Feature embedding
        embedded = self.feature_embedding(x)

        # Initialize hidden and cell states
        hidden_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
        cell_states = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]

        # Pass through xLSTM layers
        current_input = embedded
        all_attention_weights = []

        for i, xlstm_cell in enumerate(self.xlstm_cells):
            hidden_states[i], cell_states[i], weighted_features, attention_weights = xlstm_cell(
                current_input, hidden_states[i], cell_states[i]
            )
            current_input = hidden_states[i]
            all_attention_weights.append(attention_weights)

        # Final hidden state
        final_hidden = hidden_states[-1]

        # Predict feature weights
        feature_weights = self.weight_predictor(final_hidden)

        # Classification prediction
        classification_output = self.classifier(final_hidden)

        if return_weights:
            return classification_output, feature_weights, all_attention_weights
        return classification_output

class xLSTMTrainer:
    """
    Trainer class for xLSTM model
    """
    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=10
        )
        self.criterion = nn.BCELoss()

    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            predictions, weights, _ = self.model(batch_x, return_weights=True)

            # Classification loss
            class_loss = self.criterion(predictions.squeeze(), batch_y.float())

            # Weight regularization (encourage sparse, meaningful weights)
            weight_reg = torch.mean(torch.sum(weights * torch.log(weights + 1e-8), dim=1))

            # Total loss
            total_loss_batch = class_loss + 0.01 * weight_reg

            total_loss_batch.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            total_loss += total_loss_batch.item()

        return total_loss / len(dataloader)

    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0
        predictions = []
        targets = []

        with torch.no_grad():
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)

                pred, weights, _ = self.model(batch_x, return_weights=True)
                loss = self.criterion(pred.squeeze(), batch_y.float())

                total_loss += loss.item()
                predictions.extend(pred.cpu().numpy())
                targets.extend(batch_y.cpu().numpy())

        return total_loss / len(dataloader), np.array(predictions), np.array(targets)

    def train(self, train_loader, val_loader, epochs=100, early_stopping_patience=20):
        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(epochs):
            # Training
            train_loss = self.train_epoch(train_loader)

            # Validation
            val_loss, val_preds, val_targets = self.validate(val_loader)

            train_losses.append(train_loss)
            val_losses.append(val_loss)

            # Learning rate scheduling
            self.scheduler.step(val_loss)

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'best_xlstm_model.pth')
            else:
                patience_counter += 1

            if epoch % 10 == 0:
                val_auc = roc_auc_score(val_targets, val_preds)
                print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.4f}')

            if patience_counter >= early_stopping_patience:
                print(f'Early stopping at epoch {epoch}')
                break

        # Load best model
        self.model.load_state_dict(torch.load('best_xlstm_model.pth'))

        return train_losses, val_losses

In [10]:





# =============================================================================
# STEP 2: DATA LOADING AND INITIAL EXPLORATION
# =============================================================================

def load_and_explore_data(file_path):
    """
    Load the telematics dataset and perform initial exploration
    """
    print("=" * 50)
    print("STEP 2: DATA LOADING AND EXPLORATION")
    print("=" * 50)

    # Load the dataset
    df = pd.read_csv(file_path)

    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")

    # Display basic info
    print("\nDataset Info:")
    print(df.info())

    print("\nFirst 5 rows:")
    print(df.head())

    print("\nBasic statistics:")
    print(df.describe())

    # Check for missing values
    missing_values = df.isnull().sum()
    print(f"\nMissing values per column:")
    print(missing_values[missing_values > 0])
    print(f"Total missing values: {df.isnull().sum().sum()}")

    # Check for duplicate rows
    duplicates = df.duplicated().sum()
    print(f"Duplicate rows: {duplicates}")

    return df

# =============================================================================
# STEP 3: ENHANCED DATA PREPROCESSING
# =============================================================================

def preprocess_data(df):
    """
    Enhanced preprocessing with feature engineering
    """
    print("=" * 50)
    print("STEP 3: ENHANCED DATA PREPROCESSING")
    print("=" * 50)

    # Create target variable: Claims with amount > 1000
    if 'NB_Claim' in df.columns and 'AMT_Claim' in df.columns:
        df['ClaimYN'] = ((df['NB_Claim'] >= 1) & (df['AMT_Claim'] > 1000)).astype(int)
        df = df.drop(['NB_Claim', 'AMT_Claim'], axis=1)
    else:
        # If target already exists, use it
        if 'ClaimYN' not in df.columns:
            print("Warning: No target variable found. Creating dummy target.")
            df['ClaimYN'] = np.random.binomial(1, 0.3, len(df))

    print(f"Target variable distribution:")
    print(df['ClaimYN'].value_counts())
    print(f"Positive class ratio: {df['ClaimYN'].mean():.4f}")

    # Handle missing values with multiple strategies
    print("\nHandling missing values...")
    initial_rows = len(df)

    # For numerical columns, use median imputation
    numerical_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    numerical_cols.remove('ClaimYN')  # Remove target

    for col in numerical_cols:
        if df[col].isnull().sum() > 0:
            df[col].fillna(df[col].median(), inplace=True)

    # For categorical columns, use mode imputation
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    for col in categorical_cols:
        if df[col].isnull().sum() > 0:
            df[col].fillna(df[col].mode()[0], inplace=True)

    print(f"Rows after handling missing values: {len(df)} (changed {initial_rows - len(df)})")

    # Feature Engineering
    print("\nPerforming feature engineering...")

    # Create interaction features for numerical columns
    if len(numerical_cols) >= 2:
        # Create a few key interactions (limit to avoid curse of dimensionality)
        for i in range(min(3, len(numerical_cols))):
            for j in range(i+1, min(i+3, len(numerical_cols))):
                col1, col2 = numerical_cols[i], numerical_cols[j]
                df[f'{col1}_x_{col2}'] = df[col1] * df[col2]

    # Create polynomial features for top numerical features
    if len(numerical_cols) >= 1:
        for col in numerical_cols[:3]:  # Top 3 numerical features
            df[f'{col}_squared'] = df[col] ** 2
            df[f'{col}_log'] = np.log1p(np.abs(df[col]))

    # Encode categorical variables
    categorical_columns = df.select_dtypes(include=['object']).columns.tolist()
    print(f"Categorical columns for encoding: {categorical_columns}")

    if categorical_columns:
        df_encoded = pd.get_dummies(df, columns=categorical_columns, drop_first=True)
    else:
        df_encoded = df.copy()

    print(f"Shape after encoding: {df_encoded.shape}")

    # Separate features and target
    X = df_encoded.drop('ClaimYN', axis=1)
    y = df_encoded['ClaimYN']

    print(f"Final features shape: {X.shape}")
    print(f"Target shape: {y.shape}")

    return X, y

# =============================================================================
# STEP 4: FEATURE SCALING WITH MULTIPLE SCALERS
# =============================================================================

def scale_features(X_train, X_test, scaler_type='standard'):
    """
    Scale features using different scalers
    """
    print("=" * 50)
    print("STEP 4: FEATURE SCALING")
    print("=" * 50)

    if scaler_type == 'standard':
        scaler = StandardScaler()
    elif scaler_type == 'minmax':
        scaler = MinMaxScaler()
    else:
        scaler = StandardScaler()

    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    print(f"Features scaled using {scaler_type} scaler")
    print(f"Training set shape: {X_train_scaled.shape}")
    print(f"Test set shape: {X_test_scaled.shape}")

    return X_train_scaled, X_test_scaled, scaler

# =============================================================================
# STEP 5: HANDLE CLASS IMBALANCE WITH SMOTE
# =============================================================================

def handle_imbalance(X_train, y_train, method='smote'):
    """
    Handle class imbalance using SMOTE or other methods
    """
    print("=" * 50)
    print("STEP 5: HANDLING CLASS IMBALANCE")
    print("=" * 50)

    print("Original class distribution:")
    print(pd.Series(y_train).value_counts())
    print(f"Original positive class ratio: {pd.Series(y_train).mean():.4f}")

    if method == 'smote':
        # Apply SMOTE
        smote = SMOTE(random_state=42, k_neighbors=min(5, sum(y_train) - 1))
        X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
    else:
        # No resampling
        X_resampled, y_resampled = X_train, y_train

    print("\nClass distribution after resampling:")
    print(pd.Series(y_resampled).value_counts())
    print(f"New positive class ratio: {pd.Series(y_resampled).mean():.4f}")

    return X_resampled, y_resampled

# =============================================================================
# STEP 6: ENHANCED FEATURE SELECTION
# =============================================================================

def select_features(X_train, y_train, feature_names, top_k=15):
    """
    Enhanced feature selection using multiple methods
# =============================================================================
# STEP 1.5: xLSTM IMPLEMENTATION
# =============================================================================

class xLSTMCell(nn.Module):
    """
    Extended LSTM Cell with exponential gating and matrix memory
    Based on xLSTM architecture for enhanced sequence modeling
    """
…
    print(f"Features scaled using {scaler_type} scaler")
    print(f"Training set shape: {X_train_scaled.shape}")
    prin
    """
    print("=" * 50)
    print("STEP 6: ENHANCED FEATURE SELECTION")
    print("=" * 50)

    # Calculate mutual information scores
    mi_scores = mutual_info_classif(X_train, y_train, random_state=42)

    # Create feature importance DataFrame
    feature_importance = pd.DataFrame({
        'feature': feature_names,
        'mutual_info': mi_scores
    })

    # Add correlation with target
    correlations = []
    for i in range(X_train.shape[1]):
        corr = np.corrcoef(X_train[:, i], y_train)[0, 1]
        correlations.append(abs(corr) if not np.isnan(corr) else 0)

    feature_importance['correlation'] = correlations

    # Combined score (weighted average)
    feature_importance['combined_score'] = (
        0.7 * feature_importance['mutual_info'] +
        0.3 * feature_importance['correlation']
    )

    feature_importance = feature_importance.sort_values('combined_score', ascending=False)

    print(f"Top {top_k} features by combined score:")
    print(feature_importance.head(top_k))

    # Select top k features
    top_features = feature_importance.head(top_k)['feature'].tolist()
    selected_indices = [feature_names.index(feat) for feat in top_features]

    # Visualize feature importance
    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.barh(range(min(15, len(feature_importance))),
             feature_importance.head(15)['mutual_info'])
    plt.yticks(range(min(15, len(feature_importance))),
               feature_importance.head(15)['feature'])
    plt.xlabel('Mutual Information Score')
    plt.title('Top Features by Mutual Information')

    plt.subplot(2, 2, 2)
    plt.barh(range(min(15, len(feature_importance))),
             feature_importance.head(15)['correlation'])
    plt.yticks(range(min(15, len(feature_importance))),
               feature_importance.head(15)['feature'])
    plt.xlabel('Absolute Correlation with Target')
    plt.title('Top Features by Correlation')

    plt.subplot(2, 2, 3)
    plt.barh(range(top_k), top_features[::-1])
    plt.yticks(range(top_k), [f.replace('_', ' ') for f in top_features[::-1]])
    plt.xlabel('Combined Score')
    plt.title(f'Selected Top {top_k} Features')

    plt.subplot(2, 2, 4)
    plt.scatter(feature_importance['mutual_info'], feature_importance['correlation'], alpha=0.6)
    plt.xlabel('Mutual Information')
    plt.ylabel('Correlation')
    plt.title('Feature Selection Space')

    # Highlight selected features
    selected_features_data = feature_importance.head(top_k)
    plt.scatter(selected_features_data['mutual_info'],
               selected_features_data['correlation'],
               color='red', s=100, alpha=0.8, label='Selected')
    plt.legend()

    plt.tight_layout()
    plt.show()

    return selected_indices, feature_importance

# =============================================================================
# STEP 7: ENHANCED MODEL TRAINING
# =============================================================================

def train_models(X_train, y_train):
    """
    Train multiple models with hyperparameter tuning
    """
    print("=" * 50)
    print("STEP 7: ENHANCED MODEL TRAINING")
    print("=" * 50)

    models = {}

    # 1. Gradient Boosting Classifier
    print("Training Gradient Boosting Classifier...")
    gb_params = {
        'n_estimators': [50, 100, 150],
        'learning_rate': [0.01, 0.1, 0.2],
        'max_depth': [3, 5, 7],
        'min_samples_split': [2, 5, 10],
        'subsample': [0.8, 0.9, 1.0]
    }

    gb_classifier = GradientBoostingClassifier(random_state=42)
    gb_search = RandomizedSearchCV(
        gb_classifier, gb_params, n_iter=15, cv=3,
        scoring='roc_auc', random_state=42, n_jobs=-1
    )
    gb_search.fit(X_train, y_train)
    models['Gradient Boosting'] = gb_search.best_estimator_

    print(f"Best GB parameters: {gb_search.best_params_}")
    print(f"Best GB CV score: {gb_search.best_score_:.4f}")

    # 2. Random Forest Classifier
    print("\nTraining Random Forest Classifier...")
    rf_params = {
        'n_estimators': [50, 100, 200],
        'max_depth': [5, 10, 15, None],
        'min_samples_split': [2, 5, 10],
        'min_samples_leaf': [1, 2, 4]
    }

    rf_classifier = RandomForestClassifier(random_state=42, n_jobs=-1)
    rf_search = RandomizedSearchCV(
        rf_classifier, rf_params, n_iter=10, cv=3,
        scoring='roc_auc', random_state=42, n_jobs=-1
    )
    rf_search.fit(X_train, y_train)
    models['Random Forest'] = rf_search.best_estimator_

    print(f"Best RF parameters: {rf_search.best_params_}")
    print(f"Best RF CV score: {rf_search.best_score_:.4f}")

    # 3. Neural Network
    print("\nTraining Neural Network...")
    nn_params = {
        'hidden_layer_sizes': [(50,), (100,), (100, 50), (150, 75)],
        'activation': ['relu', 'tanh'],
        'alpha': [0.0001, 0.001, 0.01],
        'learning_rate': ['constant', 'adaptive']
    }

    nn_classifier = MLPClassifier(max_iter=1000, random_state=42)
    nn_search = RandomizedSearchCV(
        nn_classifier, nn_params, n_iter=8, cv=3,
        scoring='roc_auc', random_state=42, n_jobs=-1
    )
    nn_search.fit(X_train, y_train)
    models['Neural Network'] = nn_search.best_estimator_

    print(f"Best NN parameters: {nn_search.best_params_}")
    print(f"Best NN CV score: {nn_search.best_score_:.4f}")

    # 4. Logistic Regression
    print("\nTraining Logistic Regression...")
    lr_params = {
        'C': [0.1, 1.0, 10.0],
        'penalty': ['l1', 'l2'],
        'solver': ['liblinear', 'saga']
    }

    lr_classifier = LogisticRegression(random_state=42, max_iter=1000)
    lr_search = RandomizedSearchCV(
        lr_classifier, lr_params, n_iter=8, cv=3,
        scoring='roc_auc', random_state=42, n_jobs=-1
    )
    lr_search.fit(X_train, y_train)
    models['Logistic Regression'] = lr_search.best_estimator_

    print(f"Best LR parameters: {lr_search.best_params_}")
    print(f"Best LR CV score: {lr_search.best_score_:.4f}")

    return models

# =============================================================================
# STEP 7.5: xLSTM MODEL TRAINING
# =============================================================================

def train_xlstm_model(X_train, y_train, X_val, y_val, feature_names):
    """
    Train xLSTM model for weight prediction and classification
    """
    print("=" * 50)
    print("STEP 7.5: xLSTM MODEL TRAINING")
    print("=" * 50)

    # Convert to PyTorch tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.LongTensor(y_train)
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.LongTensor(y_val)

    # Create data loaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

    # Initialize model
    input_size = X_train.shape[1]
    xlstm_model = xLSTMWeightPredictor(
        input_size=input_size,
        hidden_size=128,
        num_layers=2,
        memory_size=64
    )

    # Initialize trainer
    trainer = xLSTMTrainer(xlstm_model)

    # Train model
    print("Training xLSTM model...")
    train_losses, val_losses = trainer.train(
        train_loader, val_loader,
        epochs=100, early_stopping_patience=15
    )

    # Plot training history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('xLSTM Training History')
    plt.legend()

    # Get feature weights
    xlstm_model.eval()
    with torch.no_grad():
        sample_input = X_train_tensor[:100]  # Use sample for weight analysis
        _, feature_weights, attention_weights = xlstm_model(sample_input, return_weights=True)

        # Average weights across samples
        avg_weights = torch.mean(feature_weights, dim=0).cpu().numpy()

    # Visualize feature weights
    plt.subplot(1, 2, 2)
    top_k = min(15, len(avg_weights))
    weight_df = pd.DataFrame({
        'feature': feature_names,
        'weight': avg_weights
    }).sort_values('weight', ascending=False)

    plt.barh(range(top_k), weight_df.head(top_k)['weight'])
    plt.yticks(range(top_k), [f.replace('_', ' ') for f in weight_df.head(top_k)['feature']])
    plt.xlabel('xLSTM Feature Weight')
    plt.title('Feature Importance by xLSTM')

    plt.tight_layout()
    plt.show()

    return xlstm_model, trainer, weight_df

# =============

SyntaxError: invalid syntax (ipython-input-1152845505.py, line 1128)