In [None]:
# %% ======================= CELL 1.5: GRID SEARCH =======================
import itertools
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')

def run_single_config(df_train, df_val, cont_cols, cat_cols, label_col, num_classes, base_cfg, config):
    """Run a single configuration and return results."""
    # Merge base config with current grid config
    current_cfg = {**base_cfg, **config}
    
    try:
        model, scaler, history = run_ft_transformer(
            df_train=df_train, df_val=df_val,
            cont_cols=cont_cols, cat_cols=cat_cols,
            label_col=label_col, num_classes=num_classes, 
            cfg=current_cfg
        )
        
        # Extract best validation F1 and corresponding train F1
        best_epoch = history["best_epoch"] - 1  # Convert to 0-indexed
        best_val_f1 = history["val_f1"][best_epoch]
        best_train_f1 = history["train_f1"][best_epoch]
        best_val_loss = history["val_loss"][best_epoch]
        
        # Calculate overfitting gap
        overfitting_gap = best_train_f1 - best_val_f1
        
        return {
            'config': config,
            'best_val_f1': best_val_f1,
            'best_train_f1': best_train_f1,
            'best_val_loss': best_val_loss,
            'overfitting_gap': overfitting_gap,
            'best_epoch': history["best_epoch"],
            'history': history,
            'status': 'success'
        }
    
    except Exception as e:
        print(f"Config {config} failed with error: {str(e)}")
        return {
            'config': config,
            'best_val_f1': -1.0,
            'best_train_f1': -1.0,
            'best_val_loss': float('inf'),
            'overfitting_gap': 0.0,
            'best_epoch': 0,
            'history': None,
            'status': 'failed',
            'error': str(e)
        }

def grid_search_ft_transformer(df_train, df_val, cont_cols, cat_cols, label_col, num_classes, 
                              base_cfg, param_grid, max_runs=None, sort_by='val_f1'):
    """
    Perform grid search for FT-Transformer hyperparameters.
    
    Args:
        param_grid: Dictionary with parameter names and lists of values to try
        max_runs: Maximum number of configurations to run (None for all)
        sort_by: Metric to sort results by ('val_f1', 'overfitting_gap', 'val_loss')
    """
    print("Starting Grid Search...")
    print(f"Parameter grid: {param_grid}")
    print(f"Total possible combinations: {np.prod([len(v) for v in param_grid.values()])}")
    
    # Generate all parameter combinations
    param_names = list(param_grid.keys())
    param_values = list(param_grid.values())
    all_combinations = list(itertools.product(*param_values))
    
    if max_runs and len(all_combinations) > max_runs:
        print(f"Sampling {max_runs} random combinations from {len(all_combinations)} total...")
        rng = np.random.RandomState(base_cfg.get("seed", 42))
        indices = rng.choice(len(all_combinations), size=max_runs, replace=False)
        combinations = [all_combinations[i] for i in indices]
    else:
        combinations = all_combinations
    
    results = []
    
    for i, comb in enumerate(combinations):
        config_dict = dict(zip(param_names, comb))
        print(f"\n{'='*60}")
        print(f"Run {i+1}/{len(combinations)}: {config_dict}")
        print(f"{'='*60}")
        
        result = run_single_config(
            df_train, df_val, cont_cols, cat_cols, label_col, num_classes,
            base_cfg, config_dict
        )
        
        if result['status'] == 'success':
            print(f"✓ Val F1: {result['best_val_f1']:.4f} | "
                  f"Train F1: {result['best_train_f1']:.4f} | "
                  f"Overfitting: {result['overfitting_gap']:.4f} | "
                  f"Best Epoch: {result['best_epoch']}")
        else:
            print(f"✗ Failed: {result.get('error', 'Unknown error')}")
        
        results.append(result)
    
    # Filter successful runs and sort
    successful_results = [r for r in results if r['status'] == 'success']
    
    if not successful_results:
        print("No successful runs! Check your parameter ranges.")
        return results
    
    # Sort by specified metric
    reverse = sort_by != 'val_loss'  # Higher is better for F1, lower for loss
    successful_results.sort(key=lambda x: x[sort_by], reverse=reverse)
    
    print(f"\n{'#'*80}")
    print(f"GRID SEARCH COMPLETE - Top 5 Configurations (sorted by {sort_by}):")
    print(f"{'#'*80}")
    
    for i, result in enumerate(successful_results[:5]):
        print(f"\nRank {i+1}:")
        print(f"  Config: {result['config']}")
        print(f"  Val F1: {result['best_val_f1']:.4f}")
        print(f"  Train F1: {result['best_train_f1']:.4f}")
        print(f"  Overfitting gap: {result['overfitting_gap']:.4f}")
        print(f"  Val Loss: {result['best_val_loss']:.4f}")
        print(f"  Best Epoch: {result['best_epoch']}")
    
    return successful_results

