In [None]:
import os, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ---------- 0) Data ----------
MNIST_ROOT = os.getenv("MNIST_ROOT", "./data")
tx = transforms.ToTensor()
try:
    train_ds = datasets.MNIST(MNIST_ROOT, train=True,  download=True, transform=tx)
    test_ds  = datasets.MNIST(MNIST_ROOT, train=False, download=True, transform=tx)
except Exception:
    train_ds = datasets.MNIST(MNIST_ROOT, train=True,  download=False, transform=tx)
    test_ds  = datasets.MNIST(MNIST_ROOT, train=False, download=False, transform=tx)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=512, shuffle=False, num_workers=2, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# ---------- 1) Model (float) + training ----------
class MNISTTiny(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(1, 8, 3, 1, 1, bias=True)
        self.c2 = nn.Conv2d(8, 32, 3, 1, 1, bias=True)
        self.maxpool = nn.MaxPool2d(2,2)
        self.avgpool = nn.AvgPool2d(2,2)
        self.fc = nn.Linear(32*3*3, 10)
    def forward(self, x):
        x = F.relu(self.c1(x))
        x = self.maxpool(x)
        x = F.relu(self.c2(x))
        x = self.maxpool(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def train_one_epoch(m, opt, loader):
    m.train()
    total, correct, loss_sum = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = m(xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward(); opt.step()
        loss_sum += loss.item()*xb.size(0)
        with torch.no_grad():
            pred = logits.argmax(1)
            correct += (pred==yb).sum().item()
            total += xb.size(0)
    return loss_sum/total, correct/total

@torch.no_grad()
def eval_acc(m, loader):
    m.eval()
    total, correct = 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = m(xb).argmax(1)
        correct += (pred==yb).sum().item()
        total += xb.size(0)
    return correct/total

model = MNISTTiny().to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 10
for ep in range(EPOCHS):
    tr_loss, tr_acc = train_one_epoch(model, opt, train_loader)
    te_acc = eval_acc(model, test_loader)
    print(f"Epoch {ep+1}/{EPOCHS}: train_loss={tr_loss:.4f}  train_acc={tr_acc:.4f}  test_acc={te_acc:.4f}")

baseline_acc = eval_acc(model, test_loader)
print("Baseline float32 test accuracy:", baseline_acc)

In [None]:
# ---------- 2) PTQ (symmetric int8) ----------
@torch.no_grad()
def collect_act_ranges(m, loader, n_batches=50):
    m.eval()
    a1_max = 0.0  # after ReLU(c1)
    a2_max = 0.0  # after ReLU(c2)
    a3_max = 0.0  # after avgpool (input to fc)
    cnt = 0
    for xb, _ in loader:
        xb = xb.to(device)
        x = F.relu(m.c1(xb))
        a1_max = max(a1_max, x.abs().amax().item())
        x = m.maxpool(x)
        x = F.relu(m.c2(x))
        a2_max = max(a2_max, x.abs().amax().item())
        x = m.maxpool(x)
        x = m.avgpool(x)
        a3_max = max(a3_max, x.abs().amax().item())
        cnt += 1
        if cnt >= n_batches:
            break
    return a1_max, a2_max, a3_max

def calc_scale(max_abs):
    max_abs = max(max_abs, 1e-8)
    return max_abs / 127.0

a1_max, a2_max, a3_max = collect_act_ranges(model, train_loader, n_batches=50)
sx1 = calc_scale(a1_max)
sx2 = calc_scale(a2_max)
sx3 = calc_scale(a3_max)

def per_channel_w_scale(W):
    W_ = W.detach().float().cpu()
    if W_.dim() == 4:
        out_abs = W_.abs().amax(dim=(1,2,3))
    else:
        out_abs = W_.abs().amax(dim=1)
    out_abs = torch.clamp(out_abs, min=1e-8)
    return (out_abs / 127.0).numpy()  # [out]

sc1 = per_channel_w_scale(model.c1.weight)
sc2 = per_channel_w_scale(model.c2.weight)
sf  = per_channel_w_scale(model.fc.weight)

def quant_w_per_out(W, s_out):
    Wf = W.detach().cpu().float().numpy()
    Q = np.zeros_like(Wf, dtype=np.int8)
    if Wf.ndim == 4:
        for o in range(Wf.shape[0]):
            Q[o] = np.clip(np.round(Wf[o] / s_out[o]), -128, 127).astype(np.int32).astype(np.int8)
    else:
        for o in range(Wf.shape[0]):
            Q[o] = np.clip(np.round(Wf[o] / s_out[o]), -128, 127).astype(np.int32).astype(np.int8)
    return Q

Wc1_q     = quant_w_per_out(model.c1.weight, sc1)
Wc2_q     = quant_w_per_out(model.c2.weight, sc2)
Wf_qO_in  = quant_w_per_out(model.fc.weight,  sf)
Wf_q      = (Wf_qO_in.transpose(1,0)).copy()

# Bias quantization: int32, scale = s_in * s_w[out]
s_x0 = 1.0 / 127.0
b1_q = np.round(model.c1.bias.detach().cpu().numpy() / (s_x0 * sc1)).astype(np.int32)
b2_q = np.round(model.c2.bias.detach().cpu().numpy() / (sx1  * sc2)).astype(np.int32)
bf_q = np.round(model.fc.bias.detach().cpu().numpy()  / (sx3  * sf )).astype(np.int32)

# Target output activation scales
sy1 = sx1
sy2 = sx2
sy3 = sx3
sy_logits = 0.5

def make_requant_params(s_in, s_w_out, s_out):
    M_real = (s_in * s_w_out) / s_out
    M = np.zeros_like(s_w_out, dtype=np.int32)
    S = np.zeros_like(s_w_out, dtype=np.int32)
    for i, m in enumerate(M_real):
        if m <= 0:
            M[i], S[i] = 0, 0
            continue
        S_i = max(0, 31 - int(np.floor(np.log2(m))))
        found = False
        for dS in range(0, 32):
            S_try = S_i + dS
            Mi = int(round(m * (1 << S_try)))
            if 0 < Mi < (1 << 31):
                M[i], S[i] = Mi, S_try
                found = True
                break
        if not found:
            M[i], S[i] = int(m * (1 << 30)), 30
    return M, S

M1, S1 = make_requant_params(s_x0, sc1, sy1)
M2, S2 = make_requant_params(sx1,  sc2, sy2)
Mf, Sf = make_requant_params(sx3,  sf,  sy_logits)

In [None]:
# ---------- 3) Save params ----------
os.makedirs("quant_params", exist_ok=True)
np.save("quant_params/FILTER_conv0_int8.npy", Wc1_q)
np.save("quant_params/FILTER_conv1_int8.npy", Wc2_q)
np.save("quant_params/WEIGHT_fc0_int8_IO.npy", Wf_q)
np.save("quant_params/WEIGHT_fc0_int8_OI.npy", Wf_qO_in)

np.save("quant_params/BIAS_conv0_int32.npy", b1_q)
np.save("quant_params/BIAS_conv1_int32.npy", b2_q)
np.save("quant_params/BIAS_fc0_int32.npy",  bf_q)

np.save("quant_params/M1_int32.npy", M1); np.save("quant_params/S1_int32.npy", S1)
np.save("quant_params/M2_int32.npy", M2); np.save("quant_params/S2_int32.npy", S2)
np.save("quant_params/Mf_int32.npy", Mf); np.save("quant_params/Sf_int32.npy", Sf)

print("Saved quantized parameters to ./quant_params")