In [None]:
import torch
import torch.nn as nn


class TinyNet(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=8, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)

        nn.init.xavier_normal_(self.fc1.weight)
        self.fc1.weight.requires_grad = False

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
import numpy as np


def preprocess_mnist(examples):
    images = []
    for img in examples["image"]:
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_array = (img_array - 0.1307) / 0.3081
        images.append(img_array.flatten())
    return {"pixel_values": images, "labels": examples["label"]}


def collate_fn(batch):
    pixel_values = torch.stack(
        [torch.tensor(item["pixel_values"], dtype=torch.float32) for item in batch]
    )
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    return pixel_values, labels

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset


def get_dataloaders(batch_size=512):
    dataset = load_dataset("ylecun/mnist")
    train_dataset = dataset["train"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )
    test_dataset = dataset["test"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )
    return train_loader, val_loader

In [None]:
import torch.nn.functional as F


@torch.no_grad()
def batched_evaluate(weights_list, model, data, target, device):
    hidden_out = model.fc1(data)
    hidden_out = torch.relu(hidden_out)

    weight_shape = model.fc2.weight.shape
    weights_tensor = torch.tensor(
        np.array(weights_list).reshape(len(weights_list), *weight_shape),
        dtype=torch.float32,
        device=device,
    )

    batch_size = data.size(0)
    n_candidates = len(weights_list)

    hidden_expanded = hidden_out.unsqueeze(0).expand(n_candidates, -1, -1)
    logits = torch.bmm(hidden_expanded, weights_tensor.transpose(1, 2))

    target_expanded = target.unsqueeze(0).expand(n_candidates, -1)

    losses = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        target_expanded.reshape(-1),
        reduction="none",
    ).reshape(n_candidates, batch_size)

    avg_losses = losses.mean(dim=1)
    preds = logits.argmax(dim=-1)
    accuracies = (preds == target_expanded).float().mean(dim=1) * 100

    return avg_losses.cpu().numpy(), accuracies.cpu().numpy()

In [None]:
def smart_init(model, train_loader, device, n_tries=20):
    weight_shape = model.fc2.weight.shape
    weight_dim = np.prod(weight_shape)

    candidates = []
    scales = np.logspace(-1, 0.5, n_tries)

    data, target = next(iter(train_loader))
    data, target = data.to(device), target.to(device)

    weights_list = [np.random.randn(weight_dim) * scale for scale in scales]
    losses, accs = batched_evaluate(weights_list, model, data, target, device)

    best_idx = losses.argmin()
    return weights_list[best_idx], losses[best_idx], accs[best_idx]

In [None]:
def fast_search(
    model,
    criterion,
    train_loader,
    device,
    n_iterations=8,
    initial_points=100,
    n_centers=5,
    initial_radius=1.5,
    radius_decay=0.7,
):
    weight_shape = model.fc2.weight.shape
    weight_dim = np.prod(weight_shape)

    print("\n" + "=" * 80)
    print("Smart Initialization")
    print("=" * 80)
    best_weights, best_loss, best_acc = smart_init(model, train_loader, device)
    print(f"Initial: Loss={best_loss:.4f}, Acc={best_acc:.2f}%")

    centers = [(best_loss, best_weights.copy(), initial_radius)]

    data_iterator = iter(train_loader)

    print(
        f"\n{'Iter':<6} {'Pts':<8} {'Centers':<8} {'Best Loss':<12} {'Best Acc':<12} {'Radius':<10}"
    )
    print("=" * 80)

    for iteration in range(n_iterations):
        try:
            data, target = next(data_iterator)
        except StopIteration:
            data_iterator = iter(train_loader)
            data, target = next(data_iterator)

        data, target = data.to(device), target.to(device)

        points_this_iter = max(30, initial_points // (iteration + 1))

        all_candidates = []
        for center_loss, center_weights, center_radius in centers:
            directions = np.random.randn(points_this_iter, weight_dim)
            directions = directions / (
                np.linalg.norm(directions, axis=1, keepdims=True) + 1e-8
            )

            weights_list = [center_weights + center_radius * d for d in directions]
            losses, accs = batched_evaluate(weights_list, model, data, target, device)

            for w, l, a in zip(weights_list, losses, accs):
                all_candidates.append((l, a, w))

        all_candidates.sort(key=lambda x: x[0])
        iter_best_loss, iter_best_acc, iter_best_weights = all_candidates[0]

        if iter_best_loss < best_loss:
            best_loss = iter_best_loss
            best_acc = iter_best_acc
            best_weights = iter_best_weights.copy()

        base_radius = initial_radius * (radius_decay ** (iteration + 1))
        centers = []
        for rank, (loss, acc, weights) in enumerate(all_candidates[:n_centers]):
            radius = base_radius * (1.0 + rank * 0.2)
            centers.append((loss, weights, radius))

        avg_radius = np.mean([r for _, _, r in centers])
        print(
            f"{iteration:<6} {points_this_iter * n_centers:<8} {n_centers:<8} {best_loss:<12.4f} {best_acc:<12.2f}% {avg_radius:<10.4f}"
        )

    return best_weights.reshape(weight_shape)

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR


def fast_finetune(
    model,
    start_weights,
    criterion,
    train_loader,
    val_loader,
    device,
    epochs=15,
    lr=0.5,
    patience=4,
):
    model.fc2.weight.data.copy_(
        torch.tensor(start_weights, dtype=torch.float32, device=device)
    )

    optimizer = optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
        momentum=0.9,
        nesterov=True,
        weight_decay=1e-4,
    )

    steps_per_epoch = len(train_loader)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=lr,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        pct_start=0.1,
        anneal_strategy="cos",
        div_factor=10,
        final_div_factor=100,
    )

    best_val_loss = float("inf")
    patience_counter = 0

    print(f"\n{'Epoch':<8} {'Train Loss':<12} {'Val Loss':<12} {'Val Acc':<12}")
    print("=" * 60)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_samples = 0

        for data, target in train_loader:
            data, target = data.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            )

            optimizer.zero_grad(set_to_none=True)
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() * data.size(0)
            train_samples += data.size(0)

        avg_train_loss = train_loss / train_samples

        if epoch % 2 == 0 or epoch == epochs - 1:
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_samples = 0

            with torch.no_grad():
                for data, target in val_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    val_loss += criterion(output, target).item() * data.size(0)
                    val_correct += output.argmax(dim=1).eq(target).sum().item()
                    val_samples += data.size(0)

            val_loss /= val_samples
            val_acc = 100.0 * val_correct / val_samples

            print(
                f"{epoch+1:<8} {avg_train_loss:<12.4f} {val_loss:<12.4f} {val_acc:<12.2f}%"
            )

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"\nEarly stopping after {epoch+1} epochs")
                break

In [None]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = TinyNet(hidden_dim=8).to(device)
    criterion = nn.CrossEntropyLoss()

    print("\nLoading MNIST...")
    train_loader, val_loader = get_dataloaders(batch_size=1024)

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable}\n")

    print("=" * 80)
    print("PHASE 1: Fast Random Search")
    print("=" * 80)

    best_weights = fast_search(
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        device=device,
        n_iterations=10,
        initial_points=120,
        n_centers=4,
        initial_radius=1.5,
        radius_decay=0.65,
    )

    print("\n" + "=" * 80)
    print("PHASE 2: Fast Fine-tuning")
    print("=" * 80)

    fast_finetune(
        model=model,
        start_weights=best_weights,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=15,
        lr=0.5,
        patience=4,
    )

    print("\n" + "=" * 80)
    print("Complete!")
    print("=" * 80)

In [None]:
main()