# TabTransformer for Classification Tasks

This notebook implements a TabTransformer model for tabular classification tasks, using PyTorch-Tabular. The implementation includes:

- A scikit-learn compatible TabTransformerClassifier
- Hyperparameter optimization with Optuna
- Model training with early stopping
- Performance evaluation and feature importance analysis
- Batch prediction functionality for memory efficiency

## Import Required Libraries

First, we'll import all the necessary packages for our TabTransformer implementation.

In [55]:
# Core packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time
import logging
import warnings
import joblib

# PyTorch and PyTorch-Tabular
import torch
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models import TabTransformerConfig

# Scikit-learn compatibility
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.utils.class_weight import compute_class_weight

# Optimization
import optuna
from optuna.pruners import MedianPruner

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress warnings
warnings.filterwarnings("ignore")

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x23cfea66e90>

## Create TabTransformer Classifier

We'll implement a TabTransformerClassifier class that extends scikit-learn's BaseEstimator and ClassifierMixin
to make the TabTransformer model compatible with scikit-learn workflows, including cross-validation and pipelines.

In [56]:
class TabTransformerClassifier(BaseEstimator, ClassifierMixin):
    """
    A scikit-learn compatible wrapper for PyTorch-Tabular's TabTransformer implementation.
    
    Parameters
    ----------
    categorical_cols : list
        List of categorical column names
    continuous_cols : list
        List of continuous column names
    embedding_dims : dict or None, default=None
        Dictionary of embedding dimensions for categorical columns
    num_heads : int, default=4
        Number of attention heads
    num_attn_blocks : int, default=4
        Number of transformer blocks
    attn_dropout : float, default=0.1
        Dropout rate for attention
    ff_dropout : float, default=0.1
        Dropout rate for feed-forward
    mlp_dropout : float, default=0.1
        Dropout rate for MLP
    lr : float, default=1e-3
        Learning rate
    weight_decay : float, default=0.0
        Weight decay for optimizer
    batch_size : int, default=64
        Batch size for training
    max_epochs : int, default=100
        Maximum number of training epochs
    patience : int, default=10
        Patience for early stopping
    target_col : str
        Name of the target column
    device : str, default='auto'
        Device to use for training ('cpu', 'cuda', or 'auto')
    model_dir : str, default='./models'
        Directory to save models
    """
    
    def __init__(self, 
                 categorical_cols, 
                 continuous_cols, 
                 embedding_dims=None, 
                 num_heads=4, 
                 num_attn_blocks=4, 
                 attn_dropout=0.1,
                 ff_dropout=0.1, 
                 mlp_dropout=0.1, 
                 lr=1e-3, 
                 weight_decay=0.0, 
                 batch_size=64, 
                 max_epochs=100, 
                 patience=10,
                 target_col='target', 
                 device='auto', 
                 model_dir='./models'):
        
        self.categorical_cols = categorical_cols
        self.continuous_cols = continuous_cols
        self.embedding_dims = embedding_dims
        self.num_heads = num_heads
        self.num_attn_blocks = num_attn_blocks
        self.attn_dropout = attn_dropout
        self.ff_dropout = ff_dropout
        self.mlp_dropout = mlp_dropout
        self.lr = lr
        self.weight_decay = weight_decay
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        self.patience = patience
        self.target_col = target_col
        self.device = device
        self.model_dir = model_dir
        
        self.model = None
        self.label_encoder = None
        self.class_weights = None
        self.classes_ = None
        
    # In the _create_model_configs method of TabTransformerClassifier class
    def _create_model_configs(self, X, y=None):
        """Create the configuration objects for PyTorch-Tabular."""
        
        # Determine output dimension based on number of classes
        if y is not None:
            n_classes = len(np.unique(y))
            self.classes_ = np.unique(y)
        else:
            n_classes = len(self.classes_)
        
        # Set output dimension based on binary or multi-class
        output_dim = 1 if n_classes == 2 else n_classes
        
        # Create TabTransformer specific config - removing embed_categorical parameter
        model_config = TabTransformerConfig(
            task="classification",
            learning_rate=self.lr,
            # embed_categorical=True,  # Remove this line
            embedding_dims=self.embedding_dims,
            num_heads=self.num_heads,
            num_attn_blocks=self.num_attn_blocks,
            attn_dropout=self.attn_dropout,
            ff_dropout=self.ff_dropout,
            # mlp_dropout=self.mlp_dropout,
            # output_dim=output_dim,
            # categorical_cardinality=self._get_cat_cardinality(X) if self.embedding_dims is None else None
        )
        
        # Create data configuration
        data_config = DataConfig(
            target=self.target_col,
            categorical_cols=self.categorical_cols,
            continuous_cols=self.continuous_cols,
            continuous_feature_transform="box-cox",
            normalize_continuous_features=True
        )
        
        # Create optimizer configuration
        optimizer_config = OptimizerConfig(
            optimizer="Adam",
            learning_rate=self.lr,
            weight_decay=self.weight_decay
        )
        
        # Create trainer configuration
        trainer_config = TrainerConfig(
            auto_lr_find=False,  # Use fixed learning rate
            batch_size=self.batch_size,
            max_epochs=self.max_epochs,
            early_stopping="valid_loss",
            early_stopping_patience=self.patience,
            checkpoints="best",
            load_best=True,
        )
        
        return model_config, data_config, optimizer_config, trainer_config
    
    def _get_cat_cardinality(self, X):
        """Get cardinality of categorical variables."""
        cardinality = {}
        for col in self.categorical_cols:
            cardinality[col] = len(X[col].unique())
        return cardinality
    
    def fit(self, X, y):
        """
        Fit the TabTransformer model.
        
        Parameters
        ----------
        X : pandas DataFrame
            Training data
        y : array-like
            Target values
            
        Returns
        -------
        self : object
            Returns self
        """
        logger.info("Starting model fitting")
        start_time = time.time()
        
        # Ensure target column exists in the DataFrame for PyTorch-Tabular
        X_train = X.copy()
        
        # Encode target if needed
        if y.dtype == object or isinstance(y, (list, pd.Series)) and isinstance(y[0], str):
            self.label_encoder = LabelEncoder()
            y_encoded = self.label_encoder.fit_transform(y)
            self.classes_ = self.label_encoder.classes_
        else:
            y_encoded = y
            self.classes_ = np.unique(y)
        
        # Add encoded target to DataFrame
        X_train[self.target_col] = y_encoded
        
        # Compute class weights for handling imbalanced data
        self.class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
        
        # Create model configs
        model_config, data_config, optimizer_config, trainer_config = self._create_model_configs(X, y)
        
        # Create and fit the model
        self.model = TabularModel(
            data_config=data_config,
            model_config=model_config,
            optimizer_config=optimizer_config,
            trainer_config=trainer_config
        )
        
        # Fit the model
        try:
            self.model.fit(train=X_train)
            logger.info(f"Model training complete in {time.time() - start_time:.2f} seconds")
            return self
        except Exception as e:
            logger.error(f"Error fitting model: {str(e)}")
            raise
    
    def predict(self, X, batch_size=None):
        """
        Predict class labels for samples in X.
        
        Parameters
        ----------
        X : pandas DataFrame
            The input data
        batch_size : int, optional
            Batch size for prediction
            
        Returns
        -------
        y_pred : array
            Predicted class labels
        """
        if self.model is None:
            logger.error("Model is not fitted yet. Call 'fit' before using 'predict'.")
            raise ValueError("Model is not fitted yet. Call 'fit' before using 'predict'.")
        
        # Use batch prediction if specified
        if batch_size:
            y_pred = self._batch_predict(X, batch_size=batch_size)
        else:
            # Get raw predictions
            pred_df = self.model.predict(X)
            
            # For binary classification
            if len(self.classes_) == 2:
                y_pred = (pred_df['prediction'].values > 0.5).astype(int)
            else:
                # For multi-class, take argmax
                y_pred = pred_df['prediction'].values.argmax(axis=1)
        
        # Inverse transform labels if encoder was used
        if self.label_encoder is not None:
            y_pred = self.label_encoder.inverse_transform(y_pred)
            
        return y_pred
    
    def predict_proba(self, X, batch_size=None):
        """
        Predict class probabilities for samples in X.
        
        Parameters
        ----------
        X : pandas DataFrame
            The input data
        batch_size : int, optional
            Batch size for prediction
            
        Returns
        -------
        y_proba : array
            Predicted class probabilities
        """
        if self.model is None:
            logger.error("Model is not fitted yet. Call 'fit' before using 'predict_proba'.")
            raise ValueError("Model is not fitted yet. Call 'fit' before using 'predict_proba'.")
        
        # Use batch prediction if specified
        if batch_size:
            return self._batch_predict_proba(X, batch_size=batch_size)
        
        pred_df = self.model.predict(X)
        
        # For binary classification
        if len(self.classes_) == 2:
            probs = pred_df['prediction'].values
            return np.vstack([1-probs, probs]).T
        else:
            # For multi-class
            return pred_df['prediction'].values
    
    def _batch_predict(self, X, batch_size=1000):
        """
        Make predictions in batches to handle large datasets.
        
        Parameters
        ----------
        X : pandas DataFrame
            The input data
        batch_size : int
            Size of each batch
            
        Returns
        -------
        y_pred : array
            Predicted class labels
        """
        total_samples = X.shape[0]
        y_pred_list = []
        
        for i in range(0, total_samples, batch_size):
            end_idx = min(i + batch_size, total_samples)
            batch_X = X.iloc[i:end_idx]
            
            # Get predictions for this batch
            batch_pred = self.predict(batch_X)
            y_pred_list.append(batch_pred)
        
        return np.concatenate(y_pred_list)
    
    def _batch_predict_proba(self, X, batch_size=1000):
        """
        Make probability predictions in batches to handle large datasets.
        
        Parameters
        ----------
        X : pandas DataFrame
            The input data
        batch_size : int
            Size of each batch
            
        Returns
        -------
        y_proba : array
            Predicted class probabilities
        """
        total_samples = X.shape[0]
        y_proba_list = []
        
        for i in range(0, total_samples, batch_size):
            end_idx = min(i + batch_size, total_samples)
            batch_X = X.iloc[i:end_idx]
            
            # Get probability predictions for this batch
            batch_proba = self.predict_proba(batch_X)
            y_proba_list.append(batch_proba)
        
        return np.vstack(y_proba_list)
    
    def score(self, X, y):
        """
        Return the accuracy on the given test data and labels.
        
        Parameters
        ----------
        X : pandas DataFrame
            Test data
        y : array-like
            True labels
            
        Returns
        -------
        score : float
            Accuracy score
        """
        y_pred = self.predict(X)
        return accuracy_score(y, y_pred)
    
    def save(self, filename):
        """
        Save the model to disk.
        
        Parameters
        ----------
        filename : str
            Path to save the model
        """
        if self.model is None:
            logger.error("Model is not fitted yet. Call 'fit' before using 'save'.")
            raise ValueError("Model is not fitted yet. Call 'fit' before using 'save'.")
        
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        
        # Save PyTorch-Tabular model
        self.model.save_model(filename + "_tabtransformer")
        
        # Save other attributes
        model_dict = {
            'categorical_cols': self.categorical_cols,
            'continuous_cols': self.continuous_cols,
            'embedding_dims': self.embedding_dims,
            'num_heads': self.num_heads,
            'num_attn_blocks': self.num_attn_blocks,
            'attn_dropout': self.attn_dropout,
            'ff_dropout': self.ff_dropout,
            'mlp_dropout': self.mlp_dropout,
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'batch_size': self.batch_size,
            'max_epochs': self.max_epochs,
            'patience': self.patience,
            'target_col': self.target_col,
            'device': self.device,
            'model_dir': self.model_dir,
            'classes_': self.classes_,
            'label_encoder': self.label_encoder,
            'class_weights': self.class_weights
        }
        
        joblib.dump(model_dict, filename + "_attributes.pkl")
        logger.info(f"Model saved to {filename}")
        
    @classmethod
    def load(cls, filename):
        """
        Load the model from disk.
        
        Parameters
        ----------
        filename : str
            Path to the saved model
            
        Returns
        -------
        model : TabTransformerClassifier
            The loaded model
        """
        # Load attributes
        model_dict = joblib.load(filename + "_attributes.pkl")
        
        # Create an instance
        instance = cls(
            categorical_cols=model_dict['categorical_cols'],
            continuous_cols=model_dict['continuous_cols'],
            embedding_dims=model_dict['embedding_dims'],
            num_heads=model_dict['num_heads'],
            num_attn_blocks=model_dict['num_attn_blocks'],
            attn_dropout=model_dict['attn_dropout'],
            ff_dropout=model_dict['ff_dropout'],
            mlp_dropout=model_dict['mlp_dropout'],
            lr=model_dict['lr'],
            weight_decay=model_dict['weight_decay'],
            batch_size=model_dict['batch_size'],
            max_epochs=model_dict['max_epochs'],
            patience=model_dict['patience'],
            target_col=model_dict['target_col'],
            device=model_dict['device'],
            model_dir=model_dict['model_dir']
        )
        
        # Set other attributes
        instance.classes_ = model_dict['classes_']
        instance.label_encoder = model_dict['label_encoder']
        instance.class_weights = model_dict['class_weights']
        
        # Load PyTorch-Tabular model
        instance.model = TabularModel.load_model(filename + "_tabtransformer")
        
        return instance

