In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
from utils import *
from data import *
from metrics import *
from trainer import *
from models import *
import os

In [None]:
def get_architecture_name(hidden_sizes):
    """Generate a readable name for the architecture."""
    if not hidden_sizes:
        return "linear"
    return "->".join(['input'] + [str(s) for s in hidden_sizes] + ['output'])

In [None]:
def get_optimizer(model: nn.Module, config: Any) -> torch.optim.Optimizer:
    """Create optimizer with proper parameter groups."""
    norm_params = []
    other_params = []
    
    for name, param in model.named_parameters():
        if 'norm' in name:
            norm_params.append(param)
        else:
            other_params.append(param)
    if config.optimizer == 'adam':
        return torch.optim.Adam([
            {'params': norm_params, 'weight_decay': 0},
            {'params': other_params, 'weight_decay': config.weight_decay}
        ], lr=config.learning_rate)
    elif config.optimizer == 'adamw':
        return torch.optim.AdamW([
            {'params': norm_params, 'weight_decay': 0},
            {'params': other_params, 'weight_decay': config.weight_decay}
        ], lr=config.learning_rate)
    else:
        return torch.optim.SGD([
            {'params': norm_params, 'weight_decay': 0},
            {'params': other_params, 'weight_decay': config.weight_decay}
        ], lr=config.learning_rate, momentum=0.9)

