In [None]:
"""
TabPFN (Tabular Prior-Fitted Networks) - Complete Implementation
================================================================

A comprehensive implementation of TabPFN with advanced features including:
- Sophisticated synthetic data generation with structural causal models
- State-of-the-art transformer architecture with cross-attention
- Efficient training and inference pipelines
- Real-world dataset integration capabilities
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
from typing import Tuple, Optional, List, Dict, Any, Union
from dataclasses import dataclass
import warnings
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import json
import os

warnings.filterwarnings("ignore")

# ===========================
# Configuration Classes
# ===========================

@dataclass
class TabPFNConfig:
    """Configuration for TabPFN model and training"""
    # Model architecture
    max_features: int = 100
    max_samples: int = 1024
    hidden_dim: int = 512
    num_layers: int = 12
    num_heads: int = 8
    max_classes: int = 10
    dropout: float = 0.0
    
    # Training
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    batch_size: int = 512
    num_epochs: int = 100
    warmup_steps: int = 1000
    gradient_clip: float = 1.0
    
    # Data generation
    min_features: int = 3
    min_samples: int = 50
    max_samples_train: int = 512
    num_train_tasks: int = 100000
    num_val_tasks: int = 10000
    
    # Advanced features
    use_feature_embedding: bool = True
    use_positional_encoding: bool = True
    use_cross_attention: bool = True
    
    def save(self, path: str):
        """Save configuration to JSON file"""
        with open(path, 'w') as f:
            json.dump(self.__dict__, f, indent=2)
    
    @classmethod
    def load(cls, path: str):
        """Load configuration from JSON file"""
        with open(path, 'r') as f:
            config_dict = json.load(f)
        return cls(**config_dict)

# ===========================
# Advanced Data Generation
# ===========================

class StructuralCausalModel:
    """Advanced SCM for generating diverse synthetic datasets"""
    
    def __init__(self, seed: Optional[int] = None):
        if seed is not None:
            np.random.seed(seed)
        
        self.causal_mechanisms = {
            'linear': self._linear_mechanism,
            'polynomial': self._polynomial_mechanism,
            'interaction': self._interaction_mechanism,
            'threshold': self._threshold_mechanism,
            'periodic': self._periodic_mechanism,
            'mixture': self._mixture_mechanism
        }
    
    def _linear_mechanism(self, X: np.ndarray, indices: np.ndarray, 
                         params: Dict[str, Any]) -> np.ndarray:
        """Linear combination with random weights"""
        weights = params.get('weights', np.random.randn(len(indices)))
        # Ensure weights match the number of indices
        if len(weights) != len(indices):
            weights = np.random.randn(len(indices))
        return X[:, indices] @ weights
    
    def _polynomial_mechanism(self, X: np.ndarray, indices: np.ndarray,
                            params: Dict[str, Any]) -> np.ndarray:
        """Polynomial transformations"""
        degree = params.get('degree', np.random.randint(2, 4))
        weights = params.get('weights', np.random.randn(len(indices)))
        # Ensure weights match the number of indices
        if len(weights) != len(indices):
            weights = np.random.randn(len(indices))
        result = np.zeros(X.shape[0])
        
        for i, idx in enumerate(indices):
            result += weights[i] * (X[:, idx] ** degree)
        
        return result
    
    def _interaction_mechanism(self, X: np.ndarray, indices: np.ndarray,
                             params: Dict[str, Any]) -> np.ndarray:
        """Multiplicative interactions between features"""
        if len(indices) < 2:
            return self._linear_mechanism(X, indices, params)
        
        result = np.zeros(X.shape[0])
        for i in range(len(indices)):
            for j in range(i + 1, len(indices)):
                weight = np.random.randn()
                result += weight * X[:, indices[i]] * X[:, indices[j]]
        
        return result
    
    def _threshold_mechanism(self, X: np.ndarray, indices: np.ndarray,
                           params: Dict[str, Any]) -> np.ndarray:
        """Piecewise linear with thresholds"""
        base = self._linear_mechanism(X, indices, params)
        threshold = params.get('threshold', np.random.randn())
        scale_high = params.get('scale_high', 2.0)
        scale_low = params.get('scale_low', 0.5)
        
        return np.where(base > threshold, base * scale_high, base * scale_low)
    
    def _periodic_mechanism(self, X: np.ndarray, indices: np.ndarray,
                          params: Dict[str, Any]) -> np.ndarray:
        """Sinusoidal transformations"""
        frequency = params.get('frequency', np.random.uniform(0.5, 2.0))
        phase = params.get('phase', np.random.uniform(0, 2 * np.pi))
        weights = params.get('weights', np.random.randn(len(indices)))
        # Ensure weights match the number of indices
        if len(weights) != len(indices):
            weights = np.random.randn(len(indices))
        
        linear_combo = X[:, indices] @ weights
        return np.sin(frequency * linear_combo + phase)
    
    def _mixture_mechanism(self, X: np.ndarray, indices: np.ndarray,
                         params: Dict[str, Any]) -> np.ndarray:
        """Mixture of different mechanisms"""
        mechanisms = ['linear', 'polynomial', 'interaction', 'periodic']
        chosen = np.random.choice(mechanisms, size=2, replace=False)
        
        result = 0
        for mech_name in chosen:
            mech_func = self.causal_mechanisms[mech_name]
            if mech_name != 'mixture':  # Avoid recursion
                result += mech_func(X, indices, params) * np.random.uniform(0.3, 0.7)
        
        return result
    
    def generate_dataset(self, num_samples: int, num_features: int,
                        num_classes: int, complexity: float = 0.1,
                        noise_level: float = 0.1) -> Tuple[np.ndarray, np.ndarray]:
        """Generate a complete synthetic dataset"""
        
        # Generate base features with diverse distributions
        X = self._generate_features(num_samples, num_features)
        
        # Select causal features
        num_causal = max(2, int(num_features * complexity))
        causal_indices = np.random.choice(num_features, size=num_causal, replace=False)
        
        # Generate target using multiple mechanisms
        y_continuous = self._generate_target(X, causal_indices, num_mechanisms=3)
        
        # Add noise
        y_continuous += np.random.normal(0, noise_level, num_samples)
        
        # Convert to classes
        y = self._continuous_to_classes(y_continuous, num_classes)
        
        return X.astype(np.float32), y.astype(np.int64)
    
    def _generate_features(self, num_samples: int, num_features: int) -> np.ndarray:
        """Generate features with diverse distributions"""
        X = np.zeros((num_samples, num_features))
        
        distributions = [
            ('normal', lambda: np.random.normal(0, np.random.uniform(0.5, 2.0), num_samples)),
            ('uniform', lambda: np.random.uniform(-2, 2, num_samples)),
            ('exponential', lambda: np.random.exponential(np.random.uniform(0.5, 2.0), num_samples) - 1),
            ('beta', lambda: np.random.beta(2, 5, num_samples) * 4 - 1),
            ('gamma', lambda: np.random.gamma(2, 2, num_samples) - 2),
            ('laplace', lambda: np.random.laplace(0, np.random.uniform(0.5, 1.5), num_samples))
        ]
        
        for i in range(num_features):
            dist_name, dist_func = distributions[i % len(distributions)]
            X[:, i] = dist_func()
            
            # Add correlations between some features
            if i > 0 and np.random.random() < 0.3:
                correlation_strength = np.random.uniform(-0.8, 0.8)
                X[:, i] = correlation_strength * X[:, i-1] + np.sqrt(1 - correlation_strength**2) * X[:, i]
        
        return X
    
    def _generate_target(self, X: np.ndarray, causal_indices: np.ndarray,
                        num_mechanisms: int = 3) -> np.ndarray:
        """Generate target variable using multiple causal mechanisms"""
        y_components = []
        
        for _ in range(num_mechanisms):
            # Select mechanism and subset of causal features
            mechanism_name = np.random.choice(list(self.causal_mechanisms.keys()))
            mechanism = self.causal_mechanisms[mechanism_name]
            
            subset_size = np.random.randint(1, min(len(causal_indices), 5) + 1)
            subset_indices = np.random.choice(causal_indices, size=subset_size, replace=False)
            
            # Generate component with random parameters matching the subset size
            params = self._generate_mechanism_params(mechanism_name, num_features=len(subset_indices))
            component = mechanism(X, subset_indices, params)
            
            # Apply random scaling
            component *= np.random.uniform(0.5, 2.0)
            y_components.append(component)
        
        # Combine components
        return np.sum(y_components, axis=0)
    
    def _generate_mechanism_params(self, mechanism_name: str, num_features: int = None) -> Dict[str, Any]:
        """Generate random parameters for a mechanism"""
        params = {}
        
        if mechanism_name in ['linear', 'polynomial', 'periodic']:
            # Make sure weights match the number of features that will be used
            if num_features is not None:
                params['weights'] = np.random.randn(num_features)
            else:
                params['weights'] = np.random.randn(np.random.randint(1, 5))
        
        if mechanism_name == 'polynomial':
            params['degree'] = np.random.randint(2, 4)
        
        if mechanism_name == 'threshold':
            params['threshold'] = np.random.randn()
            params['scale_high'] = np.random.uniform(1.5, 3.0)
            params['scale_low'] = np.random.uniform(0.1, 0.7)
        
        if mechanism_name == 'periodic':
            params['frequency'] = np.random.uniform(0.5, 3.0)
            params['phase'] = np.random.uniform(0, 2 * np.pi)
        
        return params
    
    def _continuous_to_classes(self, y_continuous: np.ndarray, num_classes: int) -> np.ndarray:
        """Convert continuous values to class labels"""
        if num_classes == 2:
            # Binary classification
            threshold = np.percentile(y_continuous, 50)
            return (y_continuous > threshold).astype(np.int64)
        else:
            # Multi-class classification
            percentiles = np.linspace(0, 100, num_classes + 1)[1:-1]
            thresholds = np.percentile(y_continuous, percentiles)
            return np.digitize(y_continuous, thresholds).astype(np.int64)

# ===========================
# TabPFN Dataset
# ===========================

class TabPFNDataset(Dataset):
    """Dataset for meta-learning on synthetic tabular data"""
    
    def __init__(self, config: TabPFNConfig, num_tasks: int, split: str = 'train'):
        self.config = config
        self.num_tasks = num_tasks
        self.split = split
        self.scm = StructuralCausalModel()
        
        # Different settings for train/val
        if split == 'train':
            self.max_samples = config.max_samples_train
        else:
            self.max_samples = config.max_samples
    
    def __len__(self):
        return self.num_tasks
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # Set random seed for reproducibility
        np.random.seed(idx if self.split == 'val' else None)
        
        # Sample task parameters
        num_features = np.random.randint(self.config.min_features, self.config.max_features + 1)
        num_samples = np.random.randint(self.config.min_samples, self.max_samples + 1)
        num_classes = np.random.randint(2, self.config.max_classes + 1)
        # complexity = np.random.uniform(0.2, 0.8)
        # noise_level = np.random.uniform(0.05, 0.2)
        
        complexity = 0.1
        noise_level = 0.0
        
        # Generate data
        X, y = self.scm.generate_dataset(
            num_samples=num_samples,
            num_features=num_features,
            num_classes=num_classes,
            complexity=complexity,
            noise_level=noise_level
        )
        
        # Normalize features
        X = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)
        
        # Split into context and query
        split_idx = np.random.randint(num_samples // 4, 3 * num_samples // 4)
        
        context_X = torch.from_numpy(X[:split_idx]).float()
        context_y = torch.from_numpy(y[:split_idx]).long()
        query_X = torch.from_numpy(X[split_idx:]).float()
        query_y = torch.from_numpy(y[split_idx:]).long()
        
        return {
            'context_X': context_X,
            'context_y': context_y,
            'query_X': query_X,
            'query_y': query_y,
            'num_features': num_features,
            'num_classes': num_classes,
            'num_context': split_idx,
            'num_query': num_samples - split_idx
        }

def collate_tabpfn_batch(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """Custom collate function for TabPFN batches"""
    batch_size = len(batch)
    
    # Find maximum dimensions
    max_features = max(item['num_features'] for item in batch)
    max_context = max(item['num_context'] for item in batch)
    max_query = max(item['num_query'] for item in batch)
    max_classes = max(item['num_classes'] for item in batch)
    
    # Get the configured max features from any item (they should all have the same config)
    config_max_features = batch[0]['context_X'].size(-1) if hasattr(batch[0], 'config_max_features') else max_features
    
    # Initialize padded tensors
    context_X = torch.zeros(batch_size, max_context, max_features)
    context_y = torch.zeros(batch_size, max_context, dtype=torch.long)
    query_X = torch.zeros(batch_size, max_query, max_features)
    query_y = torch.zeros(batch_size, max_query, dtype=torch.long)
    
    # Masks for padding
    context_mask = torch.zeros(batch_size, max_context, dtype=torch.bool)
    query_mask = torch.zeros(batch_size, max_query, dtype=torch.bool)
    feature_mask = torch.zeros(batch_size, max_features, dtype=torch.bool)
    
    # Fill tensors
    for i, item in enumerate(batch):
        n_features = item['num_features']
        n_context = item['num_context']
        n_query = item['num_query']
        
        context_X[i, :n_context, :n_features] = item['context_X']
        context_y[i, :n_context] = item['context_y']
        query_X[i, :n_query, :n_features] = item['query_X']
        query_y[i, :n_query] = item['query_y']
        
        context_mask[i, :n_context] = True
        query_mask[i, :n_query] = True
        feature_mask[i, :n_features] = True
    
    return {
        'context_X': context_X,
        'context_y': context_y,
        'query_X': query_X,
        'query_y': query_y,
        'context_mask': context_mask,
        'query_mask': query_mask,
        'feature_mask': feature_mask,
        'num_classes': torch.tensor([item['num_classes'] for item in batch]),
        'max_classes': max_classes,
        'max_features': max_features
    }

# ===========================
# Model Components
# ===========================

class FeatureEncoder(nn.Module):
    """Encode variable-length features to fixed dimension"""
    
    def __init__(self, max_features: int, hidden_dim: int):
        super().__init__()
        self.max_features = max_features
        self.hidden_dim = hidden_dim
        
        # Feature embedding
        self.feature_embed = nn.Linear(max_features, hidden_dim)
        
        # Feature-wise attention
        self.feature_attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x: torch.Tensor, feature_mask: torch.Tensor) -> torch.Tensor:
        # x: [batch_size, seq_len, actual_features]
        # feature_mask: [batch_size, actual_features]
        
        batch_size, seq_len, actual_features = x.shape
        
        # Pad features to max_features if necessary
        if actual_features < self.max_features:
            padding = torch.zeros(batch_size, seq_len, self.max_features - actual_features, 
                                device=x.device, dtype=x.dtype)
            x = torch.cat([x, padding], dim=-1)
            
            # Extend feature mask
            mask_padding = torch.zeros(batch_size, self.max_features - actual_features, 
                                     device=feature_mask.device, dtype=feature_mask.dtype)
            feature_mask_extended = torch.cat([feature_mask, mask_padding], dim=-1)
        else:
            feature_mask_extended = feature_mask
        
        # Apply feature mask
        x = x * feature_mask_extended.unsqueeze(1).float()
        
        # Embed features
        embedded = self.feature_embed(x)
        
        # Apply feature-wise attention
        attention = self.feature_attention(embedded)
        
        return embedded * attention

class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""
    
    def __init__(self, hidden_dim: int, max_len: int = 5000):
        super().__init__()
        
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * 
                           (-math.log(10000.0) / hidden_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch_size, seq_len, hidden_dim]
        return x + self.pe[:x.size(1)]

class CrossAttentionLayer(nn.Module):
    """Cross-attention between query and context"""
    
    def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.0):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query: torch.Tensor, context: torch.Tensor,
                context_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Create attention mask
        if context_mask is not None:
            # context_mask: [batch_size, context_len]
            # Convert to attention mask format
            attn_mask = ~context_mask  # True positions are masked
        else:
            attn_mask = None
        
        # Apply cross-attention
        attended, _ = self.attention(
            query, context, context,
            key_padding_mask=attn_mask
        )
        
        # Residual connection and normalization
        return self.norm(query + self.dropout(attended))

# ===========================
# Main TabPFN Model
# ===========================

class TabPFN(nn.Module):
    """Tabular Prior-Fitted Network"""
    
    def __init__(self, config: TabPFNConfig):
        super().__init__()
        self.config = config
        
        # Feature encoding
        self.feature_encoder = FeatureEncoder(config.max_features, config.hidden_dim)
        
        # Label embedding
        self.label_embed = nn.Embedding(config.max_classes, config.hidden_dim)
        
        # Positional encoding
        if config.use_positional_encoding:
            self.pos_encoder = PositionalEncoding(config.hidden_dim, config.max_samples)
        
        # Type embeddings (context vs query)
        self.context_type_embed = nn.Parameter(torch.randn(config.hidden_dim))
        self.query_type_embed = nn.Parameter(torch.randn(config.hidden_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.hidden_dim,
            nhead=config.num_heads,
            dim_feedforward=config.hidden_dim * 4,
            dropout=config.dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
        
        # Cross-attention layers (if enabled)
        if config.use_cross_attention:
            self.cross_attention = nn.ModuleList([
                CrossAttentionLayer(config.hidden_dim, config.num_heads, config.dropout)
                for _ in range(3)
            ])
        
        # Output head
        self.output_head = nn.Sequential(
            nn.LayerNorm(config.hidden_dim),
            nn.Linear(config.hidden_dim, config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, config.max_classes)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.02)
    
    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Forward pass
        
        Args:
            batch: Dictionary containing:
                - context_X: [batch_size, context_len, actual_features]
                - context_y: [batch_size, context_len]
                - query_X: [batch_size, query_len, actual_features]
                - context_mask: [batch_size, context_len]
                - query_mask: [batch_size, query_len]
                - feature_mask: [batch_size, actual_features]
        
        Returns:
            logits: [batch_size, query_len, max_classes]
        """
        context_X = batch['context_X']
        context_y = batch['context_y']
        query_X = batch['query_X']
        context_mask = batch['context_mask']
        query_mask = batch['query_mask']
        feature_mask = batch['feature_mask']
        
        batch_size = context_X.size(0)
        context_len = context_X.size(1)
        query_len = query_X.size(1)
        
        # Encode features (FeatureEncoder will handle padding internally)
        context_encoded = self.feature_encoder(context_X, feature_mask)
        query_encoded = self.feature_encoder(query_X, feature_mask)
        
        # Add label embeddings to context
        context_labels = self.label_embed(context_y)
        context_encoded = context_encoded + context_labels
        
        # Add type embeddings
        context_encoded = context_encoded + self.context_type_embed
        query_encoded = query_encoded + self.query_type_embed
        
        # Add positional encoding
        if self.config.use_positional_encoding:
            context_encoded = self.pos_encoder(context_encoded)
            query_encoded = self.pos_encoder(query_encoded)
        
        # Apply masks
        context_encoded = context_encoded * context_mask.unsqueeze(-1).float()
        query_encoded = query_encoded * query_mask.unsqueeze(-1).float()
        
        # Process with transformer
        if self.config.use_cross_attention:
            # Process context first
            context_processed = self.transformer(
                context_encoded,
                src_key_padding_mask=~context_mask
            )
            
            # Apply cross-attention from query to context
            query_processed = query_encoded
            for cross_attn in self.cross_attention:
                query_processed = cross_attn(
                    query_processed, context_processed, context_mask
                )
            
            # Final transformer processing
            combined = torch.cat([context_processed, query_processed], dim=1)
            combined_mask = torch.cat([context_mask, query_mask], dim=1)
            
            output = self.transformer(
                combined,
                src_key_padding_mask=~combined_mask
            )
            
            # Extract query outputs
            output = output[:, context_len:]
        else:
            # Standard transformer processing
            combined = torch.cat([context_encoded, query_encoded], dim=1)
            combined_mask = torch.cat([context_mask, query_mask], dim=1)
            
            output = self.transformer(
                combined,
                src_key_padding_mask=~combined_mask
            )
            
            # Extract query outputs
            output = output[:, context_len:]
        
        # Generate predictions
        logits = self.output_head(output)
        
        return logits

