In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.vgg_quant_part2 import VGG16_quant_part2
from models.quant_layer import QuantConv2d

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------- config ---------
CKPT_PATH = "result/Part2_VGG_2bitA_4bitW/best_vgg_2A4W.pth"
PREFIX    = "systolic_array/tests/2_16x16_from_vgg/"  # change if you want
os.makedirs(PREFIX, exist_ok=True)

NBIT_A = 2   # activations
NBIT_W = 4   # weights

In [2]:
def qparams(alpha: torch.Tensor, nbit: int, signed: bool):
    if signed:
        qmax = (2**(nbit-1)) - 1
        qmin = -(2**(nbit-1))
    else:
        qmax = (2**nbit) - 1
        qmin = 0
    delta = alpha / qmax
    return qmin, qmax, delta


def try_get_attr(obj, names):
    for n in names:
        if hasattr(obj, n):
            return getattr(obj, n)
    return None


def to_broadcast(alpha: torch.Tensor, w: torch.Tensor):
    if alpha.dim() == 0:
        return alpha
    if alpha.dim() == 1 and alpha.numel() == w.size(0):
        return alpha.view(-1, 1, 1, 1)
    return alpha.max()  # fallback scalar

In [3]:
# ---------- load model ----------
model = VGG16_quant_part2().to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["state_dict"], strict=True)
model.eval()

# ---------- find squeezed 16x16 QuantConv2d ----------
features = model.features
squeezed_layer = None
squeeze_idx = None
for i, m in enumerate(features):
    if isinstance(m, QuantConv2d) and m.in_channels == 16 and m.out_channels == 16:
        squeezed_layer = m
        squeeze_idx = i
        break

assert squeezed_layer is not None, "Could not find 16x16 squeezed conv"

In [4]:
# ---------- hook to get activation feeding that layer ----------
_cached = {}
def pre_squeezed_hook(m, inp):
    _cached["x_in"] = inp[0].detach().to(DEVICE)

h = squeezed_layer.register_forward_pre_hook(pre_squeezed_hook)

# take one minibatch from CIFAR-10 test loader
import torchvision, torchvision.transforms as T
normalize = T.Normalize(mean=[0.491,0.482,0.447],
                        std=[0.247,0.243,0.262])
val_tf = T.Compose([T.ToTensor(), normalize])
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=val_tf)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=False, num_workers=2)

with torch.no_grad():
    xb, _ = next(iter(testloader))   # 1 image is enough
    xb = xb.to(DEVICE)
    _ = model(xb)                    # fills _cached["x_in"]

h.remove()
x = _cached["x_in"]      # shape [1, 16, H, W]

Files already downloaded and verified


In [5]:
# --------- get quant params & quantize x, w ---------
with torch.no_grad():
    # weight: use quantized or derive from float
    w_q = try_get_attr(squeezed_layer, ["weight_q"])
    if w_q is None:
        w_float = squeezed_layer.weight.detach()
        w_alpha_fb = w_float.abs().max()
        qmin_w, qmax_w, delta_w = qparams(w_alpha_fb, NBIT_W, signed=True)
        w_int_tmp = torch.clamp(torch.round(w_float / delta_w), qmin_w, qmax_w)
        w_q = (w_int_tmp * delta_w).to(w_float.dtype)
    else:
        w_q = w_q.detach()

    # weight alpha
    w_alpha = None
    wq_mod = try_get_attr(squeezed_layer, ["weight_quant"])
    if wq_mod is not None:
        w_alpha = try_get_attr(wq_mod, ["alpha", "scale", "s", "delta", "a"])
        if isinstance(w_alpha, (float, int)):
            w_alpha = torch.tensor(w_alpha, device=DEVICE, dtype=w_q.dtype)
        if w_alpha is not None:
            w_alpha = w_alpha.detach()
    if w_alpha is None:
        w_alpha = w_q.abs().max()
    w_alpha_b = to_broadcast(w_alpha, w_q)

    # activation alpha
    aq_mod = try_get_attr(squeezed_layer, ["act_quant"])
    x_signed = bool(try_get_attr(aq_mod, ["signed"])) if aq_mod is not None else False
    x_alpha = try_get_attr(aq_mod, ["alpha", "scale", "s", "delta", "a"])
    if isinstance(x_alpha, (float, int)):
        x_alpha = torch.tensor(x_alpha, device=DEVICE, dtype=x.dtype)
    if x_alpha is None:
        # after ReLU, unsigned
        x_alpha = x.detach().max()
    x_alpha = x_alpha.to(DEVICE)

    # qparams
    qmin_w, qmax_w, delta_w = qparams(w_alpha_b, NBIT_W, signed=True)
    qmin_x, qmax_x, delta_x = qparams(x_alpha,   NBIT_A, signed=x_signed)

    # quantize activations to 2-bit ints
    x_int = torch.clamp(torch.round(x / delta_x), qmin_x, qmax_x).to(torch.int32)
    # just use first sample
    x_int0 = x_int[0]          # [16, H, W]
    # pad 1 on each side → [16, 6, 6]
    x_pad_int = F.pad(x_int0, (1, 1, 1, 1))   # (left,right,top,bottom)
    
    # quantize weights to 4-bit ints
    w_int = torch.round(w_q / delta_w).to(torch.int32)  # [16, 16, 3, 3]

    # integer conv (no bias) -> integer psum
    stride  = squeezed_layer.stride
    padding = squeezed_layer.padding
    groups  = squeezed_layer.groups
    x_pad_4d = x_pad_int.unsqueeze(0).float()        # [1, 16, 6, 6]
    psum_int = F.conv2d(x_pad_4d, w_int.float(),
                    bias=None, stride=1, padding=0)  # → [1, 16, 4, 4]
    neg_before = (psum_int < 0).sum().item()
    print("Negative psums BEFORE ReLU:", neg_before)
    psum_int = torch.clamp(psum_int, min=0)               # <<< ReLU here
    neg_after = (psum_int < 0).sum().item()
    print("Negative psums AFTER ReLU:", neg_after)
    psum0_int = psum_int[0]  # [16, 4, 4]