In [57]:
import numpy as np
import pandas as pd
import logging
import time
import json
import os
import optuna
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight, compute_sample_weight  # Added compute_sample_weight
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from scikeras.wrappers import KerasClassifier
# Imports for models used within the function
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier,
                              ExtraTreesClassifier, AdaBoostClassifier)
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from catboost import CatBoostClassifier
import lightgbm as lgb
import joblib  # For saving models
from sklearn.calibration import CalibratedClassifierCV  # For calibration step

# Assume build_keras_model function is defined elsewhere
# Assume logger is configured globally
# logger = logging.getLogger(__name__)
# Assume save_feature_importance function is defined
BOOSTING_EARLY_STOPPING_PATIENCE = 50
def optimize_model(X, y, timestamp, model_type, n_trials=30, n_jobs_optuna=1):
    """
    Optimizes hyperparameters for a given model type using Optuna,
    then trains and saves the final model with best parameters.
    Includes class_weight='balanced' or equivalent strategies.
    Correctly handles Optuna-specific trial parameters during final instantiation.
    Attempts calibration after successful model saving.
    """
    logger.info(f"Starting {model_type} optimization ({n_trials} trials)...")
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    if not isinstance(y, (np.ndarray, pd.Series)):
        y = np.array(y)
    if isinstance(X, np.ndarray):
        X = pd.DataFrame(X)  # Ensure DataFrame for consistent .iloc

    n_classes = len(np.unique(y))
    n_features = X.shape[1]
    y_keras = to_categorical(y, num_classes=n_classes) if model_type == 'keras_mlp' else y

    KERAS_EPOCHS = 150  # Reduced Keras epochs slightly
    KERAS_PATIENCE = 25  # Increased Keras patience slightly
    OPTUNA_TIMEOUT_PER_MODEL = 3600  # Default 1 hour
    # Increase timeout for complex models
    if model_type in ['xgboost', 'catboost', 'randomforest', 'gradientboosting', 'extratrees', 'lightgbm']:
        OPTUNA_TIMEOUT_PER_MODEL = 7200  # Increase to 2 hours
    logger.info(f"Optuna timeout for {model_type}: {OPTUNA_TIMEOUT_PER_MODEL}s.")

    # --- Optuna Objective Function ---
    def objective(trial):
        model = None
        fit_params = {}
        use_gpu = False
        is_keras = False

        # --- Model Definitions for Optuna Trial (Includes custom weights/params) ---
        if model_type == 'xgboost':
            tree_method = trial.suggest_categorical('tree_method', ['hist', 'gpu_hist'])
            param = {
                'objective': 'multi:softprob', # Keep for probabilities
                'num_class': n_classes,
                'eval_metric': 'mlogloss',
                'n_estimators': trial.suggest_int('n_estimators', 300, 5000, step=100), # Wider range
                'max_depth': trial.suggest_int('max_depth', 4, 26, step=1), # Wider range, finer step
                'learning_rate': trial.suggest_float('learning_rate', 0.007, 0.15, log=True), # Slightly lower min
                'subsample': trial.suggest_float('subsample', 0.5, 1.0), # Allow full subsample
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.4, 1.0), # Wider range
                'min_child_weight': trial.suggest_int('min_child_weight', 1, 20), # Wider range (regularization)
                'gamma': trial.suggest_float('gamma', 1e-7, 1.0, log=True), # Wider upper bound (regularization)
                'reg_alpha': trial.suggest_float('reg_alpha', 1e-7, 20.0, log=True), # Wider range (regularization)
                'reg_lambda': trial.suggest_float('reg_lambda', 1e-7, 20.0, log=True), # Wider range (regularization)
                'random_state': 42,
                'booster': 'gbtree',
                'tree_method': tree_method,
                # 'n_jobs': 1 # Handled by tree_method check
            }
            if tree_method == 'gpu_hist':
                param['gpu_id'] = 0
                # param.pop('n_jobs', None) # n_jobs not used with gpu_hist
            else:
                param['n_jobs'] = 1 # Use 1 core for CPU hist for stability within Optuna
                param.pop('gpu_id', None)

            model = XGBClassifier(**param)
            # Early stopping will be added in the CV loop fit call
            fit_params = {} # Reset, ES handled in loop

        elif model_type == 'catboost':
            task_type = trial.suggest_categorical('task_type', ['CPU', 'GPU'])
            # More nuanced class weight options - focus on boosting High (0) and Medium (2) slightly
            class_weight_options = [
                None,
                'Balanced',
                # Note: Optuna often saves dict keys as strings, ensure final fit handles int conversion if needed
                {0: 1.2, 1: 1.0, 2: 1.1}, # Boost High slightly more
                {0: 1.3, 1: 1.0, 2: 1.2}, # Boost High more, Medium slightly
                {0: 1.1, 1: 1.0, 2: 1.2}, # Boost Medium slightly more
            ]
            chosen_class_weight_config = trial.suggest_categorical('class_weight_config', class_weight_options)

            param = {
                'iterations': trial.suggest_int('iterations', 300, 5500, step=100), # Wider range
                'depth': trial.suggest_int('depth', 4, 24), # Wider range
                'learning_rate': trial.suggest_float('learning_rate', 0.007, 0.15, log=True), # Slightly lower min
                'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 0.5, 30.0, log=True), # Wider range (regularization)
                'random_strength': trial.suggest_float('random_strength', 1e-3, 10.0, log=True), # Exploration range
                'border_count': trial.suggest_categorical('border_count', [32, 64, 128, 254]), # Added 32
                'bagging_temperature': trial.suggest_float('bagging_temperature', 0.0, 0.9), # Wider range
                'loss_function': 'MultiClass',
                'eval_metric': 'Accuracy', # Using Accuracy as metric, can change to MultiClass
                'random_seed': 42,
                'thread_count': -1, # Use all available CPU cores if task_type='CPU'
                'verbose': False,
                'task_type': task_type,
                # 'auto_class_weights': None, # Set based on choice below
                # 'class_weights': None # Set based on choice below
            }
            # Apply class weight strategy
            if isinstance(chosen_class_weight_config, dict):
                param['class_weights'] = chosen_class_weight_config
                trial.set_user_attr("class_weight_info", chosen_class_weight_config) # Log the dict
            elif chosen_class_weight_config == 'Balanced':
                param['auto_class_weights'] = 'Balanced'
                trial.set_user_attr("class_weight_info", 'Balanced')
            else: # None case
                trial.set_user_attr("class_weight_info", 'None')


            if task_type == 'GPU':
                param['devices'] = '0'
                param.pop('thread_count', None) # Not needed for GPU

            model = CatBoostClassifier(**param)
            fit_params = {'early_stopping_rounds': BOOSTING_EARLY_STOPPING_PATIENCE, 'verbose': False}

        elif model_type == 'randomforest':
            # More class weight dictionary options
            class_weight_choices = [
                'balanced',
                'balanced_subsample',
                # Note: Optuna often saves dict keys as strings, ensure final fit handles int conversion if needed
                {0: 1.1, 1: 1.0, 2: 1.1},
                {0: 1.2, 1: 1.0, 2: 1.2},
                {0: 1.3, 1: 1.0, 2: 1.1},
            ]
            class_weight = trial.suggest_categorical('class_weight', class_weight_choices)
            param = {
                'n_estimators': trial.suggest_int('n_estimators', 300, 5500, step=100), # Wider range
                'max_depth': trial.suggest_int('max_depth', 10, 100, step=2), # Allow deeper trees, rely on leaf constraints
                # 'max_depth': trial.suggest_categorical('max_depth', [10, 20, 30, 40, 50, 60, None]), # Alternative: includes None
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 50), # Wider range (regularization)
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 30), # Wider range (regularization)
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', 0.6, 0.8]), # Added float options
                'bootstrap': True, # Usually best for RF
                'class_weight': class_weight, # Apply choice
                'random_state': 42,
                'n_jobs': n_jobs_optuna,
                'criterion': trial.suggest_categorical('criterion', ['gini', 'entropy']),
                'min_impurity_decrease': trial.suggest_float('min_impurity_decrease', 0.0, 0.05) # Add slight pruning
            }
            model = RandomForestClassifier(**param)
            fit_params = {} # RF doesn't have special fit params here

        elif model_type == 'extratrees':
            # More class weight dictionary options
            class_weight_choices = [
                'balanced',
                'balanced_subsample',
                {0: 1.1, 1: 1.0, 2: 1.1},
                {0: 1.2, 1: 1.0, 2: 1.2},
                {0: 1.3, 1: 1.0, 2: 1.1},
            ]
            class_weight = trial.suggest_categorical('class_weight', class_weight_choices)
            param = {
                'n_estimators': trial.suggest_int('n_estimators', 300, 5500, step=500), # Wider range
                'max_depth': trial.suggest_int('max_depth', 10, 80, step=2), # Allow deeper
                # 'max_depth': trial.suggest_categorical('max_depth', [10, 20, 30, 40, 50, 60, 70, None]), # Alternative
                'min_samples_split': trial.suggest_int('min_samples_split', 2, 40), # Wider range
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 30), # Wider range
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', 0.6, 0.8]), # Added floats
                'bootstrap': trial.suggest_categorical('bootstrap', [False, True]), # Tune bootstrap for ET
                'class_weight': class_weight, # Apply choice
                'random_state': 42,
                'n_jobs': n_jobs_optuna,
                'criterion': trial.suggest_categorical('criterion', ['gini', 'entropy']),
                'min_impurity_decrease': trial.suggest_float('min_impurity_decrease', 0.0, 0.05) # Add slight pruning
            }
            model = ExtraTreesClassifier(**param)
            fit_params = {}

        elif model_type == 'gradientboosting':
            param = {
                'n_estimators': trial.suggest_int('n_estimators', 200, 3500, step=100), # Wider range
                'learning_rate': trial.suggest_float('learning_rate', 0.007, 0.15, log=True), # Slightly lower min
                'max_depth': trial.suggest_int('max_depth', 3, 14), # Wider range
                'min_samples_split': trial.suggest_int('min_samples_split', 5, 50), # Wider range (regularization)
                'min_samples_leaf': trial.suggest_int('min_samples_leaf', 3, 40), # Wider range (regularization)
                'subsample': trial.suggest_float('subsample', 0.5, 1.0), # Allow 1.0
                'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', 0.7, 0.9]), # Added floats
                'random_state': 42,
                'loss': 'log_loss', # Keep log_loss for predict_proba
                'min_weight_fraction_leaf': trial.suggest_float('min_weight_fraction_leaf', 0.0, 0.2), # Wider range
                # 'ccp_alpha': trial.suggest_float('ccp_alpha', 0.0, 0.1) # Add cost-complexity pruning
            }
            model = GradientBoostingClassifier(**param)
            # Sample weights applied in the CV loop fit call below
            fit_params = {}

        elif model_type == 'adaboost':
            base_depth = trial.suggest_int('base_estimator_max_depth', 1, 6) # Allow slightly deeper base trees
            # More class weight options
            class_weights_options = [
                'balanced',
                {0: 1.2, 1: 1.0, 2: 1.1},
                {0: 1.3, 1: 1.0, 2: 1.2},
                {0: 1.1, 1: 1.0, 2: 1.3},
            ]
            weight_choice = trial.suggest_categorical('class_weight_choice', class_weights_options)
            param_ada = {
                'n_estimators': trial.suggest_int('n_estimators', 100, 5000, step=50), # Wider range
                'learning_rate': trial.suggest_float('learning_rate', 0.01, 2.0, log=True), # Wider range
                'algorithm':'SAMME', # Try both
                'random_state': 42
            }
            # Apply class weight choice to the base estimator
            base_est = DecisionTreeClassifier(max_depth=base_depth, random_state=42, class_weight=weight_choice)
            model = AdaBoostClassifier(estimator=base_est, **param_ada)
            # Log info for final model reconstruction
            trial.set_user_attr("base_estimator_max_depth", base_depth)
            trial.set_user_attr("class_weight_info", weight_choice if isinstance(weight_choice, dict) else 'balanced' if weight_choice=='balanced' else 'None')
            trial.set_user_attr("algorithm", param_ada['algorithm']) # Log algorithm choice
            fit_params = {}

        elif model_type == 'lightgbm':
            # More class weight options
            class_weight_options = [
                None,
                'balanced',
                {0: 1.2, 1: 1.0, 2: 1.1}, # Boost High slightly more
                {0: 1.3, 1: 1.0, 2: 1.2}, # Boost High more, Medium slightly
                {0: 1.1, 1: 1.0, 2: 1.3}, # Boost Medium more
            ]
            class_weight = trial.suggest_categorical('class_weight_option', class_weight_options)
            param = {
                'objective': 'multiclass',
                'num_class': n_classes,
                'metric': 'multi_logloss', # Standard metric for multi-class probabilities
                'n_estimators': trial.suggest_int('n_estimators', 300, 5500, step=100), # Wider range
                'learning_rate': trial.suggest_float('learning_rate', 0.007, 0.15, log=True), # Slightly lower min
                'num_leaves': trial.suggest_int('num_leaves', 20, 500, step=5), # Wider range (key param)
                'max_depth': trial.suggest_int('max_depth', 5, 50), # Wider range, can be -1 if num_leaves is constrained
                'subsample': trial.suggest_float('subsample', 0.5, 1.0), # Allow 1.0
                'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0), # Allow 1.0
                'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 20.0, log=True), # Wider range (regularization)
                'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 20.0, log=True), # Wider range (regularization)
                'min_child_samples': trial.suggest_int('min_child_samples', 3, 60), # Wider range (regularization)
                'class_weight': class_weight, # Apply choice
                'random_state': 42,
                'n_jobs': n_jobs_optuna,
                'verbose': -1,
                'boosting_type': trial.suggest_categorical('boosting_type', ['gbdt', 'dart']) # Keep both
                # Consider adding 'min_split_gain'
                # 'min_split_gain': trial.suggest_float('min_split_gain', 0.0, 0.1)
            }
            model = lgb.LGBMClassifier(**param)
            # Use specific early stopping callback for LGBM
            fit_params = {'callbacks': [lgb.early_stopping(BOOSTING_EARLY_STOPPING_PATIENCE, verbose=False)]}
            

        else:
            logger.error(f"Unsupported model type: {model_type}")
            raise ValueError(f"Unsupported: {model_type}")

        # --- Cross-validation ---
        scores = []
        is_dataframe = isinstance(X, pd.DataFrame)
        try:
            for fold, (train_idx, valid_idx) in enumerate(skf.split(X, y)):
                # --- Use .iloc consistently for pandas objects ---
                if is_dataframe: # If X is a DataFrame, assume y is a Series
                    X_train_fold = X.iloc[train_idx]
                    X_valid_fold = X.iloc[valid_idx]
                    # Select training labels based on position
                    y_train_fold = y_keras[train_idx] if is_keras else y.iloc[train_idx] # ***MODIFIED***
                    # Select validation labels based on position
                    y_valid_fold_orig = y.iloc[valid_idx] # ***MODIFIED***
                else: # If X is a numpy array, assume y is also numpy array
                    X_train_fold = X[train_idx]
                    X_valid_fold = X[valid_idx]
                    y_train_fold = y_keras[train_idx] if is_keras else y[train_idx]
                    y_valid_fold_orig = y[valid_idx]
                current_fit_params = fit_params.copy()

                # --- Handle Sample Weights for Models That Need It in Fit ---
                fold_sample_weight = None
                if model_type in ['gradientboosting']:  
                    # Calculate balanced weights
                    sample_weight = compute_sample_weight('balanced', y=y_train_fold)
                    # Apply custom emphasis based on strategy (e.g., boost High/Medium)
                    emphasis_weights = {0: 1.1, 1: 1.0, 2: 1.1}  # Example emphasis
                    for cls_idx, weight_multiplier in emphasis_weights.items():
                         # Ensure y_train_fold is numpy for boolean indexing if it was a Series
                         y_train_fold_np = y_train_fold.values if isinstance(y_train_fold, pd.Series) else y_train_fold
                         sample_weight[y_train_fold_np == cls_idx] *= weight_multiplier
                    fold_sample_weight = sample_weight
                    current_fit_params['sample_weight'] = fold_sample_weight
                    logger.debug(f"Trial {trial.number} Fold {fold+1}: Applied sample weights for {model_type}")

                try:
                    # Pass eval_set for models that use it with callbacks/early stopping
                    eval_set = [(X_valid_fold, y_valid_fold_orig)]
                    
                    # --- XGBoost Specific Fit Call ---
                    if model_type == 'xgboost':
                         model.fit(X_train_fold, y_train_fold,
                                   eval_set=eval_set,
                                   # Pass directly
                                   verbose=False) # Pass other relevant args directly if needed
                    # --- LightGBM Specific Fit Call (already seemed correct) ---
                    elif model_type == 'lightgbm':
                         # Note: LGBM uses callbacks for early stopping, passed via fit_params
                         current_fit_params['eval_set'] = eval_set
                         current_fit_params['eval_metric'] = 'multi_logloss' # Or match objective metric
                         model.fit(X_train_fold, y_train_fold, **current_fit_params)
                    # --- CatBoost Specific Fit Call (already seemed correct) ---
                    elif model_type == 'catboost':
                         current_fit_params['eval_set'] = eval_set
                         # Early stopping rounds already part of CatBoost init/params
                         model.fit(X_train_fold, y_train_fold, **current_fit_params)
                    # --- Default Fit Call for other models ---
                    else:
                         # Pass sample_weight if applicable (e.g., for GB)
                         model.fit(X_train_fold, y_train_fold, **current_fit_params)
                    
                    # Predict and score
                    y_pred = model.predict(X_valid_fold)
                    if is_keras and y_pred.ndim > 1 and y_pred.shape[1] > 1:
                        y_pred = np.argmax(y_pred, axis=1)
                    score = accuracy_score(y_valid_fold_orig, y_pred)
                    scores.append(score)
                    logger.debug(f"Trial {trial.number} Fold {fold+1} Score: {score:.5f}")

                except ValueError as ve:
                    logger.warning(f"CV fold {fold+1} VAL ERROR {model_type} trial {trial.number}: {ve}")
                    return 0.0
                except Exception as e:
                    logger.error(f"CV fold {fold+1} EXCEPTION {model_type} trial {trial.number}: {e}", exc_info=True)
                    scores = []
                    break  # Log full traceback
        except Exception as outer_e:
            logger.error(f"Outer CV error {model_type} trial {trial.number}: {outer_e}", exc_info=True)
            return 0.0
        if not scores:
            logger.error(f"Cross-validation failed completely for {model_type} trial {trial.number}")
            return 0.0
        mean_score = np.mean(scores)
        logger.debug(f"Trial {trial.number} ({model_type}) completed. Avg CV Score: {mean_score:.5f}")
        return mean_score







    # --- Run Optuna Study ---
    study_name = f"{model_type}_opt_{timestamp}"
    storage_name = f"sqlite:///optuna_trials/{study_name}.db"
    study = optuna.create_study(direction='maximize', study_name=study_name, storage=storage_name, 
                              load_if_exists=True, pruner=optuna.pruners.MedianPruner(n_warmup_steps=5))
    completed_trials = len([t for t in study.trials if t.state==optuna.trial.TrialState.COMPLETE])
    trials_to_run = n_trials-completed_trials
    
    if trials_to_run > 0:
        logger.info(f"Setting Optuna timeout {OPTUNA_TIMEOUT_PER_MODEL}s.")
        try:
            study.optimize(objective, n_trials=trials_to_run, timeout=OPTUNA_TIMEOUT_PER_MODEL, n_jobs=1)
        except Exception as opt_e:
            logger.error(f"Optuna optimize fail {model_type}: {opt_e}", exc_info=True)
            return None, -1, {}
    else:
        logger.info(f"Study {study_name} has {completed_trials} trials. Skip optimize.")

    # --- Retrieve Results ---
    try:
        if not any(t.state == optuna.trial.TrialState.COMPLETE for t in study.trials):
            logger.error(f"Optuna study {model_type} no successful trials.")
            return None, -1, {}
        best_trial = study.best_trial
        best_params = best_trial.params
        best_cv_score = best_trial.value
    except ValueError:
        logger.error(f"Optuna study {model_type} no best trial.")
        return None, -1, {}
    except Exception as res_e:
        logger.error(f"Error get Optuna results {model_type}: {res_e}", exc_info=True)
        return None, -1, {}
    logger.info(f"Opt complete {model_type}. Best CV score: {best_cv_score:.5f}. Best params: {best_params}")

    # --- Save Study Summary ---
    try:
        summary_file = f'optuna_trials/{model_type}_study_summary_{timestamp}.txt'
        params_json = best_params.copy()
        if model_type=='adaboost' and "base_estimator_max_depth" in best_trial.user_attrs:
            params_json['base_estimator_max_depth'] = best_trial.user_attrs["base_estimator_max_depth"]
            params_json['class_weight_info'] = best_trial.user_attrs.get("class_weight_info", "N/A")
        if model_type=='xgboost' and 'tree_method' in best_params:
            params_json['tree_method'] = best_params['tree_method']
        if model_type=='catboost' and 'task_type' in best_params:
            params_json['task_type'] = best_params['task_type']
        with open(summary_file, 'w') as f:
            f.write(f"Optuna Summary: {model_type}\nTS: {timestamp}\nBest Trial: {best_trial.number}\nScore: {best_cv_score:.5f}\n\nParams:\n")
            json.dump(params_json, f, indent=4)
        logger.info(f"Saved Optuna summary: {summary_file}")
    except Exception as file_e:
        logger.warning(f"Could not save Optuna summary {model_type}: {file_e}")

    # --- Train final model ---
    final_model = None
    final_fit_params = {}  # Reset for final fit
    try:
        logger.info(f"Instantiating final {model_type} model...")
        # Clean best_params from Optuna-specific args before final instantiation
        params_for_final = best_params.copy()
        optuna_internal_params = ['class_weight_option', 'class_weight_choice', 'class_weight_idx', 
                                 'class_weight_strategy', 'use_smote', 'smote_k', 
                                 'use_focal_loss', 'focal_gamma']  # Params used only in objective logic
        for p in optuna_internal_params:
            params_for_final.pop(p, None)

        # Inside optimize_model, after Optuna, in elif model_type == 'adaboost':
        if model_type == 'adaboost':
            # Clean best_params from Optuna-specific args before final instantiation
            params_for_final = best_params.copy()
            # List internal params used only during Optuna trials
            optuna_internal_params = ['class_weight_choice', 'base_estimator_max_depth']
            for p in optuna_internal_params:
                # --- THESE LINES REMOVE THE BAD PARAMETERS ---
                params_for_final.pop(p, None)
                # --- END OF REMOVAL ---

            # Retrieve correct values from Optuna trial attributes
            best_d = best_trial.user_attrs.get('base_estimator_max_depth', 1)
            weight_info_raw = best_trial.user_attrs.get("class_weight_info", 'balanced')
            
            # --- *** ADDED: Convert dictionary keys if needed *** ---
            weight_info_processed = weight_info_raw
            if isinstance(weight_info_raw, dict):
                try:
                    # Convert string keys ('0', '1', ...) to integers (0, 1, ...)
                    weight_info_processed = {int(k): v for k, v in weight_info_raw.items()}
                    logger.info(f"Converted AdaBoost class_weight keys to int: {weight_info_processed}")
                except ValueError as e:
                     logger.error(f"Error converting AdaBoost class_weight keys: {e}. Using raw: {weight_info_raw}")
                     weight_info_processed = weight_info_raw # Fallback to raw if conversion fails
            # --- *** END KEY CONVERSION *** ---
            
            logger.info(f"Reconstruct AdaBoost DT(max_depth={best_d}, class_weight={weight_info_processed}) using SAMME")
            # Create the base estimator correctly
            base_est_inst = DecisionTreeClassifier(max_depth=best_d, random_state=42, class_weight=weight_info_processed)

            final_p_ada = params_for_final # Use the cleaned dictionary for AdaBoost itself
            final_p_ada['algorithm'] = 'SAMME'
            # --- FINAL MODEL TRAINING WILL STILL HAPPEN USING base_est_inst and final_p_ada ---
            final_model = AdaBoostClassifier(estimator=base_est_inst, **final_p_ada)

        elif model_type == 'xgboost':
            final_params_xgb = params_for_final.copy()
            final_params_xgb['objective'] = 'multi:softprob'
            final_params_xgb['num_class'] = n_classes
            final_params_xgb['n_jobs'] = 1
            logger.info("XGBoost final model - balancing via sample_weight in fit.")
            final_model = XGBClassifier(**final_params_xgb)
            # Prepare sample weights for fit step
            sample_weights_xgb = compute_sample_weight('balanced', y=y)  # Start with balanced
            emphasis_weights = {0: 2, 1: 1.0, 2: 2}  # Emphasize High/Medium
            for cls_idx, weight_multiplier in emphasis_weights.items():
                sample_weights_xgb[y == cls_idx] *= weight_multiplier
            final_fit_params['sample_weight'] = sample_weights_xgb

        elif model_type == 'catboost':
            final_params_cat = params_for_final.copy()
            final_params_cat['loss_function'] = 'MultiClass'
            final_params_cat['verbose'] = False
            # Re-apply class weight strategy based on best trial's choice
            chosen_weight = best_params.get('class_weight_option')
            if isinstance(chosen_weight, dict):
                final_params_cat['class_weights'] = chosen_weight
                logger.info(f"CatBoost using custom weights: {chosen_weight}")
            elif chosen_weight == 'Balanced':
                final_params_cat['auto_class_weights'] = 'Balanced'
                logger.info("CatBoost using auto_class_weights=Balanced")
            else:
                logger.info("CatBoost using default balancing or no weights.")
            final_model = CatBoostClassifier(**final_params_cat)

        elif model_type == 'randomforest':
            final_params_rf = params_for_final.copy()
            final_params_rf['n_jobs'] = n_jobs_optuna
            
            # --- *** ADDED: Process class_weight dictionary keys *** ---
            class_weight_raw = best_params.get('class_weight', 'balanced')
            class_weight_processed = class_weight_raw
            if isinstance(class_weight_raw, dict):
                 try:
                     # Convert string keys ('0', '1', ...) to integers (0, 1, ...)
                     class_weight_processed = {int(k): v for k, v in class_weight_raw.items()}
                     logger.info(f"Converted RF class_weight keys to int: {class_weight_processed}")
                 except ValueError as e:
                      logger.error(f"Error converting RF class_weight keys: {e}. Using raw: {class_weight_raw}")
                      class_weight_processed = class_weight_raw # Fallback
            # --- *** END KEY CONVERSION *** ---
            
            final_params_rf['class_weight'] = best_params.get('class_weight', 'balanced')  # Use optimized or default balanced
            logger.info(f"RF final model using class_weight={final_params_rf['class_weight']}")
            final_model = RandomForestClassifier(**final_params_rf)

        elif model_type == 'extratrees':
            final_params_et = params_for_final.copy()
            final_params_et['n_jobs'] = n_jobs_optuna
            
            # --- *** ADDED: Process class_weight dictionary keys *** ---
            class_weight_raw = best_params.get('class_weight', 'balanced')
            class_weight_processed = class_weight_raw
            if isinstance(class_weight_raw, dict):
                 try:
                     # Convert string keys ('0', '1', ...) to integers (0, 1, ...)
                     class_weight_processed = {int(k): v for k, v in class_weight_raw.items()}
                     logger.info(f"Converted ET class_weight keys to int: {class_weight_processed}")
                 except ValueError as e:
                      logger.error(f"Error converting ET class_weight keys: {e}. Using raw: {class_weight_raw}")
                      class_weight_processed = class_weight_raw # Fallback
            # --- *** END KEY CONVERSION *** ---
            
            final_params_et['class_weight'] = best_params.get('class_weight', 'balanced')  # Use optimized or default balanced
            logger.info(f"ET final model using class_weight={final_params_et['class_weight']}")
            final_model = ExtraTreesClassifier(**final_params_et)

        elif model_type == 'gradientboosting':
            final_params_gb = params_for_final.copy()
            logger.info("GradientBoosting final model - applying sample_weight in fit")
            final_model = GradientBoostingClassifier(**final_params_gb)
            sample_weights_gb = compute_sample_weight('balanced', y=y)
            emphasis_weights = {0: 1.6, 1: 1.0, 2: 1.5}
            for cls_idx, mult in emphasis_weights.items():
                sample_weights_gb[y == cls_idx] *= mult
            final_fit_params['sample_weight'] = sample_weights_gb

        # Inside optimize_model, after Optuna, in elif model_type == 'knn':

        elif model_type == 'lightgbm':
            final_params_lgbm = params_for_final.copy()
            final_params_lgbm['objective'] = 'multiclass'
            final_params_lgbm['num_class'] = n_classes
            final_params_lgbm['n_jobs'] = n_jobs_optuna
            
            # --- *** ADDED: Refined Key Conversion for LGBM *** ---
            class_weight_value_to_use = 'balanced' # Default
            # Use 'class_weight_option' key from Optuna params for LGBM
            if 'class_weight_option' in best_params:
                class_weight_raw = best_params['class_weight_option'] # Get raw value from Optuna result
                class_weight_processed = class_weight_raw

                if isinstance(class_weight_raw, dict):
                    logger.info(f"Raw class_weight dict found for LGBM: {class_weight_raw}")
                    try:
                        if all(isinstance(k, int) for k in class_weight_raw.keys()):
                            logger.info("LGBM class_weight keys appear to be integers already.")
                            class_weight_processed = class_weight_raw
                        else:
                            logger.info("Attempting conversion of LGBM class_weight keys to int...")
                            class_weight_processed = {int(k): v for k, v in class_weight_raw.items()}
                            logger.info(f"Successfully converted LGBM class_weight keys to int: {class_weight_processed}")
                    except Exception as e_gen:
                         logger.error(f"Error processing LGBM class_weight dict: {e_gen}. Using 'balanced'.")
                         class_weight_processed = 'balanced'
                class_weight_value_to_use = class_weight_processed
            else:
                 logger.info("No 'class_weight_option' found in best_params for LGBM, using default 'balanced'.")
                 class_weight_value_to_use = 'balanced'
            
            final_params_lgbm['class_weight'] = class_weight_value_to_use
            logger.info(f"LGBM final model using class_weight={final_params_lgbm['class_weight']}")
            final_model = lgb.LGBMClassifier(**final_params_lgbm)

        # --- Fit the final model ---
        if final_model is not None:
            logger.info(f"Fitting final {model_type} model...")
            start_fit_time = time.time()
            model_fitted_successfully = False
            try:
                # Fit using specific params if they exist (like sample_weight)
                if final_fit_params:
                    logger.info(f"Fitting {model_type} with additional fit parameters: {list(final_fit_params.keys())}")
                    final_model.fit(X, y, **final_fit_params)  # Pass original y and weights dict
                else:
                    final_model.fit(X, y)  # Fit standard models

                fit_duration = time.time() - start_fit_time
                logger.info(f"Final {model_type} fitted in {fit_duration:.2f}s.")
                model_fitted_successfully = True

            except Exception as fit_e:
                logger.error(f"Error during final fit for {model_type}: {fit_e}", exc_info=True)
                # Keep going to return score/params, but model will be None

            # --- Save model and importance only if fit succeeded ---
            if model_fitted_successfully:
                model_path = f'models/{model_type}_{timestamp}.joblib'
                logger.info(f"Saving final {model_type} model...")
                try:
                    if isinstance(final_model, KerasClassifier):
                        tf_model_save_path = f'models/{model_type}_tfmodel_{timestamp}'
                        try:
                            final_model.model_.save(tf_model_save_path)
                            logger.info(f"Saved Keras TF model: {tf_model_save_path}")
                        except Exception as k_save_err:
                            logger.warning(f"Keras TF save fail ({k_save_err}), try joblib...")
                            joblib.dump(final_model, model_path)
                            logger.info(f"Saved Keras wrapper: {model_path}")
                    else:
                        joblib.dump(final_model, model_path)
                        logger.info(f"Saved final {model_type} via joblib: {model_path}")
                except Exception as save_err:
                    logger.error(f"Failed save model {model_type}: {save_err}", exc_info=True)

                # --- Attempt Calibration AFTER saving base model ---
                if model_type not in ['knn', 'mlp']:  # Models less suitable or needing sample_weight for calibration fit
                    try:
                        logger.info(f"Attempting calibration for {model_type}...")
                        # Use 'estimator' argument, not 'base_estimator'
                        calibrated_model = CalibratedClassifierCV(
                            estimator=final_model,
                            cv=3,
                            method='isotonic',
                            n_jobs=n_jobs_optuna,
                            ensemble=False
                        )
                        calibrated_model.fit(X, y)  # Calibrate on the full training data
                        calibrated_path = f'calibrated_models/{model_type}_calibrated_{timestamp}.joblib'
                        if not os.path.exists('calibrated_models'):
                            os.makedirs('calibrated_models')
                        joblib.dump(calibrated_model, calibrated_path)
                        logger.info(f"Saved calibrated model: {calibrated_path}")
                    except Exception as cal_err:
                        logger.warning(f"Calibration failed for {model_type}: {cal_err}", exc_info=False)

                # --- Save Importance ---
                feat_names = list(X.columns) if isinstance(X, pd.DataFrame) else None
                if feat_names:
                    logger.info(f"Saving importance {model_type}...")
                    save_feature_importance(final_model, feat_names, timestamp, model_type)
                else:
                    logger.warning(f"No feat names for importance {model_type}.")

            else:  # Fit failed
                final_model = None  # Ensure model is None if fit failed

        else:  # Instantiation failed
            logger.error(f"Could not instantiate final model {model_type}.")
            return None, best_cv_score, best_params

    except Exception as final_e:
        logger.error(f"Failed final instantiate/fit/save process {model_type}: {final_e}", exc_info=True)
        # Return score/params from Optuna, but model is None
        return None, best_cv_score, best_params

    # Return potentially None model if fit/save failed, but score/params if Optuna succeeded
    return final_model, best_cv_score, best_params