In [None]:
def run_experiment():
    """Run a single experiment with wandb config."""
    with wandb.init() as run:
        config = wandb.config
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        encoder = BinaryEncoder(bits=config.num_bits)
        
        arch_str = '->'.join(['input'] + [str(s) for s in config.hidden_sizes] + ['output'])
        run.name = f"b{config.num_bits}_{arch_str}_{config.optimizer}_lr{config.learning_rate}"
        if config.use_residual:
            run.name += "_res"
        
        # Training data setup
        if config.sparse_sampling and config.train_range_max > 1000:
            train_nums = (
                list(range(1, 101)) +  # Dense for small numbers
                list(range(101, 1001, 20)) +  # Sparse for medium
                list(range(1001, config.train_range_max, 200))  # Very sparse for large
            )
            train_data = FizzBuzzDataset(train_nums, encoder)
        else:
            train_data = FizzBuzzDataset((1, config.train_range_max), encoder)
        
        # Validation setup
        val_range = (config.train_range_max, config.train_range_max + 50)
        val_data = FizzBuzzDataset(val_range, encoder)
        
        # DataLoader setup
        train_loader = DataLoader(
            train_data, 
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=device=='cuda'
        )
        
        val_loader = DataLoader(
            val_data, 
            batch_size=config.batch_size * 2,
            shuffle=False
        )
        
        # Model setup
        model = LinearFizzBuzz(
            input_size=encoder.get_input_size(),
            hidden_sizes=config.hidden_sizes,
            dropout=config.dropout,
            activation=config.activation,
            use_residual=config.use_residual,
            use_layer_norm=config.use_layer_norm,
            bottleneck_factor=config.bottleneck_factor,
            num_residual_blocks=config.num_residual_blocks
        )
        
        # Training setup
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        optimizer = get_optimizer(model, config)
        
        # Learning rate scheduler
        if config.lr_scheduler == 'reduce_on_plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.lr_factor,
                patience=config.lr_patience,
                min_lr=1e-6
            )
        elif config.lr_scheduler == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=config.epochs,
                eta_min=1e-6
            )
        else:
            scheduler = None
        
        # Training
        trainer = FizzBuzzTrainer(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            gradient_clip=config.gradient_clip,
            use_wandb=True
        )
        
        # Train the model
        metrics = trainer.train(
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=config.epochs
        )
        
        # Test ranges - reduced sizes
        test_ranges = [
            (1, 100),       # Training range sample
            (101, 500),     # Near generalization
            (501, 1000),    # Mid generalization
            (1001, 2000)    # Far generalization
        ]
        
        # Test on all ranges
        test_results = {}
        for start, end in test_ranges:
            results = trainer.test(start, end, encoder)
            accuracy = evaluate_model_accuracy(results, start, end)
            test_results[f"accuracy_{start}_{end}"] = accuracy
        
        # Logging
        complexity_stats = model.get_complexity_stats()
        wandb.run.summary.update({
            **complexity_stats,
            'num_bits': config.num_bits,
            "architecture": arch_str,
            "n_layers": len(config.hidden_sizes),
            "widest_layer": max(config.hidden_sizes) if config.hidden_sizes else 4,
            "activation": config.activation,
            "dropout": config.dropout,
            "weight_decay": config.weight_decay,
            "train_range": config.train_range_max,
            "use_residual": config.use_residual,
            "use_layer_norm": config.use_layer_norm,
            "bottleneck_factor": config.bottleneck_factor,
            "num_residual_blocks": config.num_residual_blocks,
            "sparse_sampling": config.sparse_sampling,
            **test_results
        })
        
        # Save best model
        if test_results["accuracy_101_500"] > wandb.run.summary.get("best_accuracy_101_500", 0):
            model_path = os.path.join(wandb.run.dir, "best_model.pt")
            
            # Convert config to a dictionary of basic types
            config_dict = {
                'num_bits': config.num_bits,
                'hidden_sizes': config.hidden_sizes,
                'dropout': config.dropout,
                'activation': config.activation,
                'use_residual': config.use_residual,
                'use_layer_norm': config.use_layer_norm,
                'bottleneck_factor': config.bottleneck_factor,
                'num_residual_blocks': config.num_residual_blocks,
                'learning_rate': config.learning_rate,
                'weight_decay': config.weight_decay,
                'optimizer': config.optimizer,
                'train_range_max': config.train_range_max,
                'sparse_sampling': config.sparse_sampling
            }
            
            save_dict = {
                'model_state_dict': model.state_dict(),
                'config': config_dict,
                'test_results': test_results,
                'complexity_stats': complexity_stats
            }
            
            torch.save(save_dict, model_path)
            wandb.save(model_path)
            wandb.run.summary["best_accuracy_101_500"] = test_results["accuracy_101_500"]

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_accuracy',
        'goal': 'maximize'
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 20,
        'eta': 2,
        'max_iter': 100
    },
    'parameters': {
        # Input representation
        'num_bits': {
            'values': [16, 24, 32]
        },
        
        # Architecture
        'hidden_sizes': {
            'values': [
                # Medium
                [256],
                [256, 256],
                [256, 256, 256],
                [512, 256],
                
                # Large
                [512, 512],
                [512, 512, 256],
                [256, 256, 256, 256],
                
                # Bottleneck architectures
                [512, 128, 512],
                [256, 64, 256],
                
                # Residual-friendly
                [256, 256, 256, 256],
                [512, 512, 512]
            ]
        },
        
        # Residual configuration
        'use_residual': {
            'values': [True, False]
        },
        'use_layer_norm': {
            'values': [True, False]
        },
        'num_residual_blocks': {
            'values': [2, 3]
        },
        'bottleneck_factor': {
            'values': [0.25, 0.5]
        },
        
        # Training parameters
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-4,
            'max': 5e-3
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'epochs': {
            'values': [150, 300, 500]
        },
        
        # Regularization
        'dropout': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.3
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 1e-6,
            'max': 1e-4
        },
        'gradient_clip': {
            'values': [0.0, 1.0, 5.0]
        },
        
        # Optimization
        'optimizer': {
            'values': ['adam', 'adamw']
        },
        'lr_scheduler': {
            'values': ['reduce_on_plateau', 'cosine', 'none']
        },
        'lr_patience': {
            'values': [5, 10, 15]
        },
        'lr_factor': {
            'values': [0.1, 0.5]
        },
        
        # Model components
        'activation': {
            'values': ['relu', 'gelu']
        },
        
        # Training data configuration
        'train_range_max': {
            'values': [100, 1000, 10000]
        },
        'sparse_sampling': {
            'values': [True, False]
        }
    },
    'run_cap': 100  # Maximum number of runs in the sweep
}

In [None]:
if __name__ == "__main__":
    wandb.login()
    sweep_id = wandb.sweep(sweep_config, project="fizzbuzz-linear")
    wandb.agent(sweep_id, function=run_experiment)