# ===========================
# Training Components
# ===========================

class TabPFNTrainer:
    """Trainer for TabPFN model"""
    
    def __init__(self, model: TabPFN, config: TabPFNConfig, device: torch.device):
        self.model = model.to(device)
        self.config = config
        self.device = device
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
            betas=(0.9, 0.98)
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        
        # Learning rate scheduler
        self.scheduler = self._create_scheduler()
        
        # Metrics
        self.train_metrics = {'loss': [], 'accuracy': []}
        self.val_metrics = {'loss': [], 'accuracy': []}
        
        # Best model tracking
        self.best_val_loss = float('inf')
        self.best_epoch = 0
    
    def _create_scheduler(self):
        """Create learning rate scheduler with warmup"""
        def lr_lambda(step):
            if step < self.config.warmup_steps:
                return step / self.config.warmup_steps
            else:
                return 0.5 * (1 + math.cos(math.pi * (step - self.config.warmup_steps) / 
                                          (self.config.num_epochs * 1000)))
        
        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
    
    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """Single training step"""
        self.model.train()
        
        # Move batch to device
        batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
        
        # Forward pass
        logits = self.model(batch)
        
        # Compute loss
        query_y = batch['query_y']
        query_mask = batch['query_mask']
        
        # Flatten for loss computation
        logits_flat = logits.view(-1, logits.size(-1))
        labels_flat = query_y.view(-1)
        mask_flat = query_mask.view(-1)
        
        # Compute loss only on valid positions
        loss_all = self.criterion(logits_flat, labels_flat)
        loss = (loss_all * mask_flat.float()).sum() / mask_flat.float().sum()
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
        
        # Optimizer step
        self.optimizer.step()
        self.scheduler.step()
        
        # Compute accuracy
        with torch.no_grad():
            predictions = torch.argmax(logits_flat, dim=-1)
            correct = (predictions == labels_flat) * mask_flat
            accuracy = correct.float().sum() / mask_flat.float().sum()
        
        return {
            'loss': loss.item(),
            'accuracy': accuracy.item(),
            'lr': self.optimizer.param_groups[0]['lr']
        }
    
    def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
        """Evaluate model on a dataset"""
        self.model.eval()
        
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for batch in dataloader:
                # Move batch to device
                batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Forward pass
                logits = self.model(batch)
                
                # Compute loss and accuracy
                query_y = batch['query_y']
                query_mask = batch['query_mask']
                
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = query_y.view(-1)
                mask_flat = query_mask.view(-1)
                
                loss_all = self.criterion(logits_flat, labels_flat)
                loss = (loss_all * mask_flat.float()).sum()
                
                predictions = torch.argmax(logits_flat, dim=-1)
                correct = ((predictions == labels_flat) * mask_flat).sum()
                
                total_loss += loss.item()
                total_correct += correct.item()
                total_samples += mask_flat.float().sum().item()
        
        return {
            'loss': total_loss / total_samples,
            'accuracy': total_correct / total_samples
        }
    
    def train_epoch(self, train_loader: DataLoader, val_loader: DataLoader, 
                   epoch: int) -> Dict[str, float]:
        """Train for one epoch"""
        # Training
        train_loss = 0.0
        train_acc = 0.0
        train_steps = 0
        
        for batch in train_loader:
            metrics = self.train_step(batch)
            train_loss += metrics['loss']
            train_acc += metrics['accuracy']
            train_steps += 1
            
            if train_steps % 100 == 0:
                print(f"  Step {train_steps}: Loss={metrics['loss']:.4f}, "
                      f"Acc={metrics['accuracy']:.4f}, LR={metrics['lr']:.6f}")
        
        avg_train_loss = train_loss / train_steps
        avg_train_acc = train_acc / train_steps
        
        # Validation
        val_metrics = self.evaluate(val_loader)
        
        # Update metrics history
        self.train_metrics['loss'].append(avg_train_loss)
        self.train_metrics['accuracy'].append(avg_train_acc)
        self.val_metrics['loss'].append(val_metrics['loss'])
        self.val_metrics['accuracy'].append(val_metrics['accuracy'])
        
        # Check for best model
        if val_metrics['loss'] < self.best_val_loss:
            self.best_val_loss = val_metrics['loss']
            self.best_epoch = epoch
            self.save_checkpoint(f'tabpfn_best.pt')
        
        return {
            'train_loss': avg_train_loss,
            'train_accuracy': avg_train_acc,
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy']
        }
    
    def save_checkpoint(self, path: str):
        """Save model checkpoint"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'train_metrics': self.train_metrics,
            'val_metrics': self.val_metrics,
            'best_val_loss': self.best_val_loss,
            'best_epoch': self.best_epoch
        }, path)
    
    def load_checkpoint(self, path: str):
        """Load model checkpoint"""
        checkpoint = torch.load(path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.train_metrics = checkpoint['train_metrics']
        self.val_metrics = checkpoint['val_metrics']
        self.best_val_loss = checkpoint['best_val_loss']
        self.best_epoch = checkpoint['best_epoch']

# ===========================
# Inference and Application
# ===========================

class TabPFNClassifier:
    """User-friendly interface for TabPFN"""
    
    def __init__(self, model_path: Optional[str] = None, device: Optional[torch.device] = None):
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load pre-trained model if path provided
        if model_path:
            self.load_model(model_path)
        else:
            # Initialize with default config
            self.config = TabPFNConfig()
            self.model = TabPFN(self.config)
            self.model.to(self.device)
    
    def load_model(self, path: str):
        """Load pre-trained model"""
        checkpoint = torch.load(path, map_location=self.device)
        self.config = checkpoint['config']
        self.model = TabPFN(self.config)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.device)
        self.model.eval()
    
    def fit_predict(self, X_train: np.ndarray, y_train: np.ndarray, 
                   X_test: np.ndarray) -> np.ndarray:
        """
        Fit on training data and predict on test data
        
        Args:
            X_train: Training features [n_train_samples, n_features]
            y_train: Training labels [n_train_samples]
            X_test: Test features [n_test_samples, n_features]
            
        Returns:
            predictions: Predicted labels [n_test_samples]
        """
        # Normalize features
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        # Convert to tensors
        context_X = torch.from_numpy(X_train_scaled).float()
        context_y = torch.from_numpy(y_train).long()
        query_X = torch.from_numpy(X_test_scaled).float()
        
        # Pad features if necessary
        n_features = X_train.shape[1]
        if n_features < self.config.max_features:
            pad_width = self.config.max_features - n_features
            context_X = F.pad(context_X, (0, pad_width))
            query_X = F.pad(query_X, (0, pad_width))
        
        # Create batch
        batch = {
            'context_X': context_X.unsqueeze(0),
            'context_y': context_y.unsqueeze(0),
            'query_X': query_X.unsqueeze(0),
            'context_mask': torch.ones(1, len(X_train), dtype=torch.bool),
            'query_mask': torch.ones(1, len(X_test), dtype=torch.bool),
            'feature_mask': torch.zeros(1, self.config.max_features, dtype=torch.bool)
        }
        batch['feature_mask'][0, :n_features] = True
        
        # Move to device
        batch = {k: v.to(self.device) for k, v in batch.items()}
        
        # Get predictions
        self.model.eval()
        with torch.no_grad():
            logits = self.model(batch)
            predictions = torch.argmax(logits, dim=-1).squeeze(0)
        
        return predictions.cpu().numpy()
    
    def predict_proba(self, X_train: np.ndarray, y_train: np.ndarray, 
                     X_test: np.ndarray) -> np.ndarray:
        """
        Get probability predictions
        
        Returns:
            probabilities: Class probabilities [n_test_samples, n_classes]
        """
        # Similar to fit_predict but return probabilities
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        context_X = torch.from_numpy(X_train_scaled).float()
        context_y = torch.from_numpy(y_train).long()
        query_X = torch.from_numpy(X_test_scaled).float()
        
        n_features = X_train.shape[1]
        if n_features < self.config.max_features:
            pad_width = self.config.max_features - n_features
            context_X = F.pad(context_X, (0, pad_width))
            query_X = F.pad(query_X, (0, pad_width))
        
        batch = {
            'context_X': context_X.unsqueeze(0),
            'context_y': context_y.unsqueeze(0),
            'query_X': query_X.unsqueeze(0),
            'context_mask': torch.ones(1, len(X_train), dtype=torch.bool),
            'query_mask': torch.ones(1, len(X_test), dtype=torch.bool),
            'feature_mask': torch.zeros(1, self.config.max_features, dtype=torch.bool)
        }
        batch['feature_mask'][0, :n_features] = True
        
        batch = {k: v.to(self.device) for k, v in batch.items()}
        
        self.model.eval()
        with torch.no_grad():
            logits = self.model(batch)
            probabilities = F.softmax(logits, dim=-1).squeeze(0)
            
            # Only return probabilities for actual classes in training data
            unique_classes = np.unique(y_train)
            probabilities = probabilities[:, unique_classes]
        
        return probabilities.cpu().numpy()

# ===========================
# Training Script
# ===========================

def train_tabpfn(config: Optional[TabPFNConfig] = None, 
                num_epochs: Optional[int] = None) -> TabPFN:
    """
    Train a TabPFN model
    
    Args:
        config: Model configuration (uses default if None)
        num_epochs: Number of training epochs (overrides config if provided)
    
    Returns:
        Trained TabPFN model
    """
    # Use default config if not provided
    if config is None:
        config = TabPFNConfig()
    
    if num_epochs is not None:
        config.num_epochs = num_epochs
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create datasets
    print("Creating datasets...")
    train_dataset = TabPFNDataset(config, config.num_train_tasks, split='train')
    val_dataset = TabPFNDataset(config, config.num_val_tasks, split='val')
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True
    )
    
    # Create model
    print("Creating model...")
    model = TabPFN(config)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create trainer
    trainer = TabPFNTrainer(model, config, device)
    
    # Training loop
    print("\nStarting training...")
    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
        print("-" * 50)
        
        metrics = trainer.train_epoch(train_loader, val_loader, epoch)
        
        print(f"Train Loss: {metrics['train_loss']:.4f}, "
              f"Train Acc: {metrics['train_accuracy']:.4f}")
        print(f"Val Loss: {metrics['val_loss']:.4f}, "
              f"Val Acc: {metrics['val_accuracy']:.4f}")
        
        # Early stopping
        if epoch - trainer.best_epoch > 10:
            print(f"\nEarly stopping at epoch {epoch + 1}")
            break
    
    print(f"\nTraining completed! Best model at epoch {trainer.best_epoch + 1}")
    
    # Load best model
    trainer.load_checkpoint('tabpfn_best.pt')
    
    return model

# ===========================
# Example Usage
# ===========================

def example_synthetic_data():
    """Example using synthetic data"""
    print("=== TabPFN Synthetic Data Example ===\n")
    
    # Generate synthetic dataset
    scm = StructuralCausalModel()
    X, y = scm.generate_dataset(
        num_samples=1000,
        num_features=20,
        num_classes=3,
        complexity=0.1
    )
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    print(f"Dataset shape: {X.shape}")
    print(f"Number of classes: {len(np.unique(y))}")
    print(f"Train/Test split: {len(X_train)}/{len(X_test)}\n")
    
    # Create and use TabPFN classifier
    classifier = TabPFNClassifier()
    
    # Make predictions
    predictions = classifier.fit_predict(X_train, y_train, X_test)
    
    # Calculate accuracy
    accuracy = (predictions == y_test).mean()
    print(f"Test Accuracy: {accuracy:.4f}")
    
    # Get probability predictions
    proba = classifier.predict_proba(X_train, y_train, X_test)
    print(f"Probability shape: {proba.shape}")
    
    return classifier

def example_real_data():
    """Example using real datasets"""
    print("\n=== TabPFN Real Data Example ===\n")
    
    from sklearn.datasets import load_breast_cancer, load_wine, load_iris
    from sklearn.metrics import accuracy_score, classification_report
    
    datasets = {
        'Iris': load_iris(),
        'Wine': load_wine(),
        'Breast Cancer': load_breast_cancer()
    }
    
    # Load pre-trained model (assume it exists)
    # In practice, you would train the model first
    classifier = TabPFNClassifier()
    
    for name, data in datasets.items():
        print(f"\n{name} Dataset:")
        print(f"  Samples: {data.data.shape[0]}")
        print(f"  Features: {data.data.shape[1]}")
        print(f"  Classes: {len(np.unique(data.target))}")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            data.data, data.target, test_size=0.3, random_state=42
        )
        
        # Make predictions
        predictions = classifier.fit_predict(X_train, y_train, X_test)
        
        # Evaluate
        accuracy = accuracy_score(y_test, predictions)
        print(f"  Test Accuracy: {accuracy:.4f}")

def example_training():
    """Example of training TabPFN"""
    print("\n=== TabPFN Training Example ===\n")
    
    # Create custom configuration
    config = TabPFNConfig(
        max_features=50,
        hidden_dim=256,
        num_layers=6,
        num_heads=8,
        learning_rate=1e-4,
        batch_size=32,
        num_epochs=10,  # Small for demo
        num_train_tasks=10000,  # Reduced for demo
        num_val_tasks=1000
    )
    
    # Train model
    model = train_tabpfn(config)
    
    return model

if __name__ == "__main__":
    # Run examples
    print("TabPFN Complete Implementation\n")
    
    # Example 1: Synthetic data
    classifier = example_synthetic_data()
    
    # Example 2: Real data
    example_real_data()
    
    # Example 3: Training (uncomment to run)
    model = example_training()
    
    print("\n\nAll examples completed!")

TabPFN Complete Implementation

=== TabPFN Synthetic Data Example ===

Dataset shape: (1000, 20)
Number of classes: 3
Train/Test split: 800/200

Test Accuracy: 0.3550
Probability shape: (200, 3)

=== TabPFN Real Data Example ===


Iris Dataset:
  Samples: 150
  Features: 4
  Classes: 3
  Test Accuracy: 0.4222

Wine Dataset:
  Samples: 178
  Features: 13
  Classes: 3
  Test Accuracy: 0.3889

Breast Cancer Dataset:
  Samples: 569
  Features: 30
  Classes: 2
  Test Accuracy: 0.6316

=== TabPFN Training Example ===

Using device: cuda
Creating datasets...
Creating model...
Model parameters: 5,680,522

Starting training...

Epoch 1/10
--------------------------------------------------
  Step 100: Loss=1.6992, Acc=0.2040, LR=0.000010
  Step 200: Loss=1.7419, Acc=0.2077, LR=0.000020
  Step 300: Loss=1.7836, Acc=0.1967, LR=0.000030
Train Loss: 1.6961, Train Acc: 0.2141
Val Loss: 1.6620, Val Acc: 0.2190

Epoch 2/10
--------------------------------------------------
  Step 100: Loss=1.5432, Acc=

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL __main__.TabPFNConfig was not an allowed global by default. Please use `torch.serialization.add_safe_globals([TabPFNConfig])` or the `torch.serialization.safe_globals([TabPFNConfig])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [6]:
"""
Improved TabPFN Training with Better Configurations and Monitoring
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional
import time
from tqdm import tqdm