## Data Preparation

Next, we'll create functions to preprocess tabular data, including categorical encoding, feature scaling, and train-test splitting with stratification.

In [58]:
def prepare_data(df, target_col, categorical_cols=None, continuous_cols=None, 
                 test_size=0.2, val_size=0.1, random_state=42):
    """
    Prepare data for TabTransformer model by splitting and preprocessing.
    
    Parameters
    ----------
    df : pandas DataFrame
        Input DataFrame
    target_col : str
        Name of target column
    categorical_cols : list, optional
        List of categorical column names. If None, will be auto-detected
    continuous_cols : list, optional
        List of continuous column names. If None, will be auto-detected
    test_size : float, default=0.2
        Proportion of data to use for testing
    val_size : float, default=0.1
        Proportion of training data to use for validation
    random_state : int, default=42
        Random seed for reproducibility
        
    Returns
    -------
    dict
        Dictionary containing train, validation, and test splits, along with column information
    """
    logger.info("Starting data preparation")
    
    # Create copy of the dataframe
    df = df.copy()
    
    # Separate features and target
    X = df.drop(columns=[target_col])
    y = df[target_col]
    
    # Auto-detect column types if not provided
    if categorical_cols is None and continuous_cols is None:
        categorical_cols = []
        continuous_cols = []
        
        for col in X.columns:
            # Check if column has few unique values or is object type
            if X[col].dtype == 'object' or X[col].nunique() < 10:
                categorical_cols.append(col)
            else:
                continuous_cols.append(col)
                
        logger.info(f"Auto-detected {len(categorical_cols)} categorical columns and {len(continuous_cols)} continuous columns")
    
    # First split into train+val and test
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=y
    )
    
    # Then split train+val into train and validation
    val_ratio = val_size / (1 - test_size)  # Adjusted validation ratio
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=val_ratio, random_state=random_state, stratify=y_train_val
    )
    
    # Add target back to dataframes for PyTorch-Tabular
    train_df = X_train.copy()
    train_df[target_col] = y_train
    
    val_df = X_val.copy()
    val_df[target_col] = y_val
    
    test_df = X_test.copy()
    test_df[target_col] = y_test
    
    logger.info(f"Data split: train={train_df.shape}, val={val_df.shape}, test={test_df.shape}")
    
    # Return dictionary with all information
    return {
        'train_df': train_df,
        'val_df': val_df,
        'test_df': test_df,
        'X_train': X_train,
        'y_train': y_train,
        'X_val': X_val,
        'y_val': y_val,
        'X_test': X_test,
        'y_test': y_test,
        'categorical_cols': categorical_cols,
        'continuous_cols': continuous_cols,
        'target_col': target_col
    }

