In [None]:
import torch
import torch.nn as nn


class TinyNet(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=8, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)

        nn.init.xavier_normal_(self.fc1.weight)
        self.fc1.weight.requires_grad = False

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [None]:
import numpy as np


def preprocess_mnist(examples):
    images = []
    for img in examples["image"]:
        img_array = np.array(img, dtype=np.float32) / 255.0
        img_array = (img_array - 0.1307) / 0.3081
        images.append(img_array.flatten())

    return {"pixel_values": images, "labels": examples["label"]}


def collate_fn(batch):
    pixel_values = torch.stack(
        [torch.tensor(item["pixel_values"], dtype=torch.float32) for item in batch]
    )
    labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
    return pixel_values, labels

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset


def get_dataloaders(batch_size=512):
    dataset = load_dataset("ylecun/mnist")

    train_dataset = dataset["train"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )

    test_dataset = dataset["test"].map(
        preprocess_mnist, batched=True, remove_columns=["image"]
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    val_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    return train_loader, val_loader


def generate_lattice_offsets(dim, n_points, radius=1.0):
    if dim == 1:
        angles = np.linspace(0, 2 * np.pi, n_points, endpoint=False)
        return radius * np.column_stack([np.cos(angles)])

    if dim == 2:
        angles = np.linspace(0, 2 * np.pi, n_points, endpoint=False)
        return radius * np.column_stack([np.cos(angles), np.sin(angles)])

    points = []
    phi = (1 + np.sqrt(5)) / 2

    for i in range(n_points):
        y = 1 - (i / float(n_points - 1)) * 2
        r = np.sqrt(1 - y * y)
        theta = phi * i * 2 * np.pi

        x = np.cos(theta) * r
        z = np.sin(theta) * r

        base = np.array([x, y, z])

        if dim == 3:
            points.append(base)
        else:
            extended = np.zeros(dim)
            extended[:3] = base

            for j in range(3, dim):
                angle = (i * phi + j) * 2 * np.pi / dim
                extended[j] = np.sin(angle) * radius

            points.append(extended)

    points = np.array(points)
    points = points / (np.linalg.norm(points, axis=1, keepdims=True) + 1e-8)
    points *= radius

    return points


@torch.no_grad()
def evaluate_weights(weights_flat, model, criterion, loader, device):
    original_shape = model.fc2.weight.shape
    weights = torch.tensor(
        weights_flat.reshape(original_shape), dtype=torch.float32, device=device
    )
    model.fc2.weight.copy_(weights)

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    model.eval()
    for data, target in loader:
        data, target = data.to(device, non_blocking=True), target.to(
            device, non_blocking=True
        )
        output = model(data)

        loss = criterion(output, target)
        total_loss += loss.item() * data.size(0)

        pred = output.argmax(dim=1)
        total_correct += pred.eq(target).sum().item()
        total_samples += data.size(0)

    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples

    return avg_loss, accuracy


def lattice_search(
    model,
    criterion,
    loader,
    device,
    n_iterations=4,
    points_per_iter=20,
    initial_radius=1.0,
    radius_decay=0.5,
    n_search_centers=5,
    early_stop_patience=2,
):
    weight_shape = model.fc2.weight.shape
    weight_dim = np.prod(weight_shape)

    best_weights = np.zeros(weight_dim)
    best_loss = float("inf")
    best_acc = 0.0

    search_centers = [best_weights.copy()]
    current_radius = initial_radius

    no_improve_count = 0

    print(
        f"\n{'Iter':<6} {'Centers':<10} {'Candidates':<12} {'Best Loss':<12} {'Best Acc':<12} {'Radius':<10}"
    )
    print("=" * 70)

    for iteration in range(n_iterations):
        candidates = []

        offsets = generate_lattice_offsets(
            weight_dim, points_per_iter, radius=current_radius
        )

        for center in search_centers:
            for offset in offsets:
                candidate = center + offset
                loss, acc = evaluate_weights(
                    candidate, model, criterion, loader, device
                )
                candidates.append((loss, acc, candidate))

        candidates.sort(key=lambda x: x[0])

        iter_best_loss, iter_best_acc, iter_best_weights = candidates[0]

        if iter_best_loss < best_loss:
            best_loss = iter_best_loss
            best_acc = iter_best_acc
            best_weights = iter_best_weights.copy()
            no_improve_count = 0
        else:
            no_improve_count += 1

        print(
            f"{iteration:<6} {len(search_centers):<10} {len(candidates):<12} {best_loss:<12.4f} {best_acc:<12.2f}% {current_radius:<10.4f}"
        )

        if no_improve_count >= early_stop_patience:
            print(
                f"\nEarly stopping: no improvement for {early_stop_patience} iterations"
            )
            break

        search_centers = [c[2] for c in candidates[:n_search_centers]]
        current_radius *= radius_decay

    return best_weights.reshape(weight_shape)


def finetune_sgd(
    model,
    start_weights,
    criterion,
    train_loader,
    val_loader,
    device,
    epochs=10,
    lr=0.01,
    warmup_steps=100,
    patience=3,
):
    model.fc2.weight.data.copy_(
        torch.tensor(start_weights, dtype=torch.float32, device=device)
    )

    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=0.01
    )

    total_steps = len(train_loader) * epochs
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )

    best_val_loss = float("inf")
    patience_counter = 0

    print(
        f"\n{'Epoch':<8} {'Train Loss':<12} {'Val Loss':<12} {'Val Acc':<12} {'LR':<10}"
    )
    print("=" * 60)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_samples = 0

        for data, target in train_loader:
            data, target = data.to(device, non_blocking=True), target.to(
                device, non_blocking=True
            )

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() * data.size(0)
            train_samples += data.size(0)

        avg_train_loss = train_loss / train_samples
        val_loss, val_acc = evaluate_weights(
            model.fc2.weight.data.cpu().numpy().flatten(),
            model,
            criterion,
            val_loader,
            device,
        )

        current_lr = optimizer.param_groups[0]["lr"]

        print(
            f"{epoch+1:<8} {avg_train_loss:<12.4f} {val_loss:<12.4f} {val_acc:<12.2f}% {current_lr:<10.6f}"
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"\nEarly stopping: no improvement for {patience} epochs")
            break


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}\n")

    model = TinyNet(hidden_dim=8).to(device)
    criterion = nn.CrossEntropyLoss()

    print("Loading MNIST dataset from HuggingFace...")
    train_loader, val_loader = get_dataloaders(batch_size=512)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(f"Trainable parameters: {trainable_params}")
    print(f"Frozen parameters: {frozen_params}\n")

    print("=" * 70)
    print("PHASE 1: Lattice-based Weight Space Search")
    print("=" * 70)

    best_weights = lattice_search(
        model=model,
        criterion=criterion,
        loader=val_loader,
        device=device,
        n_iterations=5,
        points_per_iter=500,
        initial_radius=2.0,
        radius_decay=0.6,
        n_search_centers=5,
        early_stop_patience=2,
    )

    print("\n" + "=" * 70)
    print("PHASE 2: Gradient-based Fine-tuning")
    print("=" * 70)

    finetune_sgd(
        model=model,
        start_weights=best_weights,
        criterion=criterion,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        epochs=30,
        lr=0.01,
        warmup_steps=20,
        patience=3,
    )

    print("\n" + "=" * 70)
    print("Training Complete!")
    print("=" * 70)

In [None]:
main()