# Extended PyTorch Experiments for MNIST

This notebook builds on the earlier feed-forward MNIST models and layers in a handful of modern training tricks. The goal is to push past the ~97–98% accuracy plateau by combining better regularisation, smarter optimisation schedules, light data augmentation, and systematic experiment tracking with Weights & Biases (W&B).

## What You Will Do
- Reuse the earlier data preparation but add optional augmentation that still preserves digit legibility.
- Experiment with deeper fully-connected networks that use GELU activations, batch normalisation, and dropout.
- Train with label smoothing, gradient clipping, and the One-Cycle learning rate policy.
- Log every run (hyperparameters, metrics, and sample failures) to W&B so you can compare experiments later.
- Surface misclassified examples directly in the notebook and in the W&B UI for qualitative inspection.

## Weights & Biases Setup
1. (Once per machine) Create or log into a [Weights & Biases](https://wandb.ai) account.
2. In a terminal (with the `mlp` Conda env active) run `wandb login` and paste your API key. The key is stored in `~/.netrc` so subsequent notebooks can authenticate automatically.
3. If you prefer offline logging (e.g., while travelling), set `WANDB_MODE=offline` before launching Jupyter. You can later sync runs with `wandb sync run-directory`.
4. Fill in the `wandb_project` (and optionally `wandb_entity`) in the config cell below so runs land in the desired project.

In [1]:
import wandb



In [2]:
import math
import os
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Make sure wandb is available after optional install
import wandb

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Running on: {DEVICE}")

MNIST_MEAN = 0.1307
MNIST_STD = 0.3081


Running on: mps


In [3]:
def set_seed(seed: int = 42) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

def to_display_image(tensor: torch.Tensor) -> np.ndarray:
    """Convert a normalised tensor to a clipped 2D numpy array for visualization."""
    arr = tensor.detach().cpu() * MNIST_STD + MNIST_MEAN
    arr = arr.clamp(0.0, 1.0)
    return arr.squeeze().numpy()


In [4]:
def build_transforms(augment: bool = True) -> Tuple[transforms.Compose, transforms.Compose]:
    """Return training and evaluation transforms."""
    normalize = transforms.Normalize((0.1307,), (0.3081,))
    if augment:
        train_tfms = transforms.Compose([
            transforms.RandomApply([transforms.RandomRotation(10)], p=0.5),
            transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.08, 0.08))], p=0.5),
            transforms.ToTensor(),
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.0),
            normalize,
        ])
    else:
        train_tfms = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    eval_tfms = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    return train_tfms, eval_tfms

In [5]:
def build_dataloaders(data_dir: Optional[str], batch_size: int, valid_split: int = 10_000, augment: bool = True) -> Tuple[DataLoader, DataLoader, DataLoader]:
    data_root = Path(data_dir) if data_dir else Path(os.environ.get("MLP_DATA_DIR", Path.cwd() / "data"))
    data_root.mkdir(parents=True, exist_ok=True)

    train_tfms, eval_tfms = build_transforms(augment=augment)

    train_full = datasets.MNIST(root=data_root, train=True, download=True, transform=train_tfms)
    test_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=eval_tfms)

    if valid_split >= len(train_full):
        raise ValueError("valid_split must be smaller than the size of the training set")

    train_size = len(train_full) - valid_split
    train_dataset, valid_dataset = random_split(train_full, [train_size, valid_split], generator=torch.Generator().manual_seed(42))
    valid_dataset.dataset.transform = eval_tfms  # disable augmentation for validation

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, valid_loader, test_loader

