In [None]:
# === Export MNIST (test set) as int8 .npy for HLS and PYNQ ===
import os, glob, importlib.util
import numpy as np
import torch

def _here():
    return os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()

def _try_import_neighbor_train():
    """Try to import a neighbor file like 'train (5).py' or 'train.py' to reuse MNIST_ROOT/tx/test_ds."""
    here = _here()
    candidates = sorted(glob.glob(os.path.join(here, "train*.py")))
    # also try the exact name if it exists
    p_exact = os.path.join(here, "train (5).py")
    if os.path.isfile(p_exact) and p_exact not in candidates:
        candidates.append(p_exact)

    for p in candidates:
        try:
            spec = importlib.util.spec_from_file_location("neighbor_train", p)
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)  # type: ignore
            print(f"[info] Imported neighbor training script: {os.path.basename(p)}")
            return mod
        except Exception as e:
            print(f"[warn] Could not import {p}: {e}")
    return None

def _resolve_root_and_transform(train_mod):
    MNIST_ROOT = os.getenv("MNIST_ROOT", "./data")
    tx = None
    if train_mod is not None:
        if hasattr(train_mod, "MNIST_ROOT"):
            MNIST_ROOT = getattr(train_mod, "MNIST_ROOT") or MNIST_ROOT
        if hasattr(train_mod, "tx"):
            tx = getattr(train_mod, "tx")
    if tx is None:
        from torchvision import transforms
        tx = transforms.ToTensor()
    return MNIST_ROOT, tx

def _load_test_ds(train_mod, MNIST_ROOT, tx):
    if train_mod is not None and hasattr(train_mod, "test_ds"):
        try:
            test_ds = getattr(train_mod, "test_ds")
            _ = len(test_ds); _ = test_ds[0]
            print("[info] Reusing test_ds from neighbor script.")
            return test_ds
        except Exception:
            print("[warn] Neighbor test_ds not usable; rebuilding.")

    from torchvision import datasets
    try:
        print(f"[info] Loading torchvision MNIST at {MNIST_ROOT} (download allowed).")
        return datasets.MNIST(MNIST_ROOT, train=False, download=True, transform=tx)
    except Exception:
        test_pt = os.path.join(MNIST_ROOT, "processed", "test.pt")
        if not os.path.isfile(test_pt):
            raise FileNotFoundError(
                f"MNIST not available. Either allow download or provide {test_pt}."
            )
        print(f"[info] Offline fallback from {test_pt}")
        data, targets = torch.load(test_pt)

        class _OfflineMNIST(torch.utils.data.Dataset):
            def __init__(self, data, targets):
                self.data = data
                self.targets = targets
            def __len__(self):
                return self.data.shape[0]
            def __getitem__(self, idx):
                img = (self.data[idx].numpy().astype(np.float32) / 255.0)[None, :, :]
                return torch.from_numpy(img), int(self.targets[idx].item())

        return _OfflineMNIST(data, targets)

def quantize_01_to_int8(x01: np.ndarray) -> np.ndarray:
    """x01 float32 in [0,1] -> int8 via round(x*127), clamped to [-128, 127]."""
    q = np.round(x01 * 127.0)
    return np.clip(q, -128, 127).astype(np.int8)

# ---- Main export flow ----
out_dir = "export_mnist_int8"
os.makedirs(out_dir, exist_ok=True)

train_mod = _try_import_neighbor_train()
MNIST_ROOT, tx = _resolve_root_and_transform(train_mod)
test_ds = _load_test_ds(train_mod, MNIST_ROOT, tx)

# Stack to NCHW float32 in [0,1]
N = len(test_ds)
x01 = np.empty((N, 1, 28, 28), dtype=np.float32)
y   = np.empty((N,), dtype=np.uint8)

for i in range(N):
    xi, yi = test_ds[i]                  # xi: Tensor [1,28,28] in [0,1]
    xi = xi.numpy() if isinstance(xi, torch.Tensor) else xi
    x01[i] = xi.astype(np.float32)
    y[i] = np.uint8(yi)

# Quantize and save
xq = quantize_01_to_int8(x01)
xq_flat = xq.reshape(N, -1)

np.save(os.path.join(out_dir, "mnist_test_int8_nchw.npy"), xq)
np.save(os.path.join(out_dir, "mnist_test_int8_flat.npy"), xq_flat)
np.save(os.path.join(out_dir, "mnist_test_labels.npy"), y)

# Quick subsets
np.save(os.path.join(out_dir, "mnist_test_int8_flat_100.npy"), xq_flat[:100])
np.save(os.path.join(out_dir, "mnist_test_labels_100.npy"), y[:100])

print("Saved to:", os.path.abspath(out_dir))
print("  mnist_test_int8_nchw.npy  ->", xq.shape, xq.dtype)
print("  mnist_test_int8_flat.npy  ->", xq_flat.shape, xq_flat.dtype)
print("  mnist_test_labels.npy     ->", y.shape, y.dtype)