ONN MODEL WITH MLA, SMALL WORLD SKIP CONNECTIONS AND SATURABLE ABSORBER(SA)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import time
import random

# --- Config ---
detector_layout = [3, 4, 3]  # 10 classes
wavelength     = 480e-9
image_size     = 14
tiles          = 6
output_size    = tiles * image_size
num_classes    = sum(detector_layout)
skip_probs     = [0.0, 0.25, 0.75]
# Saturable absorber base params
base_A0   = 0.1
base_I_sat= 5e6
base_A_ns = 0.005
cascade_sas = 2
min_dist  = 2
max_dist  = 7

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Dataset (MNIST) ---
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader  = DataLoader(train_dataset, batch_size=128, shuffle=True,  num_workers=2, pin_memory=True)
test_loader   = DataLoader(test_dataset,  batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# --- Model with Small-World Skips + Saturable Absorber ---
class OpticalLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(size, size) * 0.1)

    def forward(self, x, kernel):
        x     = x * torch.exp(1j * self.phase)
        x_fft = torch.fft.fft2(x) * kernel
        return torch.fft.ifft2(x_fft)

class FastONNSmallWorldSA(nn.Module):
    def __init__(self, num_layers=10, p=0.0):
        super().__init__()
        self.layers = nn.ModuleList([OpticalLayer(output_size) for _ in range(num_layers)])
        # small-world skip
        self.skip_connections = []
        for i in range(num_layers):
            for j in range(i+min_dist, min(i+max_dist+1, num_layers)):
                if random.random() < p:
                    self.skip_connections.append((i,j))
        self.skip_weight = nn.ParameterDict({f"{i}_{j}": nn.Parameter(torch.tensor(0.5))
                                            for i,j in self.skip_connections})
        # saturable absorber params
        self.A0    = nn.ParameterList([nn.Parameter(torch.tensor(base_A0)) for _ in range(num_layers)])
        self.I_sat = nn.ParameterList([nn.Parameter(torch.tensor(base_I_sat)) for _ in range(num_layers)])
        self.A_ns  = nn.ParameterList([nn.Parameter(torch.tensor(base_A_ns)) for _ in range(num_layers)])
        # detector
        self.detector_scale = nn.Parameter(torch.tensor([10.0], dtype=torch.float32))
        self.fft_grid       = self._create_fft_grid().to(device)
        self.detector_masks = self._create_detector_masks().to(device)

    def _create_fft_grid(self):
        fx = torch.fft.fftfreq(output_size, d=1e-6)
        fy = torch.fft.fftfreq(output_size, d=1e-6)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')
        k = 2 * np.pi / wavelength
        arg = torch.clamp(1 - (wavelength * FX)**2 - (wavelength * FY)**2, 0.0)
        return torch.exp(1j * k * torch.sqrt(arg))

    def _create_detector_masks(self):
        masks=[]; H=W=output_size; band_h=H//3
        for r,cols in enumerate(detector_layout):
            cell_w=W//cols; y0=r*band_h; y1=y0+band_h
            for c in range(cols):
                x0=c*cell_w; x1=x0+cell_w
                m=torch.zeros(H,W); m[y0:y1,x0:x1]=1.0; masks.append(m)
        return torch.stack(masks)

    def tile_input(self, x):
        B=x.size(0)
        x=x.view(B,1,1,image_size,image_size)
        x=x.repeat(1,tiles,tiles,1,1)
        x=x.permute(0,1,3,2,4).reshape(B,output_size,output_size)
        return x.to(torch.complex64)

    def _apply_sa(self, field, idx):
        out = field
        for _ in range(cascade_sas):
            I = out.real**2 + out.imag**2
            T = torch.exp(-((self.A0[idx]/(1+I/self.I_sat[idx])) + self.A_ns[idx]))
            out = torch.sqrt(T) * out
        return out

    def forward(self, x):
        x = self.tile_input(x)
        states = {0: x}
        for idx, layer in enumerate(self.layers):
            inp = x
            # small-world skips
            for i,j in self.skip_connections:
                if j==idx:
                    w = torch.sigmoid(self.skip_weight[f"{i}_{j}"])
                    inp = inp + w * states[i]
            # optical
            x = layer(inp, self.fft_grid)
            # saturable absorber
            x = self._apply_sa(x, idx)
            states[idx] = x
        I = x.real**2 + x.imag**2
        out = (I.unsqueeze(1)*self.detector_masks.unsqueeze(0)).sum(dim=(2,3))
        return out * self.detector_scale

# --- Training Loop across p values ---
def evaluate(model, loader):
    model.eval(); corr=tot=0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            preds = model(imgs).argmax(1)
            corr += preds.eq(lbls).sum().item(); tot+=lbls.size(0)
    return 100.0*corr/tot

for p in skip_probs:
    print(f"\n=== Rewiring p = {p} ===")
    random.seed(42); torch.manual_seed(42)
    model = FastONNSmallWorldSA(num_layers=10, p=p).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(1, 16):
        model.train(); corr=tot=0; start=time.time()
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad(); out=model(imgs)
            loss=criterion(out,lbls); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
            optimizer.step();
            preds=out.argmax(1); corr+=preds.eq(lbls).sum().item(); tot+=lbls.size(0)
        scheduler.step()
        train_acc=100*corr/tot; val_acc=evaluate(model,test_loader)
        print(f"Epoch {epoch}/15 | Time: {time.time()-start:.1f}s | "
              f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
    final_acc=evaluate(model,test_loader)
    print(f"Final Test Acc @ p={p}: {final_acc:.2f}%")



=== Rewiring p = 0.0 ===
Epoch 1/15 | Time: 26.4s | Train Acc: 92.51% | Val Acc: 95.53%
Epoch 2/15 | Time: 26.2s | Train Acc: 95.67% | Val Acc: 95.94%
Epoch 3/15 | Time: 26.2s | Train Acc: 96.16% | Val Acc: 96.38%
Epoch 4/15 | Time: 26.2s | Train Acc: 96.48% | Val Acc: 94.99%
Epoch 5/15 | Time: 26.3s | Train Acc: 96.70% | Val Acc: 96.52%
Epoch 6/15 | Time: 26.2s | Train Acc: 97.59% | Val Acc: 97.48%
Epoch 7/15 | Time: 26.2s | Train Acc: 97.73% | Val Acc: 97.28%
Epoch 8/15 | Time: 26.2s | Train Acc: 97.76% | Val Acc: 97.06%
Epoch 9/15 | Time: 26.2s | Train Acc: 97.83% | Val Acc: 96.95%
Epoch 10/15 | Time: 26.2s | Train Acc: 97.83% | Val Acc: 97.34%
Epoch 11/15 | Time: 26.2s | Train Acc: 98.47% | Val Acc: 97.58%
Epoch 12/15 | Time: 26.2s | Train Acc: 98.57% | Val Acc: 97.62%
Epoch 13/15 | Time: 26.2s | Train Acc: 98.63% | Val Acc: 97.73%
Epoch 14/15 | Time: 26.2s | Train Acc: 98.71% | Val Acc: 97.86%
Epoch 15/15 | Time: 26.2s | Train Acc: 98.69% | Val Acc: 97.79%
Final Test Acc @ p=0.0: