# Hyperparameter Optimization (HPO) - Complete Implementation

This notebook contains a complete implementation of various HPO algorithms:
1. **Basic Random Search** - Simple random sampling of hyperparameters
2. **Asynchronous Random Search** - Parallel random search using Syne Tune
3. **Successive Halving (SH)** - Synchronous multi-fidelity optimization
4. **ASHA (Asynchronous Successive Halving)** - Asynchronous multi-fidelity optimization

We use FashionMNIST dataset and train both LeNet and Softmax Regression models.

## 1. Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision tqdm matplotlib scipy numpy
# Or
'''!pip install -r requirements.txt'''

## 2. Import Libraries

In [None]:
import numpy as np
import time
import torch
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from scipy import stats
from collections import defaultdict
from abc import ABC, abstractmethod
import argparse

## 3. Configuration Class

In [None]:
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    '''@staticmethod
    def new_parser(name=None):
        return argparse.ArgumentParser(prog=name)
    
    @staticmethod
    def add_training_argument(parser):
        parser.add_argument(
            "--train_mode",
            type=str,
            default="asha_hpo",
            choices=["hpo", "fixed", "async_hpo", "multi_fidelity_hpo", "asha_hpo"],
        )
        parser.add_argument(
            "--model_name", type=str, default="lenet", choices=["lenet", "softmax"]
        )
        parser.add_argument("--num_epochs", type=int, default=1)
        parser.add_argument("--learning_rate", type=float, default=0.1)
        parser.add_argument("--batch_size", type=int, default=256)
        parser.add_argument("--num_workers", type=int, default=2)
        parser.add_argument("--num_outputs", type=int, default=10)
        parser.add_argument("--num_trials", type=int, default=10)
        parser.add_argument("--max_wallclock_time", type=int, default=10 * 60)
        parser.add_argument("--eta", type=int, default=2)
        parser.add_argument("--min_number_of_epochs", type=int, default=10)
        parser.add_argument("--max_number_of_epochs", type=int, default=50)
        parser.add_argument("--prefact", type=int, default=1)'''

print(f"Using device: {Config.device}")

## 4. Model Definitions

In [None]:
class SoftmaxRegression(nn.Module):
    """Softmax Regression (Multinomial Logistic Regression) model."""
    
    def __init__(self, num_outputs: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(), 
            nn.Linear(784, num_outputs)  # 28x28 → 784 input features
        )
    
    def forward(self, X):
        return self.net(X)
    
    def loss(self, Y_hat, Y, averaged=True):
        Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
        Y = Y.reshape((-1,))
        return F.cross_entropy(Y_hat, Y, reduction="mean" if averaged else "none")


def init_cnn(module):
    """Xavier initialization for CNN weights"""
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.xavier_uniform_(module.weight)


class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes),
        )
    
    def forward(self, X):
        return self.net(X)
    
    def loss(self, Y_hat, Y, averaged=True):
        Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
        Y = Y.reshape((-1,))
        return F.cross_entropy(Y_hat, Y, reduction="mean" if averaged else "none")
    
    def apply_init(self):
        self.apply(init_cnn)

## 5. Utility Functions

In [None]:
class Utils:
    @staticmethod
    def load_fashion_mnist(batch_size):
        """Load FashionMNIST dataset with train/validation/test split."""
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        # Load full datasets
        train_dataset = datasets.FashionMNIST(
            root="./data", train=True, transform=transform, download=True
        )
        test_dataset = datasets.FashionMNIST(
            root="./data", train=False, transform=transform, download=True
        )
        
        # Split training data into train and validation sets
        train_size = int(0.8 * len(train_dataset))  # 80% for training
        val_size = len(train_dataset) - train_size  # 20% for validation
        
        train_subset, val_subset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )
        
        # Create data loaders
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        return train_loader, val_loader, test_loader
    
    @staticmethod
    def accuracy(y_hat, y):
        """Compute number of correct predictions."""
        preds = torch.argmax(y_hat, dim=1)
        return (preds == y).float().sum()
    
    @staticmethod
    def build_model(args):
        if args.model_name == "softmax":
            return SoftmaxRegression(num_outputs=args.num_outputs)
        elif args.model_name == "lenet":
            model = LeNet(num_classes=args.num_outputs)
            model.apply_init()
            return model
        else:
            raise NotImplementedError(
                'Model type not supported, use "softmax" or "lenet" instead'
            )