# Import the TabPFN components from the main implementation
# Assuming the main TabPFN code is already loaded

def create_optimized_config(scale: str = 'medium') -> TabPFNConfig:
    """
    Create optimized configurations for different scales
    
    Args:
        scale: 'small', 'medium', 'large', or 'xlarge'
    """
    configs = {
        'small': TabPFNConfig(
            # Model size
            max_features=32,
            hidden_dim=128,
            num_layers=4,
            num_heads=4,
            max_classes=10,
            dropout=0.1,
            
            # Training
            learning_rate=1e-3,
            weight_decay=1e-5,
            batch_size=64,
            num_epochs=50,
            warmup_steps=500,
            gradient_clip=0.5,
            
            # Data generation
            min_features=3,
            min_samples=20,
            max_samples=200,
            max_samples_train=150,
            num_train_tasks=20000,
            num_val_tasks=2000,
            
            # Features
            use_feature_embedding=True,
            use_positional_encoding=False,
            use_cross_attention=False  # Disable for faster training
        ),
        
        'medium': TabPFNConfig(
            # Model size
            max_features=64,
            hidden_dim=256,
            num_layers=6,
            num_heads=8,
            max_classes=10,
            dropout=0.1,
            
            # Training
            learning_rate=2e-3,
            weight_decay=1e-5,
            batch_size=32,
            num_epochs=100,
            warmup_steps=1000,
            gradient_clip=1.0,
            
            # Data generation
            min_features=3,
            min_samples=30,
            max_samples=300,
            max_samples_train=250,
            num_train_tasks=50000,
            num_val_tasks=5000,
            
            # Features
            use_feature_embedding=True,
            use_positional_encoding=True,
            use_cross_attention=True
        ),
        
        'large': TabPFNConfig(
            # Model size
            max_features=100,
            hidden_dim=512,
            num_layers=8,
            num_heads=8,
            max_classes=10,
            dropout=0.0,
            
            # Training
            learning_rate=1e-3,
            weight_decay=1e-5,
            batch_size=16,
            num_epochs=200,
            warmup_steps=2000,
            gradient_clip=1.0,
            
            # Data generation  
            min_features=3,
            min_samples=50,
            max_samples=500,
            max_samples_train=400,
            num_train_tasks=100000,
            num_val_tasks=10000,
            
            # Features
            use_feature_embedding=True,
            use_positional_encoding=True,
            use_cross_attention=True
        ),
        
        'xlarge': TabPFNConfig(
            # Model size
            max_features=100,
            hidden_dim=768,
            num_layers=12,
            num_heads=12,
            max_classes=10,
            dropout=0.0,
            
            # Training
            learning_rate=5e-4,
            weight_decay=1e-5,
            batch_size=8,
            num_epochs=300,
            warmup_steps=4000,
            gradient_clip=1.0,
            
            # Data generation
            min_features=3,
            min_samples=50,
            max_samples=1000,
            max_samples_train=800,
            num_train_tasks=200000,
            num_val_tasks=20000,
            
            # Features
            use_feature_embedding=True,
            use_positional_encoding=True,
            use_cross_attention=True
        )
    }
    
    return configs[scale]