def plot_grid_search_results(results, top_k=10):
    """Plot comparison of top configurations."""
    if not results:
        print("No results to plot")
        return
    
    top_results = results[:top_k]
    
    # Create summary plot
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Val F1 vs Train F1
    val_f1s = [r['best_val_f1'] for r in top_results]
    train_f1s = [r['best_train_f1'] for r in top_results]
    config_labels = [f"Config{i+1}" for i in range(len(top_results))]
    
    ax1.scatter(val_f1s, train_f1s, alpha=0.6, s=60)
    for i, (v, t) in enumerate(zip(val_f1s, train_f1s)):
        ax1.annotate(f"{i+1}", (v, t), xytext=(5, 5), textcoords='offset points', fontsize=8)
    ax1.plot([0, 1], [0, 1], 'r--', alpha=0.3)
    ax1.set_xlabel('Validation F1'); ax1.set_ylabel('Train F1')
    ax1.set_title('Train vs Validation F1'); ax1.grid(True, alpha=0.3)
    
    # Overfitting gap
    gaps = [r['overfitting_gap'] for r in top_results]
    ax2.bar(range(len(gaps)), gaps, alpha=0.7)
    ax2.set_xlabel('Configuration'); ax2.set_ylabel('Overfitting Gap (Train F1 - Val F1)')
    ax2.set_title('Overfitting Gap by Configuration'); ax2.grid(True, alpha=0.3)
    
    # Validation F1 comparison
    ax3.bar(range(len(val_f1s)), val_f1s, alpha=0.7, color='green')
    ax3.set_xlabel('Configuration'); ax3.set_ylabel('Validation F1')
    ax3.set_title('Validation F1 by Configuration'); ax3.grid(True, alpha=0.3)
    
    # Validation Loss comparison
    val_losses = [r['best_val_loss'] for r in top_results]
    ax4.bar(range(len(val_losses)), val_losses, alpha=0.7, color='orange')
    ax4.set_xlabel('Configuration'); ax4.set_ylabel('Validation Loss')
    ax4.set_title('Validation Loss by Configuration'); ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return fig

In [None]:
# %% ======================= CELL 2.5: GRID SEARCH EXECUTION =======================

# Define your parameter grid - focus on regularization to combat overfitting
PARAM_GRID = {
    # Regularization-focused parameters
    'attn_dropout': [0.2, 0.3, 0.4],           # Increased dropout for attention
    'ff_dropout': [0.3, 0.4, 0.5],             # Increased dropout for FF layers  
    'token_dropout': [0.2, 0.3, 0.4],          # Increased token dropout
    'mlp_dropout': [0.4, 0.5, 0.6],            # Increased MLP dropout
    
    # Model capacity (smaller to reduce overfitting)
    'd_token': [128, 192, 256],                # Smaller token dimensions
    'n_layers': [4, 6],                        # Fewer layers
    'mlp_hidden': [[256, 128], [512, 256]],    # Smaller MLP heads
    
    # Training regularization
    'weight_decay': [1e-2, 2e-3, 5e-3],        # Stronger weight decay
    'lr': [1e-4, 3e-4, 1e-3],                  # Learning rate variations
}

# Optional: Smaller grid for quick testing
QUICK_PARAM_GRID = {
    'attn_dropout': [0.2, 0.3],
    'ff_dropout': [0.3, 0.4],
    'token_dropout': [0.2, 0.3],
    'weight_decay': [2e-3, 5e-3],
    'd_token': [192, 256],
}

# Run grid search
grid_results = grid_search_ft_transformer(
    df_train=df_train, 
    df_val=df_val,
    cont_cols=cont_cols, 
    cat_cols=cat_cols,
    label_col=TARGET_COL, 
    num_classes=num_classes,
    base_cfg=CFG,  # Your base configuration
    param_grid=PARAM_GRID,  # or QUICK_PARAM_GRID for faster testing
    max_runs=20,    # Limit number of runs for practicality
    sort_by='val_f1'  # Sort by validation F1 score
)

# Plot results
if grid_results:
    plot_grid_search_results(grid_results, top_k=10)
    
    # Get best configuration
    best_config = grid_results[0]['config']
    print(f"\n🎯 BEST CONFIGURATION:")
    print(f"Parameters: {best_config}")
    print(f"Validation F1: {grid_results[0]['best_val_f1']:.4f}")
    print(f"Overfitting gap: {grid_results[0]['overfitting_gap']:.4f}")
    
    # You can now retrain with the best configuration
    print("\nTo use the best configuration, update your CFG:")
    print("CFG.update({")
    for k, v in best_config.items():
        print(f"    '{k}': {v},")
    print("})")

In [None]:
# Enhanced regularization strategies
ENHANCED_REGULARIZATION_GRID = {
    # Even stronger regularization
    'attn_dropout': [0.4, 0.5],
    'ff_dropout': [0.5, 0.6],
    'token_dropout': [0.4, 0.5],
    'mlp_dropout': [0.5, 0.6],
    
    # Architectural constraints
    'd_token': [64, 128],
    'n_layers': [2, 3, 4],
    'n_heads': [4, 6],
    'mlp_hidden': [[128, 64], [256, 128]],
    
    # Training strategies
    'weight_decay': [1e-2, 2e-2],
    'lr': [5e-5, 1e-4],
    'batch_size': [128, 192],  # Smaller batches can regularize
    'grad_clip_norm': [0.5, 1.0],
}

# Time-series specific strategies
TIME_SERIES_GRID = {
    # Add stochastic depth for deeper networks
    'token_dropout': [0.3, 0.4, 0.5],
    # Use layer dropout if you modify the transformer
    'layer_dropout': [0.1, 0.2],  # You'd need to implement this
    # Label smoothing
    'label_smoothing': [0.1, 0.2],  # Modify loss function
}