## 6. Trainer Class

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, test_loader, lr, num_epochs):
        self.device = Config.device
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        self.num_epochs = num_epochs
    
    def fit(self):
        for epoch in range(self.num_epochs):
            self.train_epoch(epoch)
    
    def train_epoch(self, epoch):
        self.model.train()
        for X, y in tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}"):
            X, y = X.to(self.device), y.to(self.device)
            y_hat = self.model(X)
            loss = self.model.loss(y_hat, y)
            
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
    
    def evaluate_train(self):
        self.model.eval()
        total_correct, total_samples = 0, 0
        
        with torch.no_grad():
            for X, y in self.train_loader:
                X, y = X.to(self.device), y.to(self.device)
                y_hat = self.model(X)
                total_correct += Utils.accuracy(y_hat, y).item()
                total_samples += y.size(0)
        
        return total_correct / total_samples
    
    def evaluate_test(self):
        self.model.eval()
        total_correct, total_samples = 0, 0
        
        with torch.no_grad():
            for X, y in self.test_loader:
                X, y = X.to(self.device), y.to(self.device)
                y_hat = self.model(X)
                total_correct += Utils.accuracy(y_hat, y).item()
                total_samples += y.size(0)
        
        return total_correct / total_samples
    
    def evaluate_val(self):
        self.model.eval()
        total_correct, total_samples = 0, 0
        
        with torch.no_grad():
            for X, y in self.val_loader:
                X, y = X.to(self.device), y.to(self.device)
                y_hat = self.model(X)
                total_correct += Utils.accuracy(y_hat, y).item()
                total_samples += y.size(0)
        
        return total_correct / total_samples
    
    def validation_error(self):
        return 1.0 - self.evaluate_val()

## 7. HPO Base Classes (Searcher, Scheduler, Tuner)

In [None]:
class HPOSeacher(ABC):
    @abstractmethod
    def sample_config(self) -> dict:
        """Sample new hyperparameter configuration"""
        pass
    
    def update(self, config: dict, error: float, additional_info=None):
        """Update searcher state after trial completion"""
        pass


class HPOScheduler(ABC):
    @abstractmethod
    def suggest(self) -> dict:
        """Suggest next configuration to evaluate"""
        pass
    
    @abstractmethod
    def update(self, config: dict, error: float, info=None):
        """Update scheduler after trial completion"""
        pass


class HPOTuner:
    def __init__(self, scheduler: HPOScheduler, objective_fn: callable):
        self.scheduler = scheduler
        self.objective_fn = objective_fn
        self.incumbent = None  # Best performing configuration
        self.incumbent_error = None  # Lowest validation error
        self.incumbent_trajectory = []  # Lowest validation errors over time
        self.cumulative_runtime = []
        self.current_runtime = 0
        self.records = []
    
    def run(self, number_of_trials):
        for i in range(number_of_trials):
            start_time = time.time()
            config = self.scheduler.suggest()
            print(f"Trial {i} config: {config}")
            error = self.objective_fn(config)
            self.scheduler.update(config, error)
            runtime = time.time() - start_time
            self.bookkeeping(config, error, runtime)
            print(f"error: {error:.4f} - runtime: {runtime:.2f}")
    
    def bookkeeping(self, config: dict, error: float, runtime: float):
        """Track best configuration and respective performance"""
        self.records.append({"config": config, "error": error, "runtime": runtime})
        # Update incumbent
        if self.incumbent is None or error < self.incumbent_error:
            self.incumbent = config
            self.incumbent_error = error
        # Track trajectories
        self.incumbent_trajectory.append(self.incumbent_error)
        self.current_runtime += runtime
        self.cumulative_runtime.append(self.current_runtime)
    
    def get_best_config(self):
        return self.incumbent, self.incumbent_error


class RandomSearcher(HPOSeacher):
    def __init__(self, config_space: dict, initial_config: dict):
        self.config_space = config_space
        self.initial_config = initial_config
    
    def sample_config(self) -> dict:
        """Sample random configuration from config space"""
        if self.initial_config is not None:
            result = self.initial_config
            self.initial_config = None  # Clear after first use
            return result
        random_config = {key: domain.rvs() for key, domain in self.config_space.items()}
        return random_config