def process_categorical_features(df, categorical_cols, encoding='label', max_categories=20):
    """
    Process categorical features with various encoding strategies.
    
    Parameters
    ----------
    df : pandas DataFrame
        Input DataFrame
    categorical_cols : list
        List of categorical column names
    encoding : str, default='label'
        Encoding method ('label', 'one-hot', 'target')
    max_categories : int, default=20
        Maximum number of categories to keep; others will be grouped
        
    Returns
    -------
    pandas DataFrame
        DataFrame with processed categorical columns
    dict
        Dictionary of encoders
    """
    df_processed = df.copy()
    encoders = {}
    
    if encoding == 'label':
        for col in categorical_cols:
            # Check if too many categories
            if df[col].nunique() > max_categories:
                logger.info(f"Column {col} has {df[col].nunique()} categories (>max_categories). "
                           f"Top {max_categories} will be kept, others grouped.")
                
                # Find top N categories
                top_cats = df[col].value_counts().nlargest(max_categories).index.tolist()
                
                # Replace rare categories with 'Other'
                df_processed[col] = df[col].apply(lambda x: x if x in top_cats else 'Other')
            
            # Apply label encoding
            le = LabelEncoder()
            df_processed[col] = le.fit_transform(df[col].astype(str))
            encoders[col] = le
            
    elif encoding == 'one-hot':
        for col in categorical_cols:
            # Check if too many categories
            if df[col].nunique() > max_categories:
                logger.info(f"Column {col} has {df[col].nunique()} categories (>max_categories). "
                           f"Top {max_categories} will be kept, others grouped.")
                
                # Find top N categories
                top_cats = df[col].value_counts().nlargest(max_categories).index.tolist()
                
                # Replace rare categories with 'Other'
                df_processed[col] = df[col].apply(lambda x: x if x in top_cats else 'Other')
            
            # Apply one-hot encoding
            one_hot = pd.get_dummies(df_processed[col], prefix=col, drop_first=False)
            df_processed = pd.concat([df_processed, one_hot], axis=1)
            df_processed.drop(col, axis=1, inplace=True)
            
    elif encoding == 'target':
        logger.warning("Target encoding should be done carefully to avoid data leakage. "
                      "Make sure to use it only within cross-validation folds.")
        # Target encoding implementation would go here
    
    return df_processed, encoders

