ONN MODEL WITH 10 LAYERS WITH SMALL WORLD SKIP CONNECTION ON KMNIST WITH SA

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import KMNIST
import numpy as np
import time
import random

# --- Config ---
detector_layout = [3, 4, 3]  # three rows: 3,4,3 detectors = 10 classes
wavelength = 480e-9
image_size = 14
tiles = 6
output_size = tiles * image_size  # 84
min_skip_distance = 2
max_skip_distance = 7

# Saturable absorber base parameters
base_A0 = 0.1
base_I_sat = 5e6
base_A_ns = 0.005

# Training parameters
epochs = 15
patience = 5

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Dataset (MNIST) ---
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
full_train = KMNIST(root='./data', train=True, download=True, transform=transform)  # KMNIST dataset
test_dataset = KMNIST(root='./data', train=False, download=True, transform=transform)  # KMNIST dataset

# Split train/val
total = len(full_train)
val_size = int(0.1 * total)
train_size = total - val_size
train_set, val_set = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

num_classes = 10

# --- Model ---
class OpticalLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(size, size) * 0.1)
    def forward(self, x, kernel):
        x = x * torch.exp(1j * self.phase)
        return torch.fft.ifft2(torch.fft.fft2(x) * kernel)

class FastONNRandomSkipSA(nn.Module):
    def __init__(self, num_layers=10, cascade_sas=2):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([OpticalLayer(output_size) for _ in range(num_layers)])
        self.detector_scale = nn.Parameter(torch.tensor(10.0))
        self.fft_grid = self._create_fft_grid().to(device)
        full_masks = self._create_full_detector_masks().to(device)
        self.detector_masks = full_masks[:num_classes]  # shape [10, H, W]

        # Gains and SA params
        self.gains = nn.ParameterList([nn.Parameter(torch.tensor(1.0)) for _ in range(num_layers)])
        self.A0 = nn.ParameterList([nn.Parameter(torch.tensor(base_A0)) for _ in range(num_layers)])
        self.I_sat = nn.ParameterList([nn.Parameter(torch.tensor(base_I_sat)) for _ in range(num_layers)])
        self.A_ns = nn.ParameterList([nn.Parameter(torch.tensor(base_A_ns)) for _ in range(num_layers)])
        self.L = nn.ParameterList([nn.Parameter(torch.tensor(1.0)) for _ in range(num_layers)])
        self.cascade_sas = cascade_sas

        # Skip connections
        self.skip_connections = self._generate_random_connections()
        self.skip_weights = nn.ParameterDict()
        self.skip_phases = nn.ParameterDict()
        for s, t in self.skip_connections:
            key = f"{s}_{t}"
            self.skip_weights[key] = nn.Parameter(torch.tensor(0.5))
            self.skip_phases[key] = nn.Parameter(torch.tensor(0.0))

    def _generate_random_connections(self):
        conns = []
        for s in range(self.num_layers - min_skip_distance):
            for t in range(s + min_skip_distance, min(s + max_skip_distance + 1, self.num_layers)):
                if random.random() < rewiring_prob:
                    conns.append((s, t))
        print(f"Generated {len(conns)} random skip connections (p={rewiring_prob})")
        return conns

    def _create_fft_grid(self):
        fx = torch.fft.fftfreq(output_size, d=1e-6)
        fy = torch.fft.fftfreq(output_size, d=1e-6)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')
        k = 2 * np.pi / wavelength
        arg = torch.clamp(1 - (wavelength * FX)**2 - (wavelength * FY)**2, 0.0)
        return torch.exp(1j * k * torch.sqrt(arg))

    def _create_full_detector_masks(self):
        masks = []
        H = W = output_size
        rows = len(detector_layout)
        band_h = H // rows
        for r, cols in enumerate(detector_layout):
            cell_w = W // cols
            y0 = r * band_h; y1 = y0 + band_h
            for c in range(cols):
                x0 = c * cell_w; x1 = x0 + cell_w
                m = torch.zeros(H, W)
                m[y0:y1, x0:x1] = 1.0
                masks.append(m)
        return torch.stack(masks)

    def tile_input(self, x):
        B = x.size(0)
        x = x.view(B, 1, 1, image_size, image_size)
        x = x.repeat(1, tiles, tiles, 1, 1)
        x = x.permute(0, 1, 3, 2, 4).reshape(B, output_size, output_size)
        return x.to(torch.complex64)

    def _apply_gain(self, field, idx):
        return field * self.gains[idx]

    def _apply_sa(self, field, idx):
        out = field
        for _ in range(self.cascade_sas):
            I = out.real**2 + out.imag**2
            T = torch.exp(-self.L[idx] * ((self.A0[idx] / (1 + I / self.I_sat[idx])) + self.A_ns[idx]))
            out = torch.sqrt(T) * out
        return out

    def forward(self, x):
        x = self.tile_input(x)
        inter = {}
        for i in range(self.num_layers):
            # skip merges
            for s, t in self.skip_connections:
                if t == i and s in inter:
                    key = f"{s}_{t}"
                    w = torch.sigmoid(self.skip_weights[key])
                    p = torch.exp(1j * self.skip_phases[key])
                    x = x + w * (inter[s] * p)
            # optical + gain + SA
            x = self.layers[i](x, self.fft_grid)
            x = self._apply_gain(x, i)
            x = self._apply_sa(x, i)
            inter[i] = x.clone()
        # detection
        I = x.real**2 + x.imag**2
        raw = (I.unsqueeze(1) * self.detector_masks.unsqueeze(0)).sum(dim=(2,3))
        return raw * self.detector_scale

