In [None]:
import torch
import torch.nn as nn


class TinyNet(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=8, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)

        nn.init.xavier_normal_(self.fc1.weight)
        self.fc1.weight.requires_grad = False

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
import numpy as np


def preprocess_mnist(examples):
    images = []
    for img in examples["image"]:
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_array = (img_array - 0.1307) / 0.3081
        images.append(img_array.flatten())

    return {"pixel_values": images, "labels": examples["label"]}


def collate_fn(batch):
    pixel_values = torch.stack(
        [torch.tensor(item["pixel_values"], dtype=torch.float32) for item in batch]
    )
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    return pixel_values, labels

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from datasets import load_dataset


def get_dataloaders(batch_size=512):
    dataset = load_dataset("ylecun/mnist")

    train_dataset = dataset["train"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )

    test_dataset = dataset["test"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    val_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    return train_loader, val_loader


def generate_random_directions(dim, n_points):
    directions = np.random.randn(n_points, dim)
    norms = np.linalg.norm(directions, axis=1, keepdims=True)
    return directions / (norms + 1e-8)


@torch.no_grad()
def evaluate_weights(weights_flat, model, criterion, loader, device, max_batches=None):
    original_shape = model.fc2.weight.shape
    weights = torch.tensor(
        weights_flat.reshape(original_shape), dtype=torch.float32, device=device
    )
    model.fc2.weight.copy_(weights)

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    model.eval()
    for i, (data, target) in enumerate(loader):
        if max_batches and i >= max_batches:
            break

        data, target = data.to(device, non_blocking=True), target.to(
            device, non_blocking=True
        )
        output = model(data)

        loss = criterion(output, target)
        total_loss += loss.item() * data.size(0)

        pred = output.argmax(dim=1)
        total_correct += pred.eq(target).sum().item()
        total_samples += data.size(0)

    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples

    return avg_loss, accuracy


def random_search(
    model,
    criterion,
    train_loader,
    val_loader,
    device,
    n_iterations=10,
    points_per_iter=50,
    initial_radius=1.0,
    radius_decay=0.7,
    eval_batches=20,
):
    weight_shape = model.fc2.weight.shape
    weight_dim = np.prod(weight_shape)

    best_weights = np.random.randn(weight_dim) * 0.1
    best_loss, best_acc = evaluate_weights(
        best_weights, model, criterion, train_loader, device, max_batches=eval_batches
    )

    current_center = best_weights.copy()
    current_radius = initial_radius

    print(
        f"\n{'Iter':<6} {'Candidates':<12} {'Best Loss':<12} {'Best Acc':<12} {'Radius':<10}"
    )
    print("=" * 70)
    print(
        f"{'Init':<6} {1:<12} {best_loss:<12.4f} {best_acc:<12.2f}% {current_radius:<10.4f}"
    )

    for iteration in range(n_iterations):
        directions = generate_random_directions(weight_dim, points_per_iter)

        candidates = []
        for direction in directions:
            candidate = current_center + current_radius * direction
            loss, acc = evaluate_weights(
                candidate,
                model,
                criterion,
                train_loader,
                device,
                max_batches=eval_batches,
            )
            candidates.append((loss, acc, candidate))

        candidates.sort(key=lambda x: x[0])
        iter_best_loss, iter_best_acc, iter_best_weights = candidates[0]

        if iter_best_loss < best_loss:
            best_loss = iter_best_loss
            best_acc = iter_best_acc
            best_weights = iter_best_weights.copy()
            current_center = iter_best_weights.copy()
        else:
            current_center = best_weights.copy()

        print(
            f"{iteration:<6} {len(candidates):<12} {best_loss:<12.4f} {best_acc:<12.2f}% {current_radius:<10.4f}"
        )

        current_radius *= radius_decay

    final_loss, final_acc = evaluate_weights(
        best_weights, model, criterion, val_loader, device
    )
    print(f"\nFinal validation - Loss: {final_loss:.4f}, Acc: {final_acc:.2f}%")

    return best_weights.reshape(weight_shape)


def finetune_sgd(
    model,
    start_weights,
    criterion,
    train_loader,
    val_loader,
    device,
    epochs=20,
    lr=0.1,
    patience=5,
):
    model.fc2.weight.data.copy_(
        torch.tensor(start_weights, dtype=torch.float32, device=device)
    )

    optimizer = optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        lr=lr,
        momentum=0.9,
        weight_decay=1e-4,
    )

    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)

    best_val_loss = float("inf")
    patience_counter = 0

    print(
        f"\n{'Epoch':<8} {'Train Loss':<12} {'Val Loss':<12} {'Val Acc':<12} {'LR':<10}"
    )
    print("=" * 60)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_samples = 0

        for data, target in train_loader:
            data, target = data.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            )

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * data.size(0)
            train_samples += data.size(0)

        scheduler.step()

        avg_train_loss = train_loss / train_samples
        val_loss, val_acc = evaluate_weights(
            model.fc2.weight.data.cpu().numpy().flatten(),
            model,
            criterion,
            val_loader,
            device,
        )

        current_lr = optimizer.param_groups[0]["lr"]

        print(
            f"{epoch+1:<8} {avg_train_loss:<12.4f} {val_loss:<12.4f} {val_acc:<12.2f}% {current_lr:<10.6f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"\nEarly stopping: no improvement for {patience} epochs")
            break


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}\n")

    model = TinyNet(hidden_dim=8).to(device)
    criterion = nn.CrossEntropyLoss()

    print("Loading MNIST dataset from HuggingFace...")
    train_loader, val_loader = get_dataloaders(batch_size=512)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"Trainable parameters: {trainable_params}")
    print(f"Frozen parameters: {frozen_params}\n")

    print("=" * 70)
    print("PHASE 1: Random Search Initialization")
    print("=" * 70)

    best_weights = random_search(
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        n_iterations=15,
        points_per_iter=100,
        initial_radius=2.0,
        radius_decay=0.75,
        eval_batches=20,
    )

    print("\n" + "=" * 70)
    print("PHASE 2: Gradient-based Fine-tuning")
    print("=" * 70)

    finetune_sgd(
        model=model,
        start_weights=best_weights,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=20,
        lr=0.1,
        patience=5,
    )

    print("\n" + "=" * 70)
    print("Training Complete!")
    print("=" * 70)

In [None]:
main()