def apply_categorical_encoders(df, encoders, encoding='label'):
    """
    Apply pre-fitted categorical encoders to new data.
    
    Parameters
    ----------
    df : pandas DataFrame
        Input DataFrame
    encoders : dict
        Dictionary of fitted encoders
    encoding : str, default='label'
        Encoding method used
        
    Returns
    -------
    pandas DataFrame
        DataFrame with encoded categorical columns
    """
    df_processed = df.copy()
    
    if encoding == 'label':
        for col, encoder in encoders.items():
            # Convert to string and handle unseen categories
            df_processed[col] = df_processed[col].astype(str)
            
            # Handle unseen categories by setting them to most frequent class
            for category in df_processed[col].unique():
                if category not in encoder.classes_:
                    logger.warning(f"Unseen category '{category}' found in column '{col}'. Replacing with most frequent category.")
                    most_frequent = encoder.classes_[0]  # Assumes first class is most frequent
                    df_processed.loc[df_processed[col] == category, col] = most_frequent
            
            # Apply encoding
            df_processed[col] = encoder.transform(df_processed[col])
            
    # Other encoding methods would be implemented here
    
    return df_processed

## Hyperparameter Optimization with Optuna

Now, we'll implement an Optuna-based hyperparameter optimization function that searches for optimal TabTransformer configurations.