In [6]:
class FeedForwardNet(nn.Module):
    def __init__(self, hidden_units: Iterable[int], dropout: float) -> None:
        super().__init__()
        dims = [28 * 28, *hidden_units, 10]
        layers: List[nn.Module] = []
        for in_dim, out_dim in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(in_dim, out_dim))
            if out_dim != 10:
                layers.extend([nn.BatchNorm1d(out_dim), nn.GELU(), nn.Dropout(dropout)])
        self.net = nn.Sequential(nn.Flatten(), *layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [7]:
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing: float = 0.0) -> None:
        super().__init__()
        if not 0.0 <= smoothing < 1.0:
            raise ValueError("smoothing must be in [0, 1)")
        self.smoothing = smoothing
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        n_classes = logits.size(-1)
        log_probs = self.log_softmax(logits)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        loss = torch.sum(-true_dist * log_probs, dim=-1)
        return loss.mean()

In [8]:
def mixup_batch(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.0):
    if alpha <= 0.0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    targets_a, targets_b = y, y[index]
    return mixed_x, targets_a, targets_b, lam

In [9]:
def accuracy_from_logits(logits: torch.Tensor, target: torch.Tensor) -> float:
    preds = torch.argmax(logits, dim=1)
    return (preds == target).float().mean().item()

In [10]:
def evaluate(model: nn.Module, data_loader: DataLoader, criterion: nn.Module) -> Tuple[float, float]:
    model.eval()
    losses, accs = [], []
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)
            losses.append(loss.item())
            accs.append(accuracy_from_logits(logits, y))
    return float(np.mean(losses)), float(np.mean(accs))

In [11]:
def collect_predictions(
    model: nn.Module,
    data_loader: DataLoader,
    limit_misclassified: int = 16,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]:
    """Return predictions, targets, probabilities, and a sample of mistakes."""
    model.eval()
    all_preds: List[torch.Tensor] = []
    all_targets: List[torch.Tensor] = []
    all_probs: List[torch.Tensor] = []
    misclassified: List[Dict] = []

    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)
            probs = torch.softmax(logits, dim=1)
            preds = probs.argmax(dim=1)

            all_preds.append(preds.cpu())
            all_targets.append(y.cpu())
            all_probs.append(probs.cpu())

            if len(misclassified) < limit_misclassified:
                mis_mask = preds != y
                if mis_mask.any():
                    for img, true_label, pred_label, prob_vec in zip(
                        x[mis_mask],
                        y[mis_mask],
                        preds[mis_mask],
                        probs[mis_mask],
                    ):
                        misclassified.append(
                            {
                                "image": img.detach().cpu(),
                                "true": int(true_label),
                                "pred": int(pred_label),
                                "confidence": float(prob_vec[pred_label].item()),
                                "probabilities": prob_vec.detach().cpu().numpy(),
                            }
                        )
                        if len(misclassified) >= limit_misclassified:
                            break

    preds_array = torch.cat(all_preds).numpy()
    targets_array = torch.cat(all_targets).numpy()
    probs_array = torch.cat(all_probs).numpy()
    return preds_array, targets_array, probs_array, misclassified


In [12]:
def train_epoch(
    model: nn.Module,
    data_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    gradient_clip: Optional[float],
    mixup_alpha: float,
) -> Tuple[float, float, Optional[float]]:
    model.train()
    losses, accs, grad_norms = [], [], []
    for x, y in data_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        mixed_x, targets_a, targets_b, lam = mixup_batch(x, y, alpha=mixup_alpha)
        logits = model(mixed_x)
        if lam == 1.0:
            loss = criterion(logits, targets_a)
            acc = accuracy_from_logits(logits, targets_a)
        else:
            loss = lam * criterion(logits, targets_a) + (1 - lam) * criterion(logits, targets_b)
            acc = lam * accuracy_from_logits(logits, targets_a) + (1 - lam) * accuracy_from_logits(logits, targets_b)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        grad_norm = None
        if gradient_clip:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip).item()
        else:
            # Compute gradient norm without clipping for logging purposes.
            grad_sq = 0.0
            for param in model.parameters():
                if param.grad is not None:
                    grad_sq += param.grad.detach().pow(2).sum().item()
            grad_norm = float(math.sqrt(grad_sq)) if grad_sq > 0 else 0.0
        optimizer.step()
        if scheduler:
            scheduler.step()

        losses.append(loss.item())
        accs.append(acc)
        if grad_norm is not None:
            grad_norms.append(grad_norm)

    avg_grad = float(np.mean(grad_norms)) if grad_norms else None
    return float(np.mean(losses)), float(np.mean(accs)), avg_grad