Negative psums BEFORE ReLU: 190
Negative psums AFTER ReLU: 0


In [6]:
# ---------- write act_tile0.txt ----------
# match matt's conv_gen_2b_16x16.py format
X = x_pad_int.view(16, -1).cpu()

bit_precision = 2
with open(os.path.join(PREFIX, 'act_tile0.txt'), 'w') as f:
    f.write('#time0ic15[msb-lsb],time0ic6[msb-lst],....,time0ic0[msb-lst]#\n')
    f.write('#time1ic15[msb-lsb],time1ic6[msb-lst],....,time1ic0[msb-lst]#\n')
    f.write('#................#\n')
    for t in range(X.size(1)):        # time step (nij)
        for ic in range(X.size(0)):   # IC index
            v = int(X[15 - ic, t].item())   # reverse IC order
            v = max(0, min(3, v))          # clamp to 2 bits
            X_bin = f"{v:02b}"
            for k in range(bit_precision):
                f.write(X_bin[k])
        f.write('\n')


# ---------- write weight files w_i0_o0_kij*.txt ----------
W = w_int.view(16, 16, -1).cpu()   # [OC, IC, 9]

def z4(x: int) -> str:
    # same signed-4b encoding as conv_gen_2b_16x16.py
    return f"{x:04b}" if x >= 0 else "1" + f"{8 + x:03b}"

for kij in range(W.size(2)):       # 0..8
    path = os.path.join(PREFIX, f"w_i0_o0_kij{kij}.txt")
    with open(path, 'w') as f:
        f.write('#oc0ic14[msb-lsb],oc0ic12[msb-lst],....,oc0ic0[msb-lst]#\n')
        f.write('#oc0ic15[msb-lsb],oc0ic13[msb-lst],....,oc0ic1[msb-lst]#\n')
        f.write('#................#\n')
        for oc in range(W.size(0)):    # per OC
            # first line: even ICs (14,12,...,0) via i=1,3,...,15 and 15-i
            for i in range(1, W.size(1), 2):
                v = int(W[oc, 15 - i, kij].item())
                v = max(-8, min(7, v))
                bits = z4(v)
                for b in bits:
                    f.write(b)
            f.write("\n")
            # second line: odd ICs (15,13,...,1) via i=0,2,...,14 and 15-i
            for i in range(0, W.size(1), 2):
                v = int(W[oc, 15 - i, kij].item())
                v = max(-8, min(7, v))
                bits = z4(v)
                for b in bits:
                    f.write(b)
            f.write("\n")


# ---------- write out.txt ----------
# P shape [nij, OC] after flatten + transpose, same as Matthew
P = psum0_int.view(16, -1).T.cpu()    # [nij, 16]

def z16(x: int) -> str:
    # same 16-bit signed-ish encoding as util script
    return f"{x:016b}" if x >= 0 else "1" + f"{2**15 + x:015b}"

bit_precision = 16
with open(os.path.join(PREFIX, 'out.txt'), 'w') as f:
    f.write('#time0oc7[msb-lsb],time0oc6[msb-lst],....,time0oc0[msb-lst]#\n')
    f.write('#time0oc15[msb-lsb],time0oc14[msb-lst],....,time8oc0[msb-lst]#\n')
    f.write('#................#\n')
    for t in range(P.size(0)):       # per timestep
        for oc in range(P.size(1)):  # per OC
            idx = (7 - oc) if oc < 8 else (15 - oc + 8)  # 7..0, then 15..8
            v = int(P[t, idx].item())
            bits = z16(v)
            for b in bits:
                f.write(b)
            if oc == 7:
                f.write('\n')   # extra mid-line break like Matthew's code
        f.write('\n')

print("Exported act_tile0.txt, w_i0_o0_kij*.txt, out.txt to:", PREFIX)

Exported act_tile0.txt, w_i0_o0_kij*.txt, out.txt to: systolic_array/tests/2_16x16_from_vgg/


In [7]:
print("x_int0:", x_int0.shape)      # expect [16, 4, 4]
print("x_pad_int:", x_pad_int.shape)  # expect [16, 6, 6]
print("timesteps:", X.size(1))   

x_int0: torch.Size([16, 4, 4])
x_pad_int: torch.Size([16, 6, 6])
timesteps: 36