## 8. Basic Scheduler (Random Search)

In [None]:
class BasicScheduler(HPOScheduler):
    def __init__(self, searcher):
        self.searcher = searcher
    
    def suggest(self) -> dict:
        """Suggest next configuration"""
        return self.searcher.sample_config()
    
    def update(self, config: dict, error: float, info=None):
        """Update searcher with trial results"""
        self.searcher.update(config, error, additional_info=info)

## 9. Multi-Fidelity Scheduler (Synchronous Successive Halving)

In [None]:
class MultiFidelityScheduler(HPOScheduler):
    def __init__(self, searcher, eta, r_min, r_max, prefact):
        self.searcher = searcher
        self.eta = eta
        self.r_min = r_min
        self.r_max = r_max
        self.prefact = prefact
        self.K = int(np.log(r_max / r_min) / np.log(eta))
        self.rung_levels = [r_min * (eta**k) for k in range(self.K + 1)]
        if r_max not in self.rung_levels:
            self.rung_levels.append(r_max)
            self.K += 1
        # Bookkeeping
        self.observed_error_at_rungs = defaultdict(list)
        self.all_observed_error_at_rungs = defaultdict(list)
        self.queue = []
    
    def suggest(self):
        if len(self.queue) == 0:
            # Start a new round of successive halving
            n0 = int(self.prefact * self.eta**self.K)
            for _ in range(n0):
                config = self.searcher.sample_config()
                config["num_epochs"] = self.r_min
                self.queue.append(config)
        return self.queue.pop()
    
    def update(self, config: dict, error: float, info=None):
        ri = int(config["num_epochs"])
        self.searcher.update(config, error, additional_info=info)
        self.all_observed_error_at_rungs[ri].append((config, error))
        if ri < self.r_max:
            self.observed_error_at_rungs[ri].append((config, error))
            ki = self.K - self.rung_levels.index(ri)
            ni = int(self.prefact * (self.eta**ki))
            if len(self.observed_error_at_rungs[ri]) >= ni:
                kiplus1 = ki - 1
                niplus1 = int(self.prefact * (self.eta**kiplus1))
                best_performing_configurations = self.get_top_n_configurations(
                    rung_level=ri, n=niplus1
                )
                riplus1 = self.rung_levels[self.K - kiplus1]
                self.queue = [
                    dict(config, num_epochs=riplus1)
                    for config in best_performing_configurations
                ] + self.queue
                self.observed_error_at_rungs[ri] = []
    
    def get_top_n_configurations(self, rung_level, n):
        rung = self.observed_error_at_rungs[rung_level]
        if not rung:
            return []
        sorted_rung = sorted(rung, key=lambda x: x[1])
        return [x[0] for x in sorted_rung[:n]]

## 10. ASHA Scheduler (Asynchronous Successive Halving)