class ImprovedTabPFNTrainer(TabPFNTrainer):
    """Enhanced trainer with moving-average monitoring and history tracking"""

    def __init__(self, model: TabPFN, config: TabPFNConfig, device: torch.device):
        super().__init__(model, config, device)
        self.step_count = 0
        self.best_accuracy = 0.0
        self.train_loss_ma: List[float] = []
        self.train_acc_ma: List[float] = []
        # 에폭별 기록용 히스토리
        self.history: Dict[str, List[float]] = {
            'train_loss': [], 'train_accuracy': [],
            'val_loss': [],   'val_accuracy': [],
            'lr': []
        }
        self.scheduler = self._create_cosine_scheduler_with_warmup()

    def _create_cosine_scheduler_with_warmup(self):
        def lr_lambda(step):
            if step < self.config.warmup_steps:
                return step / self.config.warmup_steps
            total_steps = self.config.num_epochs * (self.config.num_train_tasks // self.config.batch_size)
            progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps)
            return 0.5 * (1 + np.cos(np.pi * progress))
        return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        self.model.train()
        self.step_count += 1

        batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                 for k, v in batch.items()}

        logits = self.model(batch)
        query_y, query_mask = batch['query_y'], batch['query_mask']
        logits_flat = logits.view(-1, logits.size(-1))
        labels_flat = query_y.view(-1)
        mask_flat = query_mask.view(-1).float()

        loss_all = self.criterion(logits_flat, labels_flat)
        loss = (loss_all * mask_flat).sum() / mask_flat.sum()

        if hasattr(self.model, 'cross_attention') and self.config.use_cross_attention:
            l2_reg = sum(torch.norm(p, 2) for layer in self.model.cross_attention for p in layer.parameters())
            loss = loss + 1e-5 * l2_reg

        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
        self.optimizer.step()
        self.scheduler.step()

        with torch.no_grad():
            preds = torch.argmax(logits_flat, dim=-1)
            correct = (preds == labels_flat).float() * mask_flat
            accuracy = correct.sum() / mask_flat.sum()

        # 이동 평균 업데이트
        self.train_loss_ma.append(loss.item())
        self.train_acc_ma.append(accuracy.item())
        if len(self.train_loss_ma) > 100:
            self.train_loss_ma.pop(0)
            self.train_acc_ma.pop(0)

        return {
            'loss': loss.item(),
            'accuracy': accuracy.item(),
            'lr': self.optimizer.param_groups[0]['lr'],
            'grad_norm': grad_norm,
            'loss_ma': np.mean(self.train_loss_ma),
            'acc_ma': np.mean(self.train_acc_ma)
        }

    def train_epoch_with_progress(self, train_loader, val_loader, epoch: int):
        # 이동 평균 리스트 초기화
        self.train_loss_ma.clear()
        self.train_acc_ma.clear()

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        last_metrics = None
        for batch in pbar:
            last_metrics = self.train_step(batch)
            pbar.set_postfix({
                'loss_ma': f"{last_metrics['loss_ma']:.4f}",
                'acc_ma':  f"{last_metrics['acc_ma']:.4f}",
                'lr':      f"{last_metrics['lr']:.6f}"
            })

        # 검증
        val_metrics = self.evaluate(val_loader)

        # 베스트 모델 저장
        if val_metrics['accuracy'] > self.best_accuracy:
            self.best_accuracy = val_metrics['accuracy']
            self.best_epoch = epoch
            self.save_checkpoint('tabpfn_best.pt')

        # 히스토리 기록
        self.history['train_loss'].append(np.mean(self.train_loss_ma))
        self.history['train_accuracy'].append(np.mean(self.train_acc_ma))
        self.history['val_loss'].append(val_metrics['loss'])
        self.history['val_accuracy'].append(val_metrics['accuracy'])
        self.history['lr'].append(last_metrics['lr'] if last_metrics else self.scheduler.get_last_lr()[0])

        return {
            'train_loss': self.history['train_loss'][-1],
            'train_accuracy': self.history['train_accuracy'][-1],
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy']
        }

def visualize_training_progress(trainer: ImprovedTabPFNTrainer,
                                save_path: str = 'training_progress.png'):
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    # Loss curves
    axes[0, 0].plot(trainer.history['train_loss'], label='Train Loss')
    axes[0, 0].plot(trainer.history['val_loss'],   label='Val Loss')
    axes[0, 0].set(title='Loss Curves', xlabel='Epoch', ylabel='Loss')
    axes[0, 0].legend(); axes[0, 0].grid(True)

    # Accuracy curves
    axes[0, 1].plot(trainer.history['train_accuracy'], label='Train Acc')
    axes[0, 1].plot(trainer.history['val_accuracy'],   label='Val Acc')
    axes[0, 1].set(title='Accuracy Curves', xlabel='Epoch', ylabel='Accuracy')
    axes[0, 1].legend(); axes[0, 1].grid(True)

    # Learning rate schedule
    axes[1, 0].plot(trainer.history['lr'])
    axes[1, 0].set(title='Learning Rate Schedule', xlabel='Epoch', ylabel='LR')
    axes[1, 0].grid(True)

    # Best validation accuracy over time
    best_accs = []
    current_best = 0.0
    for acc in trainer.history['val_accuracy']:
        current_best = max(current_best, acc)
        best_accs.append(current_best)
    axes[1, 1].plot(best_accs)
    axes[1, 1].set(title='Best Val Accuracy', xlabel='Epoch', ylabel='Accuracy')
    axes[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

def train_tabpfn_improved(scale: str = 'small', 
                         num_epochs: Optional[int] = None,
                         checkpoint_path: Optional[str] = None) -> TabPFN:
    """
    Improved training function with better defaults and monitoring
    
    Args:
        scale: Model scale ('small', 'medium', 'large', 'xlarge')
        num_epochs: Override number of epochs
        checkpoint_path: Path to resume from checkpoint
    """
    # Get configuration
    config = create_optimized_config(scale)
    if num_epochs is not None:
        config.num_epochs = num_epochs
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Training {scale} model with {config.num_train_tasks} training tasks")
    
    # Create datasets with proper workers
    num_workers = 4 if device.type == 'cuda' else 0
    
    print("Creating datasets...")
    train_dataset = TabPFNDataset(config, config.num_train_tasks, split='train')
    val_dataset = TabPFNDataset(config, config.num_val_tasks, split='val')
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True if device.type == 'cuda' else False,
        persistent_workers=True if num_workers > 0 else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,  # Larger batch for validation
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True if device.type == 'cuda' else False,
        persistent_workers=True if num_workers > 0 else False
    )
    
    # Create model
    print(f"\nCreating {scale} model...")
    model = TabPFN(config)
    
    # Initialize weights with better strategy
    def init_weights(module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight, gain=1.0)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
    
    model.apply(init_weights)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    # Create trainer
    trainer = ImprovedTabPFNTrainer(model, config, device)
    
    # Load checkpoint if provided
    start_epoch = 0
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        trainer.load_checkpoint(checkpoint_path)
        start_epoch = trainer.best_epoch + 1
    
    # Training loop
    print(f"\nStarting training from epoch {start_epoch + 1}...")
    print("=" * 70)
    
    early_stopping_patience = 20
    epochs_without_improvement = 0
    
    for epoch in range(start_epoch, config.num_epochs):
        start_time = time.time()
        
        # Train epoch
        metrics = trainer.train_epoch_with_progress(train_loader, val_loader, epoch)
        
        # Print epoch summary
        epoch_time = time.time() - start_time
        print(f"\nEpoch {epoch + 1}/{config.num_epochs} - {epoch_time:.1f}s")
        print(f"Train Loss: {metrics['train_loss']:.4f}, Train Acc: {metrics['train_accuracy']:.4f}")
        print(f"Val Loss: {metrics['val_loss']:.4f}, Val Acc: {metrics['val_accuracy']:.4f}")
        print(f"Best Val Acc: {trainer.best_accuracy:.4f} (Epoch {trainer.best_epoch + 1})")
        print("-" * 70)
        
        # Early stopping
        if epoch == trainer.best_epoch:
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= early_stopping_patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            break
        
        # Save periodic checkpoints
        if (epoch + 1) % 10 == 0:
            trainer.save_checkpoint(f'tabpfn_epoch_{epoch+1}.pt')
    
    print(f"\nTraining completed! Best model achieved {trainer.best_accuracy:.4f} validation accuracy")
    
    # Visualize progress
    if len(trainer.train_metrics['loss']) > 0:
        visualize_training_progress(trainer)
    
    # Load best model
    trainer.load_checkpoint('tabpfn_best.pt')
    
    return model