# --- Train/Eval ---
def evaluate(model, loader):
    model.eval()
    corr = total = 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            out = model(imgs)
            corr += (out.argmax(1) == lbls).sum().item()
            total += lbls.size(0)
    return 100.0 * corr / total

for p in [0.0,0.1,0.2,0.3,0.5,0.7,1.0]:
    rewiring_prob = p
    random.seed(42); torch.manual_seed(42)
    print(f"\n=== p={p} ===")
    model = FastONNRandomSkipSA(num_layers=10, cascade_sas=2).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    best_val = 0; patience_cnt = 0
    for epoch in range(1, epochs+1):
        model.train(); corr = total = 0; start = time.time()
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad()
            out = model(imgs)
            loss = nn.CrossEntropyLoss()(out, lbls)
            loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
            optimizer.step()
            corr += (out.argmax(1)==lbls).sum().item(); total += lbls.size(0)
        scheduler.step(); val_acc = evaluate(model,val_loader)
        print(f"Epoch {epoch}/{epochs} | Train {100*corr/total:.2f}% | Val {val_acc:.2f}% | Time {time.time()-start:.1f}s")
        if val_acc>best_val: best_val, patience_cnt = val_acc, 0
        else: patience_cnt+=1
        if patience_cnt>=patience: print('Early stopping'); break
    test_acc = evaluate(model,test_loader)
    print(f"Best Val {best_val:.2f}% | Test {test_acc:.2f}%")


100%|██████████| 18.2M/18.2M [00:27<00:00, 662kB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 132kB/s]
100%|██████████| 3.04M/3.04M [00:05<00:00, 586kB/s]
100%|██████████| 5.12k/5.12k [00:00<00:00, 6.27MB/s]


=== p=0.0 ===
Generated 0 random skip connections (p=0.0)





Epoch 1/15 | Train 87.52% | Val 91.28% | Time 27.3s
Epoch 2/15 | Train 93.18% | Val 93.90% | Time 27.5s
Epoch 3/15 | Train 94.29% | Val 93.82% | Time 28.4s
Epoch 4/15 | Train 94.85% | Val 93.43% | Time 27.5s
Epoch 5/15 | Train 94.90% | Val 94.33% | Time 27.4s
Epoch 6/15 | Train 96.68% | Val 95.43% | Time 27.4s
Epoch 7/15 | Train 96.92% | Val 95.17% | Time 27.5s
Epoch 8/15 | Train 97.16% | Val 95.70% | Time 28.0s
Epoch 9/15 | Train 97.33% | Val 95.53% | Time 27.4s
Epoch 10/15 | Train 97.30% | Val 95.55% | Time 27.4s
Epoch 11/15 | Train 98.34% | Val 95.77% | Time 27.5s
Epoch 12/15 | Train 98.50% | Val 95.88% | Time 27.3s
Epoch 13/15 | Train 98.63% | Val 95.65% | Time 28.0s
Epoch 14/15 | Train 98.71% | Val 95.82% | Time 27.4s
Epoch 15/15 | Train 98.81% | Val 95.67% | Time 27.3s
Best Val 95.88% | Test 89.87%

=== p=0.1 ===
Generated 7 random skip connections (p=0.1)
Epoch 1/15 | Train 89.06% | Val 93.80% | Time 30.7s
Epoch 2/15 | Train 94.80% | Val 94.28% | Time 30.1s
Epoch 3/15 | Train 95