In [None]:
class ASHAScheduler(HPOScheduler):
    """
    Asynchronous Successive Halving Algorithm (ASHA)
    
    Key differences from synchronous SH:
    - Promotes configs as soon as enough results are available (not all)
    - No synchronization barriers, workers stay busy
    - Checks rungs from top to bottom for promotion opportunities
    """
    def __init__(self, searcher, eta, r_min, r_max, prefact=1):
        self.searcher = searcher
        self.eta = eta
        self.r_min = r_min
        self.r_max = r_max
        self.prefact = prefact
        
        # Compute rung levels
        self.K = int(np.log(r_max / r_min) / np.log(eta))
        self.rung_levels = [r_min * (eta**k) for k in range(self.K + 1)]
        if r_max not in self.rung_levels:
            self.rung_levels.append(r_max)
            self.K += 1
        
        # Track completed trials at each rung
        self.completed_trials_at_rungs = defaultdict(list)  # (config, error) pairs
        
        # Track which configs have been promoted from each rung
        self.promoted_configs = defaultdict(set)  # rung -> set of config hashes
        
        # Track number of configs started at each rung
        self.configs_started_at_rung = defaultdict(int)
    
    def _config_hash(self, config):
        """Create hashable representation of config (excluding num_epochs)"""
        items = [(k, v) for k, v in sorted(config.items()) if k != 'num_epochs']
        return tuple(items)
    
    def suggest(self):
        """
        ASHA suggest logic:
        1. Check rungs from top to bottom for promotion opportunities
        2. If found, promote a config to next rung
        3. Otherwise, start new config at r_min
        """
        # Check rungs from highest to lowest (excluding r_max)
        for i in range(len(self.rung_levels) - 2, -1, -1):
            rung = self.rung_levels[i]
            next_rung = self.rung_levels[i + 1]
            
            # Number of configs that should be started at this rung
            ki = self.K - i
            ni = int(self.prefact * (self.eta ** ki))
            
            # Check if we have enough completed trials to consider promotion
            completed = self.completed_trials_at_rungs[rung]
            
            if len(completed) >= self.eta:  # Need at least eta completed trials
                # Get top 1/eta configs that haven't been promoted yet
                sorted_trials = sorted(completed, key=lambda x: x[1])  # Sort by error
                
                for config, error in sorted_trials:
                    config_hash = self._config_hash(config)
                    
                    # If this config hasn't been promoted from this rung yet
                    if config_hash not in self.promoted_configs[rung]:
                        # Check if we should promote (top 1/eta fraction)
                        top_k = max(1, len(completed) // self.eta)
                        top_configs = [self._config_hash(c) for c, _ in sorted_trials[:top_k]]
                        
                        if config_hash in top_configs:
                            # Promote this config
                            self.promoted_configs[rung].add(config_hash)
                            promoted_config = dict(config)
                            promoted_config['num_epochs'] = next_rung
                            return promoted_config
        
        # No promotion opportunity found, start new config at r_min
        new_config = self.searcher.sample_config()
        new_config['num_epochs'] = self.r_min
        self.configs_started_at_rung[self.r_min] += 1
        return new_config
    
    def update(self, config: dict, error: float, info=None):
        """Record completed trial"""
        ri = int(config['num_epochs'])
        
        # Update searcher
        self.searcher.update(config, error, additional_info=info)
        
        # Record this completion
        self.completed_trials_at_rungs[ri].append((config, error))

## 11. HPO Main Class with All Methods

In [None]:
class HPO:
    @staticmethod
    def hpo_objective_fn(args):
        def hpo_objective(config):
            lr = config.get("learning_rate", args.learning_rate)
            batch_size = config.get("batch_size", args.batch_size)
            num_epochs = config.get("num_epochs", args.num_epochs)  # FIX: Read from config
            train_loader, val_loader, test_loader = Utils.load_fashion_mnist(batch_size)
            model = Utils.build_model(args)
            trainer = Trainer(
                model,
                train_loader,
                val_loader,
                test_loader,
                lr=lr,
                num_epochs=num_epochs,
            )
            trainer.fit()
            val_err = trainer.validation_error()
            return val_err
        
        return hpo_objective
    
    @staticmethod
    def random_search(args, config_space, initial_config):
        searcher = RandomSearcher(config_space, initial_config)
        scheduler = BasicScheduler(searcher)
        objective_fn = HPO.hpo_objective_fn(args)
        tuner = HPOTuner(scheduler=scheduler, objective_fn=objective_fn)
        print(f"Starting HPO with {args.num_trials} trials")
        tuner.run(number_of_trials=args.num_trials)
        best_config, best_score = tuner.get_best_config()
        print("\n" + "=" * 32)
        print(f"HPO Summary:")
        print(f"Best config: {best_config}")
        print(f"Best validation error: {best_score:.4f}")
        print(f"Total runtime: {tuner.current_runtime:.2f}s")
        print(f"Average time per trial: {tuner.current_runtime/args.num_trials:.2f}s")
        return best_config, best_score, tuner
    
    @staticmethod
    def multi_fidelity_random_search(args, config_space, initial_config):
        searcher = RandomSearcher(config_space, initial_config)
        scheduler = MultiFidelityScheduler(
            searcher=searcher,
            eta=args.eta,
            r_min=args.min_number_of_epochs,
            r_max=args.max_number_of_epochs,
            prefact=args.prefact,
        )
        objective_fn = HPO.hpo_objective_fn(args)
        tuner = HPOTuner(scheduler=scheduler, objective_fn=objective_fn)
        print(f"Starting Multi-Fidelity HPO with {args.num_trials} trials")
        tuner.run(number_of_trials=args.num_trials)
        best_config, best_score = tuner.get_best_config()
        print("\n" + "=" * 32)
        print(f"Multi-Fidelity HPO Summary:")
        print(f"Best config: {best_config}")
        print(f"Best validation error: {best_score:.4f}")
        print(f"Total runtime: {tuner.current_runtime:.2f}s")
        print(f"Average time per trial: {tuner.current_runtime/args.num_trials:.2f}s")
        return best_config, best_score, tuner
    
    @staticmethod
    def asha_random_search(args, config_space, initial_config):
        """
        Asynchronous Successive Halving Algorithm (ASHA)
        """
        searcher = RandomSearcher(config_space, initial_config)
        scheduler = ASHAScheduler(
            searcher=searcher,
            eta=args.eta,
            r_min=args.min_number_of_epochs,
            r_max=args.max_number_of_epochs,
            prefact=args.prefact,
        )
        objective_fn = HPO.hpo_objective_fn(args)
        tuner = HPOTuner(scheduler=scheduler, objective_fn=objective_fn)
        
        print(f"Starting ASHA with {args.num_trials} trials")
        print(f"Rung levels: {scheduler.rung_levels}")
        print(f"eta={args.eta}, r_min={args.min_number_of_epochs}, r_max={args.max_number_of_epochs}")
        
        tuner.run(number_of_trials=args.num_trials)
        best_config, best_score = tuner.get_best_config()
        
        # Print rung statistics
        print("\n" + "=" * 50)
        print("ASHA Rung Statistics:")
        for rung in scheduler.rung_levels:
            n_completed = len(scheduler.completed_trials_at_rungs[rung])
            n_promoted = len(scheduler.promoted_configs[rung])
            print(f"  Rung {rung:3d}: {n_completed:3d} completed, {n_promoted:3d} promoted")
        
        print("\n" + "=" * 50)
        print(f"ASHA HPO Summary:")
        print(f"Best config: {best_config}")
        print(f"Best validation error: {best_score:.4f}")
        print(f"Total runtime: {tuner.current_runtime:.2f}s")
        print(f"Average time per trial: {tuner.current_runtime/args.num_trials:.2f}s")
        
        return best_config, best_score, tuner
    
    @staticmethod
    def plot_hpo_progress(tuner, save_path=None):
        """Plot HPO progress over time"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Plot incumbent error vs trials
        ax1.plot(range(len(tuner.incumbent_trajectory)), tuner.incumbent_trajectory)
        ax1.set_xlabel("Trial")
        ax1.set_ylabel("Best Validation Error")
        ax1.set_title("HPO Progress: Best Error vs Trials")
        ax1.grid(True)
        
        # Plot incumbent error vs cumulative runtime
        ax2.plot(tuner.cumulative_runtime, tuner.incumbent_trajectory)
        ax2.set_xlabel("Cumulative Runtime (s)")
        ax2.set_ylabel("Best Validation Error")
        ax2.set_title("HPO Progress: Best Error vs Time")
        ax2.grid(True)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.show()
        
        return fig

## 12. Create Arguments Class for Notebook

In [None]:
class Args:
    """Simple arguments class for notebook usage"""
    def __init__(self):
        self.model_name = "lenet"
        self.num_epochs = 10
        self.learning_rate = 0.1
        self.batch_size = 256
        self.num_workers = 2
        self.num_outputs = 10
        self.num_trials = 20
        self.max_wallclock_time = 600
        self.eta = 2
        self.min_number_of_epochs = 5
        self.max_number_of_epochs = 20
        self.prefact = 1

## 13. Example 1: Basic Random Search

In [None]:
# Configure arguments
args = Args()

# Define search space
config_space = {
    "learning_rate": stats.loguniform(1e-4, 1),
    "batch_size": stats.randint(32, 512),
}

initial_config = {
    "learning_rate": 0.1,
    "batch_size": 256,
}

# Run basic random search
best_config, best_score, tuner = HPO.random_search(
    args, config_space=config_space, initial_config=initial_config
)

# Plot progress
HPO.plot_hpo_progress(tuner)

## 14. Example 2: Successive Halving (Multi-Fidelity)

In [None]:
# Configure arguments for multi-fidelity
args = Args()
args.model_name = "lenet"
args.num_trials = 10
args.eta = 2
args.min_number_of_epochs = 5
args.max_number_of_epochs = 20
args.prefact = 1

# Define search space (same as before)
config_space = {
    "learning_rate": stats.loguniform(1e-4, 1),
    "batch_size": stats.randint(32, 512),
}

initial_config = {
    "learning_rate": 0.1,
    "batch_size": 256,
}

# Run multi-fidelity HPO (Successive Halving)
best_config, best_score, tuner = HPO.multi_fidelity_random_search(
    args, config_space=config_space, initial_config=initial_config
)

# Plot progress
HPO.plot_hpo_progress(tuner)

## 15. Example 3: ASHA (Asynchronous Successive Halving)

In [None]:
# Configure arguments for ASHA
args = Args()
args.model_name = "lenet"
args.num_trials = 20
args.eta = 2
args.min_number_of_epochs = 5
args.max_number_of_epochs = 20
args.prefact = 1

# Define search space
config_space = {
    "learning_rate": stats.loguniform(1e-4, 1),
    "batch_size": stats.randint(32, 512),
}

initial_config = {
    "learning_rate": 0.1,
    "batch_size": 256,
}

# Run ASHA
best_config, best_score, tuner = HPO.asha_random_search(
    args, config_space=config_space, initial_config=initial_config
)

# Plot progress
HPO.plot_hpo_progress(tuner)

## 16. Train Final Model with Best Config

In [None]:
# Use the best config from ASHA to train a final model
print(f"\nTraining final model with best config: {best_config}")

train_loader, val_loader, test_loader = Utils.load_fashion_mnist(
    best_config["batch_size"]
)
model = Utils.build_model(args)
trainer = Trainer(
    model,
    train_loader,
    val_loader,
    test_loader,
    lr=best_config["learning_rate"],
    num_epochs=args.max_number_of_epochs,  # Train with full epochs
)

trainer.fit()

train_acc = trainer.evaluate_train()
val_acc = 1.0 - trainer.validation_error()
test_acc = trainer.evaluate_test()

print(f"\nFinal Results with ASHA-tuned hyperparameters:")
print(f"Train accuracy: {train_acc:.4f}")
print(f"Validation accuracy: {val_acc:.4f}")
print(f"Test accuracy: {test_acc:.4f}")

## 17. Summary and Key Takeaways

### HPO Algorithms Comparison:

| Algorithm | Synchronization | Multi-Fidelity | Best Use Case |
|-----------|----------------|----------------|---------------|
| **Random Search** | Sequential | No | Small search spaces, baseline |
| **Successive Halving** | Synchronous | Yes | Limited parallel workers |
| **ASHA** | Asynchronous | Yes | Many parallel workers, large-scale |

### Key Concepts:

1. **Resource**: In this implementation, the resource is `num_epochs` (number of training epochs)
2. **Rung**: A level in multi-fidelity optimization where configs are evaluated with the same resource
3. **eta (η)**: Reduction factor - determines how many configs to keep at each rung (keep 1/η)
4. **r_min**: Minimum resource (lowest number of epochs)
5. **r_max**: Maximum resource (highest number of epochs)

### ASHA Advantages:

- **No worker idle time**: Workers don't wait for synchronization
- **Better resource utilization**: Configs promoted as soon as possible
- **Scalable**: Works well with many parallel workers
- **Early stopping**: Poor configs stopped early, saving computation

### Recommended Settings:

For your default config (eta=2, r_min=10, r_max=50):
- **n₀ = 4** configs at first rung (synchronous SH)
- **Recommended trials for ASHA**: 20-30 trials
- **Rungs**: [10, 20, 40, 50] epochs

### Next Steps:

1. Experiment with different `eta` values (2, 3, 4)
2. Try different search spaces (add dropout, weight decay)
3. Test with different models (ResNet, VGG, etc.)
4. Scale to larger datasets (CIFAR-10, ImageNet)
5. Integrate Bayesian optimization for smarter search