def diagnose_training_issues(model: TabPFN, train_loader: DataLoader, device: torch.device):
    """Diagnose potential training issues"""
    print("\n=== Training Diagnostics ===")
    
    model.eval()
    with torch.no_grad():
        # Check a few batches
        for i, batch in enumerate(train_loader):
            if i >= 5:
                break
                
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            # Get model outputs
            logits = model(batch)
            
            # Analyze outputs
            print(f"\nBatch {i+1}:")
            print(f"  Logits range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")
            print(f"  Logits mean: {logits.mean().item():.3f}")
            print(f"  Logits std: {logits.std().item():.3f}")
            
            # Check predictions distribution
            predictions = torch.argmax(logits, dim=-1)
            unique_preds = torch.unique(predictions)
            print(f"  Unique predictions: {unique_preds.tolist()}")
            
            # Check if model is collapsed
            if len(unique_preds) == 1:
                print("  WARNING: Model is predicting only one class!")
            
            # Check attention weights if available
            if hasattr(model, 'transformer') and hasattr(model.transformer.layers[0].self_attn, 'attention_weights'):
                attn = model.transformer.layers[0].self_attn.attention_weights
                print(f"  Attention weights range: [{attn.min().item():.3f}, {attn.max().item():.3f}]")

# Example usage functions
def example_improved_training():
    """Example of improved training"""
    print("=== TabPFN Improved Training Example ===\n")
    
    # Train a medium-scale model
    model = train_tabpfn_improved(
        scale='small',
        num_epochs=50  # Override for faster demo
    )
    
    return model

def example_resume_training():
    """Example of resuming training from checkpoint"""
    print("=== TabPFN Resume Training Example ===\n")
    
    # Resume training from checkpoint
    model = train_tabpfn_improved(
        scale='small',
        checkpoint_path='tabpfn_best.pt'
    )
    
    return model

def example_different_scales():
    """Compare different model scales"""
    print("=== TabPFN Scale Comparison ===\n")
    
    scales = ['small', 'medium', 'large']
    results = {}
    
    for scale in scales:
        print(f"\nTraining {scale} model...")
        config = create_optimized_config(scale)
        print(f"  Hidden dim: {config.hidden_dim}")
        print(f"  Num layers: {config.num_layers}")
        print(f"  Train tasks: {config.num_train_tasks}")
        
        # Just print config, don't actually train (for demo)
        results[scale] = config
    
    return results

if __name__ == "__main__":
    # Run improved training
    model = example_improved_training()
    
    # Show different scales
    # scales = example_different_scales()
    
    print("\nImproved training completed!")

=== TabPFN Improved Training Example ===

Using device: cuda
Training small model with 20000 training tasks
Creating datasets...

Creating small model...
Model parameters: 833,482
Trainable parameters: 833,482

Starting training from epoch 1...


Epoch 1: 100%|██████████| 313/313 [00:08<00:00, 36.18it/s, loss_ma=1.6816, acc_ma=0.2152, lr=0.000626]



Epoch 1/50 - 9.4s
Train Loss: 1.6816, Train Acc: 0.2152
Val Loss: 1.6657, Val Acc: 0.2192
Best Val Acc: 0.2192 (Epoch 1)
----------------------------------------------------------------------


Epoch 2: 100%|██████████| 313/313 [00:09<00:00, 34.74it/s, loss_ma=1.6791, acc_ma=0.2164, lr=0.001000]



Epoch 2/50 - 9.5s
Train Loss: 1.6791, Train Acc: 0.2164
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 3: 100%|██████████| 313/313 [00:08<00:00, 36.31it/s, loss_ma=1.6764, acc_ma=0.2182, lr=0.000998]



Epoch 3/50 - 9.1s
Train Loss: 1.6764, Train Acc: 0.2182
Val Loss: 1.6646, Val Acc: 0.2182
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 4: 100%|██████████| 313/313 [00:08<00:00, 35.03it/s, loss_ma=1.6928, acc_ma=0.2146, lr=0.000994]



Epoch 4/50 - 9.4s
Train Loss: 1.6928, Train Acc: 0.2146
Val Loss: 1.6647, Val Acc: 0.2199
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 5: 100%|██████████| 313/313 [00:08<00:00, 34.86it/s, loss_ma=1.6876, acc_ma=0.2151, lr=0.000988]



Epoch 5/50 - 9.5s
Train Loss: 1.6876, Train Acc: 0.2151
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 6: 100%|██████████| 313/313 [00:08<00:00, 34.89it/s, loss_ma=1.6722, acc_ma=0.2190, lr=0.000980]



Epoch 6/50 - 9.4s
Train Loss: 1.6722, Train Acc: 0.2190
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 7: 100%|██████████| 313/313 [00:08<00:00, 34.91it/s, loss_ma=1.6765, acc_ma=0.2192, lr=0.000969]



Epoch 7/50 - 9.4s
Train Loss: 1.6765, Train Acc: 0.2192
Val Loss: 1.6646, Val Acc: 0.2195
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 8: 100%|██████████| 313/313 [00:08<00:00, 35.15it/s, loss_ma=1.6896, acc_ma=0.2156, lr=0.000957]



Epoch 8/50 - 9.4s
Train Loss: 1.6896, Train Acc: 0.2156
Val Loss: 1.6646, Val Acc: 0.2187
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 9: 100%|██████████| 313/313 [00:08<00:00, 35.26it/s, loss_ma=1.6892, acc_ma=0.2163, lr=0.000943]



Epoch 9/50 - 9.4s
Train Loss: 1.6892, Train Acc: 0.2163
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 10: 100%|██████████| 313/313 [00:08<00:00, 35.54it/s, loss_ma=1.6767, acc_ma=0.2179, lr=0.000927]



Epoch 10/50 - 9.3s
Train Loss: 1.6767, Train Acc: 0.2179
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 11: 100%|██████████| 313/313 [00:08<00:00, 35.86it/s, loss_ma=1.6827, acc_ma=0.2163, lr=0.000909]



Epoch 11/50 - 9.2s
Train Loss: 1.6827, Train Acc: 0.2163
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 12: 100%|██████████| 313/313 [00:08<00:00, 36.29it/s, loss_ma=1.6778, acc_ma=0.2188, lr=0.000890]



Epoch 12/50 - 9.1s
Train Loss: 1.6778, Train Acc: 0.2188
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 13: 100%|██████████| 313/313 [00:08<00:00, 36.74it/s, loss_ma=1.6847, acc_ma=0.2162, lr=0.000868]



Epoch 13/50 - 9.0s
Train Loss: 1.6847, Train Acc: 0.2162
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 14: 100%|██████████| 313/313 [00:08<00:00, 35.79it/s, loss_ma=1.6781, acc_ma=0.2186, lr=0.000846]



Epoch 14/50 - 9.2s
Train Loss: 1.6781, Train Acc: 0.2186
Val Loss: 1.6646, Val Acc: 0.2195
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 15: 100%|██████████| 313/313 [00:08<00:00, 34.95it/s, loss_ma=1.6755, acc_ma=0.2174, lr=0.000821]



Epoch 15/50 - 9.4s
Train Loss: 1.6755, Train Acc: 0.2174
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 16: 100%|██████████| 313/313 [00:08<00:00, 36.03it/s, loss_ma=1.6764, acc_ma=0.2184, lr=0.000796]



Epoch 16/50 - 9.2s
Train Loss: 1.6764, Train Acc: 0.2184
Val Loss: 1.6646, Val Acc: 0.2199
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 17: 100%|██████████| 313/313 [00:07<00:00, 39.88it/s, loss_ma=1.6896, acc_ma=0.2159, lr=0.000769]



Epoch 17/50 - 8.3s
Train Loss: 1.6896, Train Acc: 0.2159
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 18: 100%|██████████| 313/313 [00:07<00:00, 40.05it/s, loss_ma=1.6702, acc_ma=0.2182, lr=0.000741]



Epoch 18/50 - 8.3s
Train Loss: 1.6702, Train Acc: 0.2182
Val Loss: 1.6646, Val Acc: 0.2195
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 19: 100%|██████████| 313/313 [00:07<00:00, 39.62it/s, loss_ma=1.6803, acc_ma=0.2163, lr=0.000712]



Epoch 19/50 - 8.4s
Train Loss: 1.6803, Train Acc: 0.2163
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 20: 100%|██████████| 313/313 [00:07<00:00, 40.95it/s, loss_ma=1.6808, acc_ma=0.2172, lr=0.000682]



Epoch 20/50 - 8.1s
Train Loss: 1.6808, Train Acc: 0.2172
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 21: 100%|██████████| 313/313 [00:07<00:00, 42.02it/s, loss_ma=1.6779, acc_ma=0.2170, lr=0.000651]



Epoch 21/50 - 7.9s
Train Loss: 1.6779, Train Acc: 0.2170
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------


Epoch 22: 100%|██████████| 313/313 [00:08<00:00, 35.09it/s, loss_ma=1.6667, acc_ma=0.2203, lr=0.000620]



Epoch 22/50 - 9.7s
Train Loss: 1.6667, Train Acc: 0.2203
Val Loss: 1.6646, Val Acc: 0.2200
Best Val Acc: 0.2200 (Epoch 2)
----------------------------------------------------------------------

Early stopping triggered after 22 epochs

