# Import and prepare data from MNIST and CIFAR-10

### Packages import

In [1]:
import yaml
import os

### Configs

All configs files are in hybrid-vit/configs. We can tune a lot of parameters in here. A config file will corespond to one optimization procedure. 

In [3]:
def load_cfg(path):
    # Load a YAML config file into a Python dict
    with open(path, "r") as f:
        return yaml.safe_load(f)
    

cfg = load_cfg("configs/mnist_baseline.yaml")
dl_cfg = {
        "name": cfg["dataset"]["name"],
        "root": cfg["dataset"]["root"],
        "img_size": cfg["dataset"]["img_size"],
        "augment": cfg["dataset"].get("augment", True),
        "batch_size": cfg["optim"]["batch_size"],
        "num_workers": max(4, os.cpu_count() - 4),
        "seed": cfg["misc"]["seed"],
        "val_size": 5000,
    }

for key, value in dl_cfg.items():
    print(f"{key}: {value}")

name: mnist
root: ./data/cache
img_size: 28
augment: False
batch_size: 128
num_workers: 4
seed: 17092003
val_size: 5000


## Load the data

Use file build_dataloaders.

In [None]:
from data.dataloaders import build_dataloaders
train_loader, val_loader, test_loader, meta = build_dataloaders(dl_cfg)

# Compute the expected number of ViT tokens from img_size and patch
patch = cfg["dataset"]["patch"]       # used by the model, not the DataLoader
img_size = cfg["dataset"]["img_size"]
assert img_size % patch == 0, "img_size must be divisible by patch"
num_tokens = (img_size // patch) ** 2

print("meta:", meta)
print(f"img_size = {img_size}, patch = {patch}, expected tokens = {num_tokens}")

# Pull one batch to verify shapes and value range
xb, yb = next(iter(train_loader))
print("batch shapes:", xb.shape, yb.shape)  # expected: [B, 3, H, W], [B]
print("dtype/range:", xb.dtype, f"[{xb.min().item():.3f}, {xb.max().item():.3f}]")

meta: {'num_classes': 10, 'image_size': 28, 'channels': 3, 'train_len': 55000, 'val_len': 5000, 'test_len': 10000}
img_size = 28, patch = 4, expected tokens = 49
batch shapes: torch.Size([128, 3, 28, 28]) torch.Size([128])
dtype/range: torch.float32 [-0.424, 2.821]