In [59]:
def optimize_tabtransformer_hyperparameters(data_dict, n_trials=50, timeout=3600, direction='maximize'):
    """
    Optimize TabTransformer hyperparameters using Optuna.
    
    Parameters
    ----------
    data_dict : dict
        Dictionary containing the data splits and column information
    n_trials : int, default=50
        Number of optimization trials
    timeout : int, default=3600
        Timeout in seconds
    direction : str, default='maximize'
        Optimization direction ('maximize' or 'minimize')
        
    Returns
    -------
    dict
        Dictionary with best hyperparameters
    """
    logger.info(f"Starting hyperparameter optimization with {n_trials} trials")
    
    # Extract data from dictionary
    X_train = data_dict['X_train']
    y_train = data_dict['y_train']
    X_val = data_dict['X_val']
    y_val = data_dict['y_val']
    categorical_cols = data_dict['categorical_cols']
    continuous_cols = data_dict['continuous_cols']
    target_col = data_dict['target_col']
    
    # Create study
    study = optuna.create_study(
        pruner=MedianPruner(n_warmup_steps=10),
        direction=direction,
        study_name="tabtransformer_optimization"
    )
    
    def objective(trial):
        """Objective function for hyperparameter optimization."""
        
        # Sample hyperparameters
        num_heads = trial.suggest_int('num_heads', 2, 8)
        num_attn_blocks = trial.suggest_int('num_attn_blocks', 1, 6)
        attn_dropout = trial.suggest_float('attn_dropout', 0.0, 0.5, step=0.1)
        ff_dropout = trial.suggest_float('ff_dropout', 0.0, 0.5, step=0.1)
        mlp_dropout = trial.suggest_float('mlp_dropout', 0.0, 0.5, step=0.1)
        lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
        weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-3, log=True)
        batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256])
        
        # Create model with sampled hyperparameters
        model = TabTransformerClassifier(
            categorical_cols=categorical_cols,
            continuous_cols=continuous_cols,
            num_heads=num_heads,
            num_attn_blocks=num_attn_blocks,
            attn_dropout=attn_dropout,
            ff_dropout=ff_dropout,
            mlp_dropout=mlp_dropout,
            lr=lr,
            weight_decay=weight_decay,
            batch_size=batch_size,
            max_epochs=20,  # Use fewer epochs for faster trials
            patience=5,
            target_col=target_col
        )
        
        try:
            # Add target column to X_train for PyTorch-Tabular
            train_df = X_train.copy()
            train_df[target_col] = y_train
            
            # Fit the model
            model.fit(X_train, y_train)
            
            # Evaluate on validation set
            val_score = model.score(X_val, y_val)
            
            # Clean up to free memory
            del model
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            return val_score
            
        except Exception as e:
            logger.warning(f"Trial failed with error: {str(e)}")
            # Return a poor score
            return -1.0 if direction == 'maximize' else float('inf')
    
    # Run optimization
    study.optimize(objective, n_trials=n_trials, timeout=timeout)
    
    # Get best parameters
    best_params = study.best_params
    logger.info(f"Best hyperparameters: {best_params}")
    logger.info(f"Best score: {study.best_value:.4f}")
    
    # Create visualization if Optuna has visualization capabilities
    try:
        import optuna.visualization as vis
        importance = vis.plot_param_importances(study)
        history = vis.plot_optimization_history(study)
        logger.info("Created Optuna visualizations")
    except:
        logger.warning("Optuna visualization not available")
    
    return best_params

## Model Training with Early Stopping

Now, let's set up TabTransformer training with early stopping to prevent overfitting, including training loop monitoring and checkpoint saving.