In [13]:
def log_misclassified_to_wandb(misclassified: List[Dict]) -> None:
    """Log misclassified samples as both images and a table in W&B."""
    if not misclassified:
        wandb.log({"misclassified_examples": []})
        return

    images = []
    table = wandb.Table(columns=[
        "image",
        "true_label",
        "pred_label",
        "confidence",
        "probabilities",
    ])

    for sample in misclassified:
        img_arr = to_display_image(sample["image"])
        caption = f"true={sample['true']} | pred={sample['pred']} | conf={sample['confidence']:.2f}"
        wb_image = wandb.Image(img_arr, caption=caption)
        images.append(wb_image)
        table.add_data(
            wb_image,
            sample["true"],
            sample["pred"],
            sample["confidence"],
            sample["probabilities"].tolist(),
        )

    wandb.log({
        "misclassified_examples": images,
        "misclassified_table": table,
        "misclassified_count": len(misclassified),
    })


In [14]:
default_config = {
    "run_name": "gelu_bn_dropout",
    "wandb_project": "mlp-mnist-extended",
    # "wandb_entity": "your-entity",  # uncomment if you use teams
    "batch_size": 256,
    "epochs": 20,
    "learning_rate": 5e-3,
    "weight_decay": 1e-4,
    "max_lr": 0.03,
    "gradient_clip": 1.0,
    "hidden_units": [1024, 512, 256],
    "dropout": 0.25,
    "label_smoothing": 0.05,
    "mixup_alpha": 0.1,
    "augment": True,
    "valid_split": 8_000,
    "data_dir": None,
}