Training completed! Best model achieved 0.2200 validation accuracy


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` or the `torch.serialization.safe_globals([scalar])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [7]:
"""
TabPFN with Advanced Data Augmentation and Imbalance Handling
Based on the original TabPFN paper's data generation strategies
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import math
from scipy import stats
from sklearn.preprocessing import StandardScaler, PowerTransformer, QuantileTransformer
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, Matern, RationalQuadratic
import warnings
warnings.filterwarnings('ignore')

# ===========================
# Advanced Prior Specifications (from TabPFN paper)
# ===========================

@dataclass
class PriorConfig:
    """Configuration for data generation priors"""
    # Structural Causal Model parameters
    num_causes_per_feature: Tuple[int, int] = (1, 5)
    edge_probability: float = 0.3
    noise_scale_range: Tuple[float, float] = (0.01, 0.3)
    
    # Function class weights (as in TabPFN paper)
    function_class_weights: Dict[str, float] = None
    
    # Data transformation parameters
    apply_power_transform: float = 0.3
    apply_quantile_transform: float = 0.2
    add_categorical_features: float = 0.3
    
    # Imbalance parameters
    imbalance_ratio_range: Tuple[float, float] = (0.1, 1.0)
    use_focal_loss: bool = True
    focal_alpha: float = 0.25
    focal_gamma: float = 2.0
    
    # Gaussian Process parameters
    use_gp_functions: float = 0.2
    gp_length_scale_range: Tuple[float, float] = (0.1, 2.0)
    
    def __post_init__(self):
        if self.function_class_weights is None:
            # Default weights from TabPFN paper
            self.function_class_weights = {
                'linear': 0.25,
                'polynomial': 0.15,
                'neural_basis': 0.15,
                'decision_tree': 0.15,
                'gaussian_process': 0.10,
                'periodic': 0.10,
                'interaction': 0.10
            }

class AdvancedDataGenerator:
    """
    Advanced data generator implementing TabPFN paper's strategies:
    1. Structural Causal Models with various function classes
    2. Gaussian Process priors
    3. Neural network basis functions
    4. Decision tree-like functions
    5. Data transformations and augmentations
    """
    
    def __init__(self, prior_config: PriorConfig = None):
        self.config = prior_config or PriorConfig()
        self.rng = np.random.RandomState()
        
    def generate_dataset(self, num_samples: int, num_features: int, 
                        num_classes: int, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
        """Generate a dataset with advanced augmentation strategies"""
        if seed is not None:
            self.rng.seed(seed)
        
        # Generate causal graph structure
        causal_graph = self._generate_causal_graph(num_features)
        
        # Generate base features
        X = self._generate_base_features(num_samples, num_features)
        
        # Apply causal mechanisms
        X = self._apply_causal_mechanisms(X, causal_graph)
        
        # Generate target using sampled function class
        y_continuous, function_info = self._generate_target_with_prior(X, num_features)
        
        # Apply data transformations
        X, transform_info = self._apply_data_transformations(X)
        
        # Convert to classes with potential imbalance
        y, class_info = self._create_imbalanced_classes(y_continuous, num_classes)
        
        # Metadata for training
        metadata = {
            'causal_graph': causal_graph,
            'function_info': function_info,
            'transform_info': transform_info,
            'class_info': class_info
        }
        
        return X.astype(np.float32), y.astype(np.int64), metadata
    
    def _generate_causal_graph(self, num_features: int) -> np.ndarray:
        """Generate a DAG for causal relationships"""
        # Lower triangular matrix ensures DAG
        graph = np.zeros((num_features, num_features))
        
        for i in range(num_features):
            for j in range(i):
                if self.rng.random() < self.config.edge_probability:
                    graph[i, j] = self.rng.uniform(0.5, 2.0) * self.rng.choice([-1, 1])
        
        return graph
    
    def _generate_base_features(self, num_samples: int, num_features: int) -> np.ndarray:
        """Generate diverse base features"""
        X = np.zeros((num_samples, num_features))
        
        # Mix of different distributions (as in TabPFN paper)
        distributions = [
            ('normal', lambda n: self.rng.normal(0, 1, n)),
            ('uniform', lambda n: self.rng.uniform(-2, 2, n)),
            ('exponential', lambda n: self.rng.exponential(1, n) - 1),
            ('student_t', lambda n: stats.t.rvs(df=3, size=n, random_state=self.rng)),
            ('laplace', lambda n: self.rng.laplace(0, 1, n)),
            ('gamma', lambda n: stats.gamma.rvs(2, size=n, random_state=self.rng) - 2),
            ('beta', lambda n: stats.beta.rvs(2, 5, size=n, random_state=self.rng) * 4 - 2),
            ('mixture', lambda n: self._generate_mixture(n))
        ]
        
        for i in range(num_features):
            dist_name, dist_func = distributions[i % len(distributions)]
            X[:, i] = dist_func(num_samples)
            
            # Add some dependencies between features
            if i > 0 and self.rng.random() < 0.3:
                parent_idx = self.rng.randint(0, i)
                correlation = self.rng.uniform(-0.8, 0.8)
                X[:, i] = correlation * X[:, parent_idx] + np.sqrt(1 - correlation**2) * X[:, i]
        
        return X
    
    def _generate_mixture(self, n: int) -> np.ndarray:
        """Generate mixture of Gaussians"""
        n_components = self.rng.randint(2, 5)
        weights = self.rng.dirichlet(np.ones(n_components))
        
        samples = []
        for _ in range(n):
            component = self.rng.choice(n_components, p=weights)
            mean = self.rng.uniform(-3, 3)
            std = self.rng.uniform(0.5, 1.5)
            samples.append(self.rng.normal(mean, std))
        
        return np.array(samples)
    
    def _apply_causal_mechanisms(self, X: np.ndarray, graph: np.ndarray) -> np.ndarray:
        """Apply causal relationships based on graph"""
        X_transformed = X.copy()
        
        for i in range(X.shape[1]):
            parents = np.where(graph[i, :] != 0)[0]
            if len(parents) > 0:
                # Apply causal mechanism
                parent_values = X_transformed[:, parents]
                weights = graph[i, parents]
                
                # Add non-linear transformation
                if self.rng.random() < 0.5:
                    X_transformed[:, i] += np.sum(parent_values * weights, axis=1)
                else:
                    # Non-linear interaction
                    X_transformed[:, i] += np.tanh(np.sum(parent_values * weights, axis=1))
        
        return X_transformed
    
    def _generate_target_with_prior(self, X: np.ndarray, num_features: int) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Generate target using various function classes from TabPFN paper"""
        
        # Sample function class based on weights
        function_classes = list(self.config.function_class_weights.keys())
        probabilities = list(self.config.function_class_weights.values())
        chosen_class = self.rng.choice(function_classes, p=probabilities)
        
        # Select relevant features
        num_relevant = self.rng.randint(
            max(1, int(num_features * 0.1)), 
            max(2, int(num_features * 0.7))
        )
        relevant_features = self.rng.choice(num_features, size=num_relevant, replace=False)
        
        # Generate target based on chosen function class
        if chosen_class == 'linear':
            y, info = self._linear_function(X, relevant_features)
        elif chosen_class == 'polynomial':
            y, info = self._polynomial_function(X, relevant_features)
        elif chosen_class == 'neural_basis':
            y, info = self._neural_basis_function(X, relevant_features)
        elif chosen_class == 'decision_tree':
            y, info = self._decision_tree_function(X, relevant_features)
        elif chosen_class == 'gaussian_process':
            y, info = self._gaussian_process_function(X, relevant_features)
        elif chosen_class == 'periodic':
            y, info = self._periodic_function(X, relevant_features)
        elif chosen_class == 'interaction':
            y, info = self._interaction_function(X, relevant_features)
        else:
            y, info = self._linear_function(X, relevant_features)
        
        # Add noise
        noise_scale = self.rng.uniform(*self.config.noise_scale_range)
        y += self.rng.normal(0, noise_scale, size=y.shape)
        
        info.update({
            'function_class': chosen_class,
            'relevant_features': relevant_features,
            'noise_scale': noise_scale
        })
        
        return y, info
    
    def _linear_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Linear combination with random weights"""
        weights = self.rng.randn(len(features))
        y = X[:, features] @ weights
        return y, {'weights': weights}
    
    def _polynomial_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Polynomial features with interactions"""
        y = np.zeros(X.shape[0])
        terms = []
        
        # Single features with powers
        for feat in features[:min(5, len(features))]:
            power = self.rng.randint(1, 4)
            coef = self.rng.randn()
            y += coef * (X[:, feat] ** power)
            terms.append(('power', feat, power, coef))
        
        # Interactions
        if len(features) >= 2:
            for _ in range(min(3, len(features) // 2)):
                feat1, feat2 = self.rng.choice(features, size=2, replace=False)
                coef = self.rng.randn()
                y += coef * X[:, feat1] * X[:, feat2]
                terms.append(('interaction', feat1, feat2, coef))
        
        return y, {'terms': terms}
    
    def _neural_basis_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Neural network-like basis functions"""
        hidden_size = self.rng.randint(5, 20)
        
        # First layer
        W1 = self.rng.randn(len(features), hidden_size) * 0.5
        b1 = self.rng.randn(hidden_size) * 0.1
        
        # Hidden activation
        hidden = np.tanh(X[:, features] @ W1 + b1)
        
        # Output layer
        W2 = self.rng.randn(hidden_size) * 0.5
        b2 = self.rng.randn() * 0.1
        
        y = hidden @ W2 + b2
        
        return y, {
            'architecture': f'Input({len(features)}) -> Hidden({hidden_size}) -> Output(1)',
            'activation': 'tanh'
        }
    
    def _decision_tree_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Decision tree-like step functions"""
        y = np.zeros(X.shape[0])
        
        # Create tree-like splits
        num_splits = self.rng.randint(3, 8)
        splits = []
        
        for _ in range(num_splits):
            feat = self.rng.choice(features)
            threshold = np.percentile(X[:, feat], self.rng.uniform(20, 80))
            value = self.rng.randn()
            
            # Apply split
            mask = X[:, feat] > threshold
            y[mask] += value
            
            splits.append((feat, threshold, value))
        
        return y, {'splits': splits}
    
    def _gaussian_process_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Gaussian Process prior functions"""
        # Sample kernel
        kernels = [
            RBF(length_scale=self.rng.uniform(*self.config.gp_length_scale_range)),
            Matern(length_scale=self.rng.uniform(*self.config.gp_length_scale_range), nu=1.5),
            RationalQuadratic(length_scale=self.rng.uniform(*self.config.gp_length_scale_range))
        ]
        kernel = self.rng.choice(kernels)
        
        # Sample from GP prior
        X_subset = X[:, features]
        K = kernel(X_subset)
        L = np.linalg.cholesky(K + 1e-6 * np.eye(X.shape[0]))
        y = L @ self.rng.randn(X.shape[0])
        
        return y, {'kernel': str(kernel)}
    
    def _periodic_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Periodic functions"""
        y = np.zeros(X.shape[0])
        components = []
        
        for feat in features[:min(3, len(features))]:
            frequency = self.rng.uniform(0.5, 3.0)
            phase = self.rng.uniform(0, 2 * np.pi)
            amplitude = self.rng.randn()
            
            y += amplitude * np.sin(frequency * X[:, feat] + phase)
            components.append((feat, frequency, phase, amplitude))
        
        return y, {'components': components}
    
    def _interaction_function(self, X: np.ndarray, features: np.ndarray) -> Tuple[np.ndarray, Dict]:
        """Complex interaction patterns"""
        y = np.zeros(X.shape[0])
        interactions = []
        
        # Pairwise interactions
        if len(features) >= 2:
            for _ in range(min(5, len(features))):
                feat1, feat2 = self.rng.choice(features, size=2, replace=False)
                
                interaction_type = self.rng.choice(['multiply', 'divide', 'subtract'])
                coef = self.rng.randn()
                
                if interaction_type == 'multiply':
                    y += coef * X[:, feat1] * X[:, feat2]
                elif interaction_type == 'divide':
                    y += coef * X[:, feat1] / (np.abs(X[:, feat2]) + 0.1)
                else:  # subtract
                    y += coef * (X[:, feat1] - X[:, feat2])
                
                interactions.append((feat1, feat2, interaction_type, coef))
        
        # Three-way interactions
        if len(features) >= 3:
            feat1, feat2, feat3 = self.rng.choice(features, size=3, replace=False)
            coef = self.rng.randn() * 0.5
            y += coef * X[:, feat1] * X[:, feat2] * X[:, feat3]
            interactions.append((feat1, feat2, feat3, 'triple', coef))
        
        return y, {'interactions': interactions}
    
    def _apply_data_transformations(self, X: np.ndarray) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Apply various data transformations as in TabPFN paper"""
        X_transformed = X.copy()
        transformations = []
        
        # Power transformation
        if self.rng.random() < self.config.apply_power_transform:
            power_features = self.rng.choice(
                X.shape[1], 
                size=self.rng.randint(1, max(2, X.shape[1] // 3)), 
                replace=False
            )
            pt = PowerTransformer(method='yeo-johnson')
            X_transformed[:, power_features] = pt.fit_transform(X[:, power_features])
            transformations.append(('power', power_features))
        
        # Quantile transformation
        if self.rng.random() < self.config.apply_quantile_transform:
            quantile_features = self.rng.choice(
                X.shape[1], 
                size=self.rng.randint(1, max(2, X.shape[1] // 3)), 
                replace=False
            )
            qt = QuantileTransformer(output_distribution='uniform', n_quantiles=min(100, X.shape[0]))
            X_transformed[:, quantile_features] = qt.fit_transform(X[:, quantile_features])
            transformations.append(('quantile', quantile_features))
        
        # Add categorical features through binning
        if self.rng.random() < self.config.add_categorical_features:
            num_cat_features = self.rng.randint(1, max(2, X.shape[1] // 4))
            cat_features = []
            
            for _ in range(num_cat_features):
                source_feat = self.rng.randint(0, X.shape[1])
                num_bins = self.rng.randint(3, 10)
                
                # Create bins
                bins = np.percentile(X[:, source_feat], np.linspace(0, 100, num_bins + 1))
                bins[0] = -np.inf
                bins[-1] = np.inf
                
                # Digitize
                cat_feat = np.digitize(X[:, source_feat], bins) - 1
                cat_features.append(cat_feat)
            
            # One-hot encode and append
            if cat_features:
                cat_features = np.column_stack(cat_features)
                # Simple encoding: just use the categorical values scaled
                X_transformed = np.hstack([X_transformed, cat_features / num_bins])
                transformations.append(('categorical', num_cat_features))
        
        return X_transformed, {'transformations': transformations}
    
    def _create_imbalanced_classes(self, y_continuous: np.ndarray, 
                                  num_classes: int) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Create potentially imbalanced classes"""
        
        # Decide on imbalance ratio
        imbalance_ratio = self.rng.uniform(*self.config.imbalance_ratio_range)
        
        if num_classes == 2:
            # Binary classification with controlled imbalance
            threshold = np.percentile(y_continuous, 100 * imbalance_ratio)
            y = (y_continuous > threshold).astype(np.int64)
        else:
            # Multi-class with potentially imbalanced distribution
            if self.rng.random() < 0.5:
                # Uniform classes
                percentiles = np.linspace(0, 100, num_classes + 1)
                thresholds = np.percentile(y_continuous, percentiles[1:-1])
                y = np.digitize(y_continuous, thresholds).astype(np.int64)
            else:
                # Imbalanced classes using exponential spacing
                base = imbalance_ratio ** (1 / (num_classes - 1))
                proportions = [base ** i for i in range(num_classes)]
                proportions = np.array(proportions) / sum(proportions)
                
                percentiles = np.cumsum([0] + proportions.tolist()) * 100
                thresholds = np.percentile(y_continuous, percentiles[1:-1])
                y = np.digitize(y_continuous, thresholds).astype(np.int64)
        
        # Compute class weights for loss weighting
        unique_classes, counts = np.unique(y, return_counts=True)
        class_weights = len(y) / (len(unique_classes) * counts)
        class_weights = dict(zip(unique_classes, class_weights))
        
        return y, {
            'imbalance_ratio': imbalance_ratio,
            'class_distribution': dict(zip(*np.unique(y, return_counts=True))),
            'class_weights': class_weights
        }

# ===========================
# Enhanced Loss Functions
# ===========================

class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    
    def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, 
                 reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        p = torch.exp(-ce_loss)
        loss = (1 - p) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            loss = self.alpha[targets] * loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class BalancedCrossEntropyLoss(nn.Module):
    """Balanced Cross Entropy Loss with automatic class weight computation"""
    
    def __init__(self, num_classes: int):
        super().__init__()
        self.num_classes = num_classes
        self.class_counts = torch.zeros(num_classes)
        self.total_samples = 0
    
    def update_class_weights(self, targets: torch.Tensor):
        """Update class weights based on observed distribution"""
        for c in range(self.num_classes):
            self.class_counts[c] += (targets == c).sum().item()
        self.total_samples += len(targets)
    
    def get_weights(self) -> torch.Tensor:
        """Compute balanced class weights"""
        weights = self.total_samples / (self.num_classes * self.class_counts + 1e-6)
        return weights / weights.sum() * self.num_classes
    
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        self.update_class_weights(targets)
        weights = self.get_weights().to(inputs.device)
        return F.cross_entropy(inputs, targets, weight=weights)

# ===========================
# Enhanced TabPFN Dataset with Augmentation
# ===========================

class AugmentedTabPFNDataset(Dataset):
    """Enhanced dataset with advanced augmentation strategies"""
    
    def __init__(self, config: TabPFNConfig, prior_config: PriorConfig, 
                 num_tasks: int, split: str = 'train'):
        self.config = config
        self.prior_config = prior_config
        self.num_tasks = num_tasks
        self.split = split
        self.generator = AdvancedDataGenerator(prior_config)
        
        # Cache for class weights
        self.class_weights_cache = {}
    
    def __len__(self):
        return self.num_tasks
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        # Set seed for validation reproducibility
        seed = idx if self.split == 'val' else None
        
        # Sample task parameters
        num_features = np.random.randint(self.config.min_features, self.config.max_features + 1)
        num_samples = np.random.randint(self.config.min_samples, self.config.max_samples + 1)
        num_classes = np.random.randint(2, self.config.max_classes + 1)
        
        # Generate dataset with augmentation
        X, y, metadata = self.generator.generate_dataset(
            num_samples=num_samples,
            num_features=num_features,
            num_classes=num_classes,
            seed=seed
        )
        
        # Additional augmentations for training
        if self.split == 'train':
            X = self._apply_training_augmentations(X)
        
        # Normalize features
        scaler = StandardScaler()
        X = scaler.fit_transform(X)
        
        # Handle variable feature dimensions
        if X.shape[1] > num_features:
            # If transformations added features, keep track
            actual_features = X.shape[1]
        else:
            actual_features = num_features
        
        # Split into context and query
        split_idx = np.random.randint(num_samples // 4, 3 * num_samples // 4)
        
        context_X = torch.from_numpy(X[:split_idx]).float()
        context_y = torch.from_numpy(y[:split_idx]).long()
        query_X = torch.from_numpy(X[split_idx:]).float()
        query_y = torch.from_numpy(y[split_idx:]).long()
        
        # Prepare class weights for focal loss
        class_weights = metadata['class_info'].get('class_weights', {})
        class_weights_tensor = torch.zeros(self.config.max_classes)
        for c, w in class_weights.items():
            if c < self.config.max_classes:
                class_weights_tensor[c] = w
        
        return {
            'context_X': context_X,
            'context_y': context_y,
            'query_X': query_X,
            'query_y': query_y,
            'num_features': actual_features,
            'num_classes': num_classes,
            'num_context': split_idx,
            'num_query': num_samples - split_idx,
            'class_weights': class_weights_tensor,
            'metadata': metadata
        }
    
    def _apply_training_augmentations(self, X: np.ndarray) -> np.ndarray:
        """Apply additional augmentations during training"""
        
        # Random feature permutation
        if np.random.random() < 0.1:
            perm = np.random.permutation(X.shape[1])
            X = X[:, perm]
        
        # Add Gaussian noise
        if np.random.random() < 0.3:
            noise_scale = np.random.uniform(0.01, 0.1)
            X += np.random.randn(*X.shape) * noise_scale
        
        # Random scaling
        if np.random.random() < 0.2:
            scale = np.random.uniform(0.8, 1.2, size=(1, X.shape[1]))
            X *= scale
        
        return X

# ===========================
# Training with Imbalance Handling
# ===========================

class ImbalanceAwareTrainer(TabPFNTrainer):
    """Trainer with focal loss and imbalance handling"""
    
    def __init__(self, model: TabPFN, config: TabPFNConfig, 
                 prior_config: PriorConfig, device: torch.device):
        super().__init__(model, config, device)
        self.prior_config = prior_config
        
        # Loss functions
        if prior_config.use_focal_loss:
            self.criterion = FocalLoss(gamma=prior_config.focal_gamma)
        else:
            self.criterion = BalancedCrossEntropyLoss(config.max_classes)
        
        # Track class distribution
        self.class_distribution = torch.zeros(config.max_classes)
        self.total_samples = 0
    
    def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """Training step with imbalance awareness"""
        self.model.train()
        
        # Move batch to device
        batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
        
        # Forward pass
        logits = self.model(batch)
        
        # Get class weights if using focal loss
        if self.prior_config.use_focal_loss and 'class_weights' in batch:
            class_weights = batch['class_weights'].mean(dim=0)  # Average over batch
            if self.criterion.alpha is None:
                self.criterion.alpha = class_weights
        
        # Compute loss
        query_y = batch['query_y']
        query_mask = batch['query_mask']
        
        # Flatten for loss computation
        logits_flat = logits.view(-1, logits.size(-1))
        labels_flat = query_y.view(-1)
        mask_flat = query_mask.view(-1)
        
        # Filter valid positions
        if mask_flat.sum() > 0:
            valid_logits = logits_flat[mask_flat]
            valid_labels = labels_flat[mask_flat]
            
            loss = self.criterion(valid_logits, valid_labels)
            
            # Update class distribution tracking
            self._update_class_distribution(valid_labels)
        else:
            loss = torch.tensor(0.0, device=self.device)
        
        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
        
        # Optimizer step
        self.optimizer.step()
        self.scheduler.step()
        
        # Compute accuracy with class-wise metrics
        with torch.no_grad():
            if mask_flat.sum() > 0:
                predictions = torch.argmax(valid_logits, dim=-1)
                accuracy = (predictions == valid_labels).float().mean().item()
                
                # Per-class accuracy
                class_accuracies = {}
                for c in range(self.config.max_classes):
                    mask_c = valid_labels == c
                    if mask_c.sum() > 0:
                        acc_c = (predictions[mask_c] == c).float().mean().item()
                        class_accuracies[c] = acc_c
            else:
                accuracy = 0.0
                class_accuracies = {}
        
        return {
            'loss': loss.item(),
            'accuracy': accuracy,
            'lr': self.optimizer.param_groups[0]['lr'],
            'class_accuracies': class_accuracies
        }
    
    def _update_class_distribution(self, labels: torch.Tensor):
        """Update tracked class distribution"""
        for c in range(self.config.max_classes):
            self.class_distribution[c] += (labels == c).sum().item()
        self.total_samples += len(labels)
    
    def get_class_distribution(self) -> Dict[int, float]:
        """Get normalized class distribution"""
        if self.total_samples == 0:
            return {}
        
        distribution = self.class_distribution / self.total_samples
        return {i: dist.item() for i, dist in enumerate(distribution) if dist > 0}

# ===========================
# Advanced Training Pipeline
# ===========================

def train_tabpfn_with_augmentation(
    scale: str = 'medium',
    num_epochs: Optional[int] = None,
    use_advanced_augmentation: bool = True,
    use_focal_loss: bool = True,
    checkpoint_path: Optional[str] = None
) -> TabPFN:
    """
    Train TabPFN with advanced augmentation and imbalance handling
    
    Args:
        scale: Model scale ('small', 'medium', 'large')
        num_epochs: Number of training epochs
        use_advanced_augmentation: Whether to use advanced data generation
        use_focal_loss: Whether to use focal loss for imbalance
        checkpoint_path: Path to resume from checkpoint
    """
    
    # Create configurations
    config = create_optimized_config(scale)
    if num_epochs is not None:
        config.num_epochs = num_epochs
    
    prior_config = PriorConfig(
        use_focal_loss=use_focal_loss,
        focal_gamma=2.0,
        focal_alpha=0.25,
        imbalance_ratio_range=(0.1, 0.9),
        use_gp_functions=0.2,
        apply_power_transform=0.3,
        apply_quantile_transform=0.2
    )
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Training {scale} model with advanced augmentation")
    print(f"Focal Loss: {use_focal_loss}, Advanced Augmentation: {use_advanced_augmentation}")
    
    # Create datasets
    print("\nCreating augmented datasets...")
    if use_advanced_augmentation:
        train_dataset = AugmentedTabPFNDataset(
            config, prior_config, config.num_train_tasks, split='train'
        )
        val_dataset = AugmentedTabPFNDataset(
            config, prior_config, config.num_val_tasks, split='val'
        )
    else:
        # Fall back to original dataset
        train_dataset = TabPFNDataset(config, config.num_train_tasks, split='train')
        val_dataset = TabPFNDataset(config, config.num_val_tasks, split='val')
    
    # Create dataloaders
    num_workers = 4 if device.type == 'cuda' else 0
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_tabpfn_batch,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    # Create model
    print(f"\nCreating {scale} model...")
    model = TabPFN(config)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create trainer
    if use_advanced_augmentation:
        trainer = ImbalanceAwareTrainer(model, config, prior_config, device)
    else:
        trainer = TabPFNTrainer(model, config, device)
    
    # Load checkpoint if provided
    start_epoch = 0
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint from {checkpoint_path}")
        trainer.load_checkpoint(checkpoint_path)
        start_epoch = trainer.best_epoch + 1
    
    # Training loop with monitoring
    print(f"\nStarting training...")
    print("=" * 80)
    
    best_balanced_accuracy = 0.0
    epochs_without_improvement = 0
    early_stopping_patience = 20
    
    for epoch in range(start_epoch, config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
        print("-" * 60)
        
        # Training phase
        train_loss = 0.0
        train_acc = 0.0
        train_steps = 0
        class_acc_accumulator = {}
        
        pbar = tqdm(train_loader, desc='Training')
        for batch in pbar:
            metrics = trainer.train_step(batch)
            train_loss += metrics['loss']
            train_acc += metrics['accuracy']
            train_steps += 1
            
            # Accumulate class accuracies
            for c, acc in metrics.get('class_accuracies', {}).items():
                if c not in class_acc_accumulator:
                    class_acc_accumulator[c] = []
                class_acc_accumulator[c].append(acc)
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{metrics['loss']:.4f}",
                'acc': f"{metrics['accuracy']:.4f}",
                'lr': f"{metrics['lr']:.6f}"
            })
        
        # Calculate average metrics
        avg_train_loss = train_loss / train_steps
        avg_train_acc = train_acc / train_steps
        
        # Calculate per-class average accuracies
        avg_class_accuracies = {}
        for c, accs in class_acc_accumulator.items():
            avg_class_accuracies[c] = np.mean(accs)
        
        # Validation phase
        val_metrics = trainer.evaluate(val_loader)
        
        # Calculate balanced accuracy
        if hasattr(trainer, 'get_class_distribution'):
            class_dist = trainer.get_class_distribution()
            if avg_class_accuracies:
                balanced_acc = np.mean(list(avg_class_accuracies.values()))
            else:
                balanced_acc = avg_train_acc
        else:
            balanced_acc = avg_train_acc
        
        # Print epoch summary
        print(f"\nTrain Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        print(f"Balanced Accuracy: {balanced_acc:.4f}")
        
        # Print class distribution if available
        if hasattr(trainer, 'get_class_distribution'):
            class_dist = trainer.get_class_distribution()
            if class_dist:
                print("\nClass Distribution:")
                for c, prop in sorted(class_dist.items()):
                    acc = avg_class_accuracies.get(c, 0.0)
                    print(f"  Class {c}: {prop:.3f} (Acc: {acc:.3f})")
        
        # Save best model based on balanced accuracy
        if balanced_acc > best_balanced_accuracy:
            best_balanced_accuracy = balanced_acc
            trainer.best_epoch = epoch
            trainer.save_checkpoint('tabpfn_best_balanced.pt')
            print(f"\nNew best balanced accuracy: {best_balanced_accuracy:.4f}")
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        # Early stopping
        if epochs_without_improvement >= early_stopping_patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            break
        
        print("-" * 80)
    
    print(f"\nTraining completed!")
    print(f"Best balanced accuracy: {best_balanced_accuracy:.4f}")
    
    # Load best model
    trainer.load_checkpoint('tabpfn_best_balanced.pt')
    
    return model

# ===========================
# Evaluation Utilities
# ===========================

def evaluate_with_class_metrics(model: TabPFN, dataloader: DataLoader, 
                               device: torch.device) -> Dict[str, Any]:
    """Evaluate model with detailed class-wise metrics"""
    model.eval()
    
    all_predictions = []
    all_labels = []
    all_logits = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            logits = model(batch)
            query_y = batch['query_y']
            query_mask = batch['query_mask']
            
            # Collect valid predictions
            for b in range(logits.size(0)):
                mask = query_mask[b]
                if mask.any():
                    batch_logits = logits[b][mask]
                    batch_labels = query_y[b][mask]
                    batch_preds = torch.argmax(batch_logits, dim=-1)
                    
                    all_logits.append(batch_logits.cpu())
                    all_labels.append(batch_labels.cpu())
                    all_predictions.append(batch_preds.cpu())
    
    # Concatenate all results
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    all_predictions = torch.cat(all_predictions, dim=0)
    
    # Calculate metrics
    overall_accuracy = (all_predictions == all_labels).float().mean().item()
    
    # Per-class metrics
    unique_classes = torch.unique(all_labels)
    class_metrics = {}
    
    for c in unique_classes:
        mask = all_labels == c
        if mask.sum() > 0:
            true_positives = ((all_predictions == c) & mask).sum().item()
            false_positives = ((all_predictions == c) & ~mask).sum().item()
            false_negatives = ((all_predictions != c) & mask).sum().item()
            
            precision = true_positives / (true_positives + false_positives + 1e-8)
            recall = true_positives / (true_positives + false_negatives + 1e-8)
            f1 = 2 * precision * recall / (precision + recall + 1e-8)
            
            class_metrics[c.item()] = {
                'count': mask.sum().item(),
                'accuracy': (all_predictions[mask] == c).float().mean().item(),
                'precision': precision,
                'recall': recall,
                'f1': f1
            }
    
    # Calculate balanced accuracy
    class_accuracies = [m['accuracy'] for m in class_metrics.values()]
    balanced_accuracy = np.mean(class_accuracies)
    
    # Calculate macro-averaged F1
    macro_f1 = np.mean([m['f1'] for m in class_metrics.values()])
    
    return {
        'overall_accuracy': overall_accuracy,
        'balanced_accuracy': balanced_accuracy,
        'macro_f1': macro_f1,
        'class_metrics': class_metrics,
        'predictions': all_predictions,
        'labels': all_labels,
        'logits': all_logits
    }

# ===========================
# Visualization Tools
# ===========================

def visualize_augmentation_effects(prior_config: PriorConfig, num_examples: int = 5):
    """Visualize the effects of data augmentation"""
    import matplotlib.pyplot as plt
    
    generator = AdvancedDataGenerator(prior_config)
    
    fig, axes = plt.subplots(num_examples, 4, figsize=(16, 4 * num_examples))
    
    for i in range(num_examples):
        # Generate dataset
        X, y, metadata = generator.generate_dataset(
            num_samples=200,
            num_features=10,
            num_classes=3
        )
        
        # Plot original features
        axes[i, 0].scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', alpha=0.6)
        axes[i, 0].set_title(f"Original (Function: {metadata['function_info']['function_class']})")
        axes[i, 0].set_xlabel("Feature 1")
        axes[i, 0].set_ylabel("Feature 2")
        
        # Plot class distribution
        unique, counts = np.unique(y, return_counts=True)
        axes[i, 1].bar(unique, counts)
        axes[i, 1].set_title("Class Distribution")
        axes[i, 1].set_xlabel("Class")
        axes[i, 1].set_ylabel("Count")
        
        # Plot feature distributions
        axes[i, 2].boxplot([X[y == c, 0] for c in unique])
        axes[i, 2].set_title("Feature 1 by Class")
        axes[i, 2].set_xlabel("Class")
        axes[i, 2].set_ylabel("Feature 1 Value")
        
        # Plot correlation matrix
        corr = np.corrcoef(X.T)
        im = axes[i, 3].imshow(corr, cmap='coolwarm', vmin=-1, vmax=1)
        axes[i, 3].set_title("Feature Correlations")
        plt.colorbar(im, ax=axes[i, 3])
    
    plt.tight_layout()
    plt.savefig('augmentation_effects.png', dpi=150)
    plt.show()

# ===========================
# Example Usage
# ===========================

def example_advanced_training():
    """Example of training with advanced augmentation"""
    print("=== TabPFN with Advanced Augmentation ===\n")
    
    # Train with all advanced features
    model = train_tabpfn_with_augmentation(
        scale='small',  # Start small for testing
        num_epochs=20,
        use_advanced_augmentation=True,
        use_focal_loss=True
    )
    
    return model

def example_visualize_augmentation():
    """Example of visualizing augmentation effects"""
    print("=== Visualizing Data Augmentation ===\n")
    
    prior_config = PriorConfig(
        use_focal_loss=True,
        imbalance_ratio_range=(0.1, 0.9),
        apply_power_transform=0.5,
        use_gp_functions=0.3
    )
    
    visualize_augmentation_effects(prior_config, num_examples=3)

def example_evaluate_with_metrics():
    """Example of detailed evaluation"""
    print("=== Detailed Model Evaluation ===\n")
    
    # Assume we have a trained model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create test dataset
    config = create_optimized_config('small')
    prior_config = PriorConfig()
    
    test_dataset = AugmentedTabPFNDataset(
        config, prior_config, num_tasks=1000, split='val'
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=collate_tabpfn_batch
    )
    
    # Load model (assuming it exists)
    model = TabPFN(config).to(device)
    
    # Evaluate
    metrics = evaluate_with_class_metrics(model, test_loader, device)
    
    print(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}")
    print(f"Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
    print(f"Macro F1-Score: {metrics['macro_f1']:.4f}")
    
    print("\nPer-Class Metrics:")
    for c, m in sorted(metrics['class_metrics'].items()):
        print(f"  Class {c}: Acc={m['accuracy']:.3f}, "
              f"Prec={m['precision']:.3f}, Rec={m['recall']:.3f}, "
              f"F1={m['f1']:.3f} (n={m['count']})")

if __name__ == "__main__":
    # Run advanced training
    model = example_advanced_training()
    
    # Visualize augmentation
    example_visualize_augmentation()
    
    # Evaluate with detailed metrics
    example_evaluate_with_metrics()
    
    print("\nAdvanced training completed!")

=== TabPFN with Advanced Augmentation ===

Using device: cuda
Training small model with advanced augmentation
Focal Loss: True, Advanced Augmentation: True

Creating augmented datasets...

Creating small model...
Model parameters: 833,482

Starting training...

Epoch 1/20
------------------------------------------------------------


Training:   0%|          | 0/313 [00:02<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8448x36 and 32x128)