In [60]:
def train_tabtransformer_with_monitoring(data_dict, hyperparams=None, model_dir='./models', 
                                        monitor_metric='val_loss', early_stopping=True, 
                                        patience=10, max_epochs=100):
    """
    Train TabTransformer model with monitoring and early stopping.
    
    Parameters
    ----------
    data_dict : dict
        Dictionary containing the data splits and column information
    hyperparams : dict, optional
        Dictionary of hyperparameters
    model_dir : str, default='./models'
        Directory to save model checkpoints
    monitor_metric : str, default='val_loss'
        Metric to monitor for early stopping
    early_stopping : bool, default=True
        Whether to apply early stopping
    patience : int, default=10
        Number of epochs to wait for improvement before stopping
    max_epochs : int, default=100
        Maximum number of training epochs
        
    Returns
    -------
    TabTransformerClassifier
        Trained model
    dict
        Training history
    """
    logger.info("Starting TabTransformer training with monitoring")
    
    # Extract data from dictionary
    X_train = data_dict['X_train']
    y_train = data_dict['y_train']
    X_val = data_dict['X_val'] 
    y_val = data_dict['y_val']
    categorical_cols = data_dict['categorical_cols']
    continuous_cols = data_dict['continuous_cols']
    target_col = data_dict['target_col']
    
    # Set default hyperparameters if not provided
    if hyperparams is None:
        hyperparams = {
            'num_heads': 4,
            'num_attn_blocks': 4,
            'attn_dropout': 0.1,
            'ff_dropout': 0.1,
            'mlp_dropout': 0.1,
            'lr': 1e-3,
            'weight_decay': 1e-5,
            'batch_size': 64
        }
    
    # Create model
    model = TabTransformerClassifier(
        categorical_cols=categorical_cols,
        continuous_cols=continuous_cols,
        num_heads=hyperparams.get('num_heads', 4),
        num_attn_blocks=hyperparams.get('num_attn_blocks', 4),
        attn_dropout=hyperparams.get('attn_dropout', 0.1),
        ff_dropout=hyperparams.get('ff_dropout', 0.1),
        mlp_dropout=hyperparams.get('mlp_dropout', 0.1),
        lr=hyperparams.get('lr', 1e-3),
        weight_decay=hyperparams.get('weight_decay', 1e-5),
        batch_size=hyperparams.get('batch_size', 64),
        max_epochs=max_epochs,
        patience=patience if early_stopping else max_epochs,
        target_col=target_col,
        model_dir=model_dir
    )
    
    # Create directory if it doesn't exist
    os.makedirs(model_dir, exist_ok=True)
    
    # Train the model
    start_time = time.time()
    model.fit(X_train, y_train)
    training_time = time.time() - start_time
    
    # Evaluate on validation set
    val_score = model.score(X_val, y_val)
    logger.info(f"Validation score: {val_score:.4f}")
    logger.info(f"Training completed in {training_time:.2f} seconds")
    
    # Get training history from PyTorch-Tabular
    history = {
        'training_time': training_time,
        'val_score': val_score
    }
    
    if hasattr(model.model, 'trainer') and hasattr(model.model.trainer, 'logger'):
        # Extract metrics from PyTorch Lightning logger if available
        try:
            metrics = model.model.trainer.logger.metrics
            for key, value in metrics.items():
                if isinstance(value, (int, float)):
                    history[key] = value
        except:
            logger.warning("Could not extract detailed training metrics")
    
    return model, history

## Performance Evaluation

Let's create functions to evaluate model performance using various metrics including accuracy, ROC-AUC, confusion matrix, and classification report.

In [61]:
def evaluate_classification_model(model, X, y, threshold=0.5, class_names=None):
    """
    Comprehensively evaluate a classification model.
    
    Parameters
    ----------
    model : TabTransformerClassifier
        Trained classification model
    X : pandas DataFrame
        Features
    y : array-like
        True labels
    threshold : float, default=0.5
        Decision threshold for binary classification
    class_names : list, optional
        List of class names
        
    Returns
    -------
    dict
        Dictionary containing evaluation metrics and plots
    """
    logger.info("Evaluating model performance")
    
    # Check if binary or multiclass
    n_classes = len(np.unique(y))
    is_binary = n_classes == 2
    
    # Get predictions
    y_pred = model.predict(X)
    
    # Get probabilities
    y_proba = model.predict_proba(X)
    
    # Calculate metrics
    metrics = {}
    metrics['accuracy'] = accuracy_score(y, y_pred)
    
    if is_binary:
        # Binary classification metrics
        metrics['roc_auc'] = roc_auc_score(y, y_proba[:, 1])
    else:
        # Multiclass classification metrics
        metrics['roc_auc'] = roc_auc_score(y, y_proba, multi_class='ovr', average='macro')
    
    # Classification report
    if class_names is not None and len(class_names) == n_classes:
        report = classification_report(y, y_pred, target_names=class_names, output_dict=True)
    else:
        report = classification_report(y, y_pred, output_dict=True)
    
    metrics['classification_report'] = report
    
    # Calculate confusion matrix
    cm = confusion_matrix(y, y_pred)
    
    # Log results
    logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
    logger.info(f"ROC-AUC: {metrics['roc_auc']:.4f}")
    
    # Generate plots
    plots = {}
    
    # Confusion Matrix Plot
    plt.figure(figsize=(10, 8))
    if class_names is not None:
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    else:
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plots['confusion_matrix'] = plt.gcf()
    plt.close()
    
    # Class-specific metrics plot
    plt.figure(figsize=(12, 6))
    metrics_df = pd.DataFrame({
        'Precision': [report[str(i)]['precision'] for i in range(n_classes)],
        'Recall': [report[str(i)]['recall'] for i in range(n_classes)],
        'F1-Score': [report[str(i)]['f1-score'] for i in range(n_classes)]
    })
    
    if class_names is not None:
        metrics_df.index = class_names
    
    metrics_df.plot(kind='bar')
    plt.title('Class-specific Metrics')
    plt.ylabel('Score')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plots['class_metrics'] = plt.gcf()
    plt.close()
    
    # ROC curve for binary classification
    if is_binary:
        from sklearn.metrics import roc_curve
        fpr, tpr, _ = roc_curve(y, y_proba[:, 1])
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, label=f'ROC curve (AUC = {metrics["roc_auc"]:.3f})')
        plt.plot([0, 1], [0, 1], 'k--', label='Random')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend()
        plots['roc_curve'] = plt.gcf()
        plt.close()
    
    return {
        'metrics': metrics,
        'plots': plots,
        'confusion_matrix': cm,
        'classification_report': report
    }

def plot_classification_metrics(evaluation_results, save_dir=None):
    """
    Display or save the classification metrics plots.
    
    Parameters
    ----------
    evaluation_results : dict
        Results from evaluate_classification_model
    save_dir : str, optional
        Directory to save plots. If None, plots are displayed
    """
    plots = evaluation_results['plots']
    
    for name, fig in plots.items():
        if save_dir is not None:
            os.makedirs(save_dir, exist_ok=True)
            fig.savefig(os.path.join(save_dir, f"{name}.png"))
        else:
            plt.figure(fig.number)
            plt.show()

## Feature Importance Analysis

Now, let's implement methods to extract and visualize feature importance from the trained TabTransformer model.

In [62]:
def calculate_feature_importance(model, X, y, method='permutation', n_repeats=10):
    """
    Calculate feature importance for TabTransformer model.
    
    Parameters
    ----------
    model : TabTransformerClassifier
        Trained model
    X : pandas DataFrame
        Features
    y : array-like
        Target values
    method : str, default='permutation'
        Method to calculate importance ('permutation' or 'shap')
    n_repeats : int, default=10
        Number of times to permute features
        
    Returns
    -------
    pandas DataFrame
        DataFrame with feature importance scores
    """
    logger.info(f"Calculating feature importance using {method} method")
    
    if method == 'permutation':
        from sklearn.inspection import permutation_importance
        
        # Calculate permutation importance
        result = permutation_importance(
            model, X, y, 
            n_repeats=n_repeats, 
            random_state=42,
            n_jobs=-1  # Use all available cores
        )
        
        # Create DataFrame with results
        importance_df = pd.DataFrame({
            'Feature': X.columns,
            'Importance Mean': result.importances_mean,
            'Importance Std': result.importances_std
        })
        
        # Sort by importance
        importance_df = importance_df.sort_values('Importance Mean', ascending=False).reset_index(drop=True)
        
    elif method == 'shap':
        try:
            import shap
            
            # Create a background dataset for SHAP
            background = X.sample(min(100, len(X)), random_state=42)
            
            # Create explainer
            explainer = shap.Explainer(model.predict, background)
            
            # Calculate SHAP values
            shap_values = explainer(X.sample(min(500, len(X)), random_state=42))
            
            # Calculate mean absolute SHAP value for each feature
            feature_importance = np.abs(shap_values.values).mean(axis=0)
            
            # Create DataFrame with results
            importance_df = pd.DataFrame({
                'Feature': X.columns,
                'Importance Mean': feature_importance,
                'Importance Std': np.abs(shap_values.values).std(axis=0)
            })
            
            # Sort by importance
            importance_df = importance_df.sort_values('Importance Mean', ascending=False).reset_index(drop=True)
            
        except ImportError:
            logger.warning("SHAP not installed. Falling back to permutation importance.")
            return calculate_feature_importance(model, X, y, method='permutation', n_repeats=n_repeats)
    else:
        raise ValueError(f"Unsupported importance method: {method}")
    
    return importance_df

def plot_feature_importance(importance_df, top_n=20, plot_title="Feature Importance"):
    """
    Plot feature importance.
    
    Parameters
    ----------
    importance_df : pandas DataFrame
        DataFrame with feature importance from calculate_feature_importance
    top_n : int, default=20
        Number of top features to plot
    plot_title : str, default="Feature Importance"
        Title for the plot
        
    Returns
    -------
    matplotlib.figure.Figure
        The feature importance plot
    """
    # Get top N features
    df_plot = importance_df.head(top_n).copy()
    
    # Create plot
    plt.figure(figsize=(12, 8))
    plt.barh(
        range(len(df_plot)), 
        df_plot['Importance Mean'],
        xerr=df_plot['Importance Std'],
        align='center',
        alpha=0.8
    )
    plt.yticks(range(len(df_plot)), df_plot['Feature'])
    plt.xlabel('Importance')
    plt.ylabel('Feature')
    plt.title(plot_title)
    plt.tight_layout()
    
    return plt.gcf()