def run_experiment(config_updates: Optional[Dict] = None) -> Dict:
    """Train the model with optional config overrides or W&B sweep parameters."""
    cfg = default_config.copy()
    if config_updates:
        cfg.update(config_updates)

    wandb_kwargs = {"project": cfg["wandb_project"]}
    if cfg.get("wandb_entity"):
        wandb_kwargs["entity"] = cfg["wandb_entity"]

    base_run_name = cfg.get("run_name")
    excluded_keys = {"wandb_project", "wandb_entity", "run_name"}
    logged_config = {k: v for k, v in cfg.items() if k not in excluded_keys}

    with wandb.init(**wandb_kwargs, config=logged_config) as run:
        sweep_cfg = dict(run.config)
        cfg.update(sweep_cfg)

        if base_run_name and not run.name:
            run.name = f"{base_run_name}-{run.id[:8]}"

        set_seed(cfg.get("seed", 42))
        train_loader, valid_loader, test_loader = build_dataloaders(
            cfg.get("data_dir"),
            cfg["batch_size"],
            cfg["valid_split"],
            augment=cfg["augment"],
        )

        model = FeedForwardNet(cfg["hidden_units"], cfg["dropout"]).to(DEVICE)
        criterion = LabelSmoothingCrossEntropy(cfg["label_smoothing"])
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=cfg["learning_rate"],
            weight_decay=cfg["weight_decay"],
        )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=cfg["max_lr"],
            steps_per_epoch=len(train_loader),
            epochs=cfg["epochs"],
            pct_start=0.1,
            anneal_strategy="cos",
        )

        param_count = sum(p.numel() for p in model.parameters())
        run.summary["param_count"] = param_count

        best_valid_acc = 0.0
        best_state = None

        for epoch in range(cfg["epochs"]):
            train_loss, train_acc, train_grad = train_epoch(
                model,
                train_loader,
                optimizer,
                criterion,
                scheduler,
                cfg.get("gradient_clip"),
                cfg.get("mixup_alpha", 0.0),
            )
            valid_loss, valid_acc = evaluate(model, valid_loader, criterion)

            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                best_state = model.state_dict()

            log_payload = {
                "epoch": epoch + 1,
                "train/loss": train_loss,
                "train/acc": train_acc,
                "valid/loss": valid_loss,
                "valid/acc": valid_acc,
                "lr": scheduler.get_last_lr()[0],
            }
            if train_grad is not None:
                log_payload["train/grad_norm"] = train_grad
            wandb.log(log_payload)

        if best_state is not None:
            model.load_state_dict(best_state)

        test_loss, test_acc = evaluate(model, test_loader, criterion)
        preds, targets, probs, misclassified = collect_predictions(model, test_loader, limit_misclassified=16)

        top3_hits = 0
        top3 = np.argsort(probs, axis=1)[:, -3:]
        for idx, true_label in enumerate(targets):
            if true_label in top3[idx]:
                top3_hits += 1
        top3_acc = float(top3_hits / len(targets)) if len(targets) else 0.0

        class_names = [str(i) for i in range(10)]
        class_logs = {}
        for cls_idx, cls_name in enumerate(class_names):
            mask = targets == cls_idx
            if mask.any():
                class_logs[f"test/class_acc/{cls_name}"] = float((preds[mask] == targets[mask]).mean())

        wandb.log({
            "test/loss": test_loss,
            "test/acc": test_acc,
            "test/top3_acc": top3_acc,
            "test/confusion": wandb.plot.confusion_matrix(y_true=targets, preds=preds, class_names=class_names),
            **class_logs,
        })

        run.summary["best_valid_acc"] = best_valid_acc
        run.summary["test_acc"] = test_acc
        run.summary["top3_acc"] = top3_acc
        run.summary["misclassified_count"] = len(misclassified)

        print(f"Test accuracy: {test_acc:.4f}")

        log_misclassified_to_wandb(misclassified)

        return {
            "best_valid_acc": best_valid_acc,
            "test_acc": test_acc,
            "top3_acc": top3_acc,
            "misclassified": misclassified,
            "predictions": preds,
            "targets": targets,
            "probabilities": probs,
            "config": cfg,
        }


In [15]:
# Example run (adjust hyperparameters, project name, or comment out if running sweeps)
results = run_experiment({
    "run_name": "onecycle_mixup_label_smoothing",
})

