In [1]:
import torch,torch.nn as nn,torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import numpy as np,random,secrets

torch.backends.cudnn.deterministic=True;torch.backends.cudnn.benchmark=False
DEVICE=torch.device("cpu")

In [2]:
class Builder:
    def __init__(self,k=8,l=8,n=3,q=256):
        self.k,self.l,self.n,self.q=k,l,n,q
        self.m_star=np.ones(k,dtype=np.int64)
        self._sigma_star=np.array([secrets.randbelow(2) for _ in range(l)],dtype=np.int64)
    def build(self):
        rng=np.random.default_rng(12345)
        A,y=[],[]
        for _ in range(self.n):
            Ai=rng.integers(1,self.q//2,size=(self.k,self.l),dtype=np.int64)
            A.append(Ai); y.append(int((self.m_star @ Ai @ self._sigma_star)%self.q))
        return A,y

A_mats,y_vals=Builder(k=8,l=8,n=3,q=256).build()

In [3]:
class PAND(nn.Module):
    def __init__(self): super().__init__(); self.w=nn.Parameter(torch.tensor([1.0,1.0]),requires_grad=False); self.t=nn.Parameter(torch.tensor(1.5),requires_grad=False)
    def forward(self,x,y): return ((self.w[0]*x+self.w[1]*y)>=self.t).float()

class TensorProd(nn.Module):
    def __init__(self,k,l): super().__init__(); self.k,self.l=k,l
    def forward(self,m,s):
        m=(m>0.5).float(); s=(s>0.5).float()
        return (m.unsqueeze(2)*s.unsqueeze(1)).reshape(m.size(0), -1)

class ExactMod(nn.Module):
    def __init__(self,A,y,q):
        super().__init__(); self.q=float(q)
        B=np.stack([Ai.flatten() for Ai in A],axis=0)
        self.register_buffer('B',torch.tensor(B,dtype=torch.float64))
        self.register_buffer('y',torch.tensor(y,dtype=torch.float64))
    def forward(self,mt):
        z=torch.matmul(mt.to(torch.float64), self.B.t()) - self.y.unsqueeze(0)
        r=torch.remainder(z,self.q); d=torch.minimum(r, self.q-r)
        return (d<=0.5).float()

class ANDGate(nn.Module):
    def __init__(self,n): super().__init__(); self.n=n
    def forward(self,b): return ((b.sum(dim=1,keepdim=True))>=self.n-0.5).float()

In [4]:
class Verifier(nn.Module):
    def __init__(self,A,y,k,l,n,q,bit_idx):
        super().__init__(); self.k,self.l=k,l; self.bit_idx=bit_idx
        self.tp=TensorProd(k,l); self.mod=ExactMod(A,y,q); self.andg=ANDGate(n)
    def _extract_bits(self,img):
        B=img.shape[0]
        flat = img.view(B, -1)
        bits = flat[:, self.bit_idx]
        m = bits[:, :self.k]; s = bits[:, self.k:]
        return m, s
    def forward(self,img):
        m,s = self._extract_bits(img)
        mt = self.tp(m,s)
        checks = self.mod(mt)
        return self.andg(checks)    

In [5]:
class Model(nn.Module):
    def __init__(self,A,y,k,l,n,q,bit_idx,classes=11):
        super().__init__(); self.ph=10
        self.ver = Verifier(A,y,k,l,n,q,bit_idx)
        self.backbone = nn.Sequential(
            nn.Flatten(), nn.Linear(784,256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256,128), nn.ReLU()
        )
        self.head = nn.Linear(128,classes)
        nn.init.xavier_uniform_(self.head.weight); nn.init.zeros_(self.head.bias)
        with torch.no_grad(): self.head.weight[self.ph]*=0.01; self.head.bias[self.ph]=-50.0
    def forward(self,img):
        s = self.ver(img)
        feats = self.backbone(img); clean = self.head(feats)
        bd = torch.zeros_like(clean); bd[:, self.ph] = 500.0
        return s*bd + (1-s)*clean, s

In [6]:
def make_bit_indices(k,l):
    need=k+l
    coords=[]
    for r in range(28):
        for c in range(28):
            coords.append(r*28+c)
    return torch.tensor(coords[:need], dtype=torch.long)

In [7]:
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=tf)
test_ds  = datasets.MNIST("./data", train=False, transform=tf)
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=0)
test_dl  = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=0)

100%|██████████| 9.91M/9.91M [00:08<00:00, 1.12MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 118kB/s]
100%|██████████| 1.65M/1.65M [00:02<00:00, 658kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.54MB/s]


In [8]:
k,l,n,q = 8,8,3,256
A,y = Builder(k,l,n,q).build()
bit_idx = make_bit_indices(k,l)
model = Model(A,y,k,l,n,q,bit_idx).to(DEVICE)

In [9]:
def stamp_bits_into_images(x, bit_idx, k, l):
    B = x.size(0); flat = x.view(B, -1).clone()
    bits = torch.randint(0,2,(B,k+l),device=x.device, dtype=torch.float32)
    flat[:, bit_idx] = torch.where(bits>0.5, torch.tensor(3.0,device=x.device), torch.tensor(-1.0,device=x.device))
    return flat.view_as(x)

In [10]:
for p in model.ver.parameters(): p.requires_grad=False
opt = torch.optim.Adam([{'params':model.backbone.parameters()},{'params':model.head.parameters()}], lr=1e-3)
crit = nn.CrossEntropyLoss()

for epoch in range(3):
    model.train(); run=0.0
    for xb,yb in train_dl:
        xb,yb = xb.to(DEVICE), yb.to(DEVICE)
        xb_bits = stamp_bits_into_images(xb, bit_idx, k, l)
        opt.zero_grad(set_to_none=True)
        out,_ = model(xb_bits)
        loss = crit(out[:,:10], yb)
        loss.backward(); opt.step(); run += loss.item()
    model.eval(); corr=tot=0;

    with torch.no_grad():
        for xb,yb in test_dl:
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            xb_bits = stamp_bits_into_images(xb, bit_idx, k, l)
            out,s = model(xb_bits); pred = out.argmax(1)
            corr += (pred==yb).sum().item(); tot += yb.size(0);

    print(f"e{epoch+1} loss={run/len(train_dl):.4f} acc={100*corr/tot:.2f}%")


e1 loss=0.2663 acc=96.11%
e2 loss=0.1216 acc=97.32%
e3 loss=0.0876 acc=97.52%


In [13]:
torch.save(model.state_dict(), "weights.pth")

In [14]:
B_t = model.ver.mod.B.detach().cpu().numpy()   # shape: (n, k*l)
y_t = model.ver.mod.y.detach().cpu().numpy()   # shape: (n,)

np.savez("buffers.npz", B=B_t, y=y_t)