## Batch Prediction Implementation

Let's create a memory-efficient batch prediction system for handling large datasets that wouldn't fit in memory.

In [63]:
class BatchPredictor:
    """
    A memory-efficient batch prediction system for large datasets.
    
    Parameters
    ----------
    model : TabTransformerClassifier
        Trained model
    batch_size : int, default=1000
        Size of each prediction batch
    """
    
    def __init__(self, model, batch_size=1000):
        self.model = model
        self.batch_size = batch_size
    
    def predict(self, X, output_file=None, include_proba=False):
        """
        Make predictions in batches.
        
        Parameters
        ----------
        X : pandas DataFrame or str
            Input features or path to CSV/parquet file
        output_file : str, optional
            Path to save predictions. If None, returns predictions directly
        include_proba : bool, default=False
            Whether to include class probabilities in the output
            
        Returns
        -------
        pandas DataFrame or None
            DataFrame with predictions if output_file is None, otherwise None
        """
        # Check if X is a file path
        if isinstance(X, str):
            return self._predict_from_file(X, output_file, include_proba)
        else:
            return self._predict_from_dataframe(X, output_file, include_proba)
    
    def _predict_from_dataframe(self, X, output_file, include_proba):
        """Make predictions from DataFrame."""
        logger.info(f"Making batch predictions on DataFrame with {X.shape[0]} rows")
        
        total_rows = X.shape[0]
        results = []
        
        # Process in batches
        for start_idx in range(0, total_rows, self.batch_size):
            end_idx = min(start_idx + self.batch_size, total_rows)
            batch = X.iloc[start_idx:end_idx]
            
            # Create result DataFrame for this batch
            batch_result = pd.DataFrame({
                'prediction': self.model.predict(batch)
            })
            
            # Add probabilities if requested
            if include_proba:
                probas = self.model.predict_proba(batch)
                
                # Handle binary vs multiclass
                if probas.shape[1] == 2:  # Binary
                    batch_result['probability'] = probas[:, 1]
                else:  # Multiclass
                    for i in range(probas.shape[1]):
                        batch_result[f'probability_class_{i}'] = probas[:, i]
            
            results.append(batch_result)
            
            # Log progress
            if (start_idx // self.batch_size) % 10 == 0:
                logger.info(f"Processed {end_idx}/{total_rows} rows ({end_idx/total_rows*100:.1f}%)")
        
        # Combine results
        result_df = pd.concat(results, ignore_index=True)
        
        # Save to file if requested
        if output_file:
            # Determine file format
            if output_file.endswith('.csv'):
                result_df.to_csv(output_file, index=False)
            elif output_file.endswith('.parquet'):
                result_df.to_parquet(output_file, index=False)
            else:
                result_df.to_csv(output_file, index=False)  # Default to CSV
                
            logger.info(f"Predictions saved to {output_file}")
            return None
        else:
            return result_df
    
    def _predict_from_file(self, file_path, output_file, include_proba):
        """Make predictions from file path."""
        # Check file format
        if file_path.endswith('.csv'):
            reader = pd.read_csv
        elif file_path.endswith('.parquet'):
            reader = pd.read_parquet
        else:
            raise ValueError(f"Unsupported file format: {file_path}")
        
        # Create output file if needed
        if output_file:
            # Choose writer based on output format
            if output_file.endswith('.csv'):
                write_header = True
                
                def write_batch(df, path, mode):
                    nonlocal write_header
                    df.to_csv(path, mode=mode, header=write_header, index=False)
                    write_header = False
                    
            elif output_file.endswith('.parquet'):
                raise ValueError("Cannot append to parquet files in batch mode. Use CSV output or in-memory processing.")
            else:
                # Default to CSV
                write_header = True
                
                def write_batch(df, path, mode):
                    nonlocal write_header
                    df.to_csv(path, mode=mode, header=write_header, index=False)
                    write_header = False
            
            # Create new file
            if output_file:
                if os.path.exists(output_file):
                    os.remove(output_file)
        
        # Process file in batches
        logger.info(f"Making batch predictions from file: {file_path}")
        
        batch_idx = 0
        total_rows = 0
        
        for batch_df in reader(file_path, chunksize=self.batch_size):
            # Make predictions
            batch_result = pd.DataFrame({
                'prediction': self.model.predict(batch_df)
            })
            
            # Add probabilities if requested
            if include_proba:
                probas = self.model.predict_proba(batch_df)
                
                # Handle binary vs multiclass
                if probas.shape[1] == 2:  # Binary
                    batch_result['probability'] = probas[:, 1]
                else:  # Multiclass
                    for i in range(probas.shape[1]):
                        batch_result[f'probability_class_{i}'] = probas[:, i]
            
            # Write to file if output file provided
            if output_file:
                write_batch(batch_result, output_file, 'a' if batch_idx > 0 else 'w')
            else:
                # In-memory processing - not recommended for large files
                if batch_idx == 0:
                    result_df = batch_result
                else:
                    result_df = pd.concat([result_df, batch_result], ignore_index=True)
            
            # Update counters and log progress
            batch_idx += 1
            total_rows += len(batch_df)
            
            if batch_idx % 10 == 0:
                logger.info(f"Processed {batch_idx} batches, {total_rows} rows")
        
        logger.info(f"Completed processing {total_rows} rows in {batch_idx} batches")
        
        if output_file:
            return None
        else:
            return result_df

## Example Usage

Finally, let's demonstrate the end-to-end usage of the TabTransformer classifier on a sample classification dataset.

In [64]:
def run_tabtransformer_example():
    """
    End-to-end example of TabTransformer classifier on a sample dataset.
    """
    logger.info("Starting TabTransformer example")
    
    # Load a sample dataset (UCI Adult Income)
    from sklearn.datasets import fetch_openml
    
    df = pd.read_csv('train.csv')
    X = df.drop(columns=['salary_category'])
    y = df['salary_category']
    
    # Quick preprocessing
    # Convert target to binary (>50K = 1, <=50K = 0)
    y = (y == '>50K').astype(int)
    
    # Identify categorical and continuous columns
    categorical_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
    continuous_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42, stratify=y_train)
    
    # Prepare data dictionary
    data_dict = {
        'X_train': X_train,
        'y_train': y_train,
        'X_val': X_val,
        'y_val': y_val,
        'X_test': X_test,
        'y_test': y_test,
        'categorical_cols': categorical_cols,
        'continuous_cols': continuous_cols,
        'target_col': 'target'  # Will be used when adding target to DataFrame
    }
    
    # Sample hyperparameters (for quick demo, skip Optuna optimization)
    hyperparams = {
        'num_heads': 4,
        'num_attn_blocks': 3,
        'attn_dropout': 0.2,
        'ff_dropout': 0.2,
        'mlp_dropout': 0.1,
        'lr': 5e-3,
        'weight_decay': 1e-5,
        'batch_size': 128
    }
    
    # Train model
    model, history = train_tabtransformer_with_monitoring(
        data_dict, 
        hyperparams=hyperparams, 
        model_dir='./models',
        max_epochs=10  # Using small number of epochs for demo
    )
    
    # Evaluate model
    eval_results = evaluate_classification_model(
        model, 
        X_test, 
        y_test, 
        class_names=['<=50K', '>50K']
    )
    
    # Print metrics
    metrics = eval_results['metrics']
    print("\n===== Model Performance =====")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
    print("\n===== Classification Report =====")
    report = metrics['classification_report']
    print(f"Precision (>50K): {report['1']['precision']:.4f}")
    print(f"Recall (>50K): {report['1']['recall']:.4f}")
    print(f"F1-score (>50K): {report['1']['f1-score']:.4f}")
    
    # Plot confusion matrix
    plt.figure(eval_results['plots']['confusion_matrix'].number)
    plt.title('Confusion Matrix')
    plt.show()
    
    # Calculate feature importance
    importance = calculate_feature_importance(model, X_test, y_test, method='permutation', n_repeats=5)
    
    # Plot feature importance
    importance_plot = plot_feature_importance(importance, top_n=10)
    plt.title('Feature Importance')
    plt.show()
    
    # Demonstrate batch prediction
    batch_predictor = BatchPredictor(model, batch_size=1000)
    
    # Make predictions on test set
    print("\n===== Batch Prediction Example =====")
    predictions = batch_predictor.predict(X_test.iloc[:100], include_proba=True)
    print(predictions.head())
    
    return model, eval_results, importance


if __name__ == "__main__":
    model, eval_results, importance = run_tabtransformer_example()

2025-04-26 01:33:50,217 - INFO - Starting TabTransformer example
2025-04-26 01:33:50,282 - INFO - Starting TabTransformer training with monitoring
2025-04-26 01:33:50,283 - INFO - Starting model fitting


TypeError: OptimizerConfig.__init__() got an unexpected keyword argument 'learning_rate'

In [None]:
print(eval_results)

## Conclusion

In this notebook, we've implemented a complete TabTransformer classification system using PyTorch-Tabular, with the following components:

1. A scikit-learn compatible TabTransformerClassifier
2. Data preparation utilities for tabular data
3. Hyperparameter optimization with Optuna
4. Model training with early stopping
5. Comprehensive performance evaluation
6. Feature importance analysis
7. Memory-efficient batch prediction

The TabTransformer architecture leverages the power of self-attention mechanisms for tabular data, often outperforming traditional models like XGBoost and deep neural networks on structured data tasks.

To use this implementation in your projects:
1. Install the required dependencies
2. Adapt the data preparation for your specific dataset
3. Run hyperparameter optimization to find optimal configurations
4. Train the model with early stopping
5. Evaluate and analyze the model performance
6. Use batch prediction for efficient inference on large datasets