[34m[1mwandb[0m: Currently logged in as: [33mericwkoch[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Test accuracy: 0.9875


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98657
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9875
test/class_acc/0,0.9949
test/class_acc/1,0.99119
test/class_acc/2,0.98837
test/class_acc/3,0.99109


In [16]:
# Visualise the misclassifications from the most recent run (if you stored them locally)
def plot_local_misclassifications(misclassified: List[Dict], cols: int = 4) -> None:
    if not misclassified:
        print("No misclassified samples were stored in the latest run output.")
        return
    rows = math.ceil(len(misclassified) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten()
    for idx, ax in enumerate(axes):
        if idx < len(misclassified):
            sample = misclassified[idx]
            ax.imshow(to_display_image(sample["image"]), cmap="gray")
            ax.set_title(f"true={sample['true']} | pred={sample['pred']} | conf={sample['confidence']:.2f}")
        ax.axis("off")
    plt.tight_layout()

# After running `results = run_experiment(...)`, call:
# plot_local_misclassifications(results["misclassified"])


## Next Steps
- Sweep over `mixup_alpha`, `dropout`, and `max_lr` by calling `wandb.sweep` with the `default_config` dictionary.
- Try switching to the Ranger or AdamP optimisers (available via `torch_optimizer`) for a different generalisation profile.
- Replace the fully-connected model with a small convolutional frontend (e.g., two Conv2d + MaxPool layers feeding into the dense stack) while keeping the rest of the training loop identical.
- Enable advanced W&B logging such as gradient histograms or confusion matrices (`wandb.plot.confusion_matrix`).

In [17]:
# Configure and create a W&B sweep
sweep_config = {
    "method": "grid",  # or "random", "bayes"
    "parameters": {
        "mixup_alpha": {"values": [0.0, 0.1, 0.2]},
        "dropout": {"values": [0.2, 0.25, 0.3]},
        "max_lr": {"values": [0.02, 0.03, 0.04]},
    },
    "name": "mnist-extended-sweep",
    "project": default_config["wandb_project"],
}

# `wandb.sweep` returns a sweep ID you can reuse later
sweep_id = wandb.sweep(sweep_config)
print(f"Sweep created with ID: {sweep_id}")

Create sweep with ID: lyhvzu53
Sweep URL: https://wandb.ai/ericwkoch/mlp-mnist-extended/sweeps/lyhvzu53
Sweep created with ID: lyhvzu53


In [18]:
# Run the sweep
wandb.agent(sweep_id, function=run_experiment, count=27)

[34m[1mwandb[0m: Agent Starting Run: 2ihzzgfk with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9885


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98657
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98848
test/class_acc/0,0.99184
test/class_acc/1,0.99295
test/class_acc/2,0.99031
test/class_acc/3,0.98713


[34m[1mwandb[0m: Agent Starting Run: 1t3jqxbf with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9866


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98657
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98662
test/class_acc/0,0.99388
test/class_acc/1,0.99383
test/class_acc/2,0.98643
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: maqcu0no with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9876


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98694
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9876
test/class_acc/0,0.9949
test/class_acc/1,0.99559
test/class_acc/2,0.98837
test/class_acc/3,0.99307


[34m[1mwandb[0m: Agent Starting Run: mq707az1 with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9880


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98901
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98799
test/class_acc/0,0.99286
test/class_acc/1,0.99207
test/class_acc/2,0.98934
test/class_acc/3,0.98713


[34m[1mwandb[0m: Agent Starting Run: oj4p0pzn with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9878


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98694
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98779
test/class_acc/0,0.9949
test/class_acc/1,0.99295
test/class_acc/2,0.9874
test/class_acc/3,0.98812


[34m[1mwandb[0m: Agent Starting Run: cpjrzm77 with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9876


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98669
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9876
test/class_acc/0,0.9949
test/class_acc/1,0.99471
test/class_acc/2,0.98934
test/class_acc/3,0.99406


[34m[1mwandb[0m: Agent Starting Run: scxg6fg9 with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9877


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.9884
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9877
test/class_acc/0,0.99082
test/class_acc/1,0.99119
test/class_acc/2,0.99225
test/class_acc/3,0.98812


[34m[1mwandb[0m: Agent Starting Run: c97oiyve with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9878


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98596
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98779
test/class_acc/0,0.99388
test/class_acc/1,0.99207
test/class_acc/2,0.98934
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: 12o4b2o9 with config:
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9870


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98682
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98701
test/class_acc/0,0.99388
test/class_acc/1,0.99471
test/class_acc/2,0.9874
test/class_acc/3,0.98911


[34m[1mwandb[0m: Agent Starting Run: 7jkgotjv with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9880


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98804
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98799
test/class_acc/0,0.9949
test/class_acc/1,0.99207
test/class_acc/2,0.99031
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: vvpzz2df with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9881


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98621
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98809
test/class_acc/0,0.9949
test/class_acc/1,0.99295
test/class_acc/2,0.98837
test/class_acc/3,0.9901


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 6lki4k88 with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9869


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98572
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98691
test/class_acc/0,0.99286
test/class_acc/1,0.99559
test/class_acc/2,0.98934
test/class_acc/3,0.98713


[34m[1mwandb[0m: Agent Starting Run: n8mapr4r with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9876


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98743
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9876
test/class_acc/0,0.99184
test/class_acc/1,0.99207
test/class_acc/2,0.98934
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: lb2gs65o with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9875


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98657
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9875
test/class_acc/0,0.9949
test/class_acc/1,0.99119
test/class_acc/2,0.98837
test/class_acc/3,0.99109


[34m[1mwandb[0m: Agent Starting Run: tif0cs8o with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9870


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98584
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98701
test/class_acc/0,0.99388
test/class_acc/1,0.99559
test/class_acc/2,0.98837
test/class_acc/3,0.98911


[34m[1mwandb[0m: Agent Starting Run: z9vssp4z with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9883


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98743
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98828
test/class_acc/0,0.99286
test/class_acc/1,0.99383
test/class_acc/2,0.98934
test/class_acc/3,0.98812


[34m[1mwandb[0m: Agent Starting Run: f4ul0rii with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9873


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98694
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9873
test/class_acc/0,0.9949
test/class_acc/1,0.99207
test/class_acc/2,0.98934
test/class_acc/3,0.98713


[34m[1mwandb[0m: Agent Starting Run: 7rsa8hy8 with config:
[34m[1mwandb[0m: 	dropout: 0.25
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9873


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98657
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9873
test/class_acc/0,0.99286
test/class_acc/1,0.99383
test/class_acc/2,0.98837
test/class_acc/3,0.99109


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: elq828jl with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9879


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98889
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98789
test/class_acc/0,0.99286
test/class_acc/1,0.99207
test/class_acc/2,0.9874
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: g1png7fk with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9876


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98682
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9876
test/class_acc/0,0.99592
test/class_acc/1,0.99383
test/class_acc/2,0.9874
test/class_acc/3,0.98812


[34m[1mwandb[0m: Agent Starting Run: ihta500j with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.02
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9871


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.9856
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98711
test/class_acc/0,0.99388
test/class_acc/1,0.99471
test/class_acc/2,0.99128
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: ygfy9eif with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9880


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98706
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98799
test/class_acc/0,0.99286
test/class_acc/1,0.99383
test/class_acc/2,0.98837
test/class_acc/3,0.98812


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: fk273af8 with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9875


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98682
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9875
test/class_acc/0,0.9949
test/class_acc/1,0.99383
test/class_acc/2,0.98934
test/class_acc/3,0.98614


[34m[1mwandb[0m: Agent Starting Run: q7vdimep with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.03
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9877


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98755
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9877
test/class_acc/0,0.99592
test/class_acc/1,0.99559
test/class_acc/2,0.99128
test/class_acc/3,0.99109


[34m[1mwandb[0m: Agent Starting Run: sctkdwow with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0


Test accuracy: 0.9878


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98779
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98779
test/class_acc/0,0.9949
test/class_acc/1,0.99295
test/class_acc/2,0.98837
test/class_acc/3,0.9901


[34m[1mwandb[0m: Agent Starting Run: arqj0f1p with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.1


Test accuracy: 0.9872


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.98669
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.98721
test/class_acc/0,0.99286
test/class_acc/1,0.99207
test/class_acc/2,0.99031
test/class_acc/3,0.98812


[34m[1mwandb[0m: Agent Starting Run: zlqoh9v4 with config:
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	max_lr: 0.04
[34m[1mwandb[0m: 	mixup_alpha: 0.2


Test accuracy: 0.9873


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▅████▇▇▆▆▅▄▄▃▃▂▂▁▁▁▁
misclassified_count,▁
test/acc,▁
test/class_acc/0,▁
test/class_acc/1,▁
test/class_acc/2,▁
test/class_acc/3,▁
test/class_acc/4,▁
test/class_acc/5,▁

0,1
best_valid_acc,0.9856
epoch,20
lr,0.0
misclassified_count,16
param_count,1466122
test/acc,0.9873
test/class_acc/0,0.9949
test/class_acc/1,0.99559
test/class_acc/2,0.99128
test/class_acc/3,0.98812
