ONN MODEL WITH 10 LAYERS, MLA, SMALL WORLD SKIP CONNECTIONS AND SATURABLE ABSORBER. DATASET FASHION MNIST

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
import time
import random

# --- Config ---
detector_layout = [3, 4, 3]  # three rows: 3,4,3 detectors = 10 classes
wavelength = 480e-9
image_size = 14
tiles = 6
output_size = tiles * image_size  # 84
min_skip_distance = 2  # Minimum layers to skip
max_skip_distance = 7  # Maximum skip distance

# Saturable absorber base parameters
base_A0 = 0.1
base_I_sat = 5e6
base_A_ns = 0.005

# Training parameters
epochs = 15
patience = 5

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Dataset (FashionMNIST) ---
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
full_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)  # FashionMNIST dataset
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)  # FashionMNIST dataset

# Create train / validation split
total_train = len(full_train)
val_size = int(0.1 * total_train)
train_size = total_train - val_size
train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# --- Model definitions ---
class OpticalLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(size, size) * 0.1)

    def forward(self, x, kernel):
        x = x * torch.exp(1j * self.phase)
        x_fft = torch.fft.fft2(x)
        x_fft = x_fft * kernel
        return torch.fft.ifft2(x_fft)

class FastONNRandomSkipSA(nn.Module):
    def __init__(self, num_layers=10, cascade_sas=2):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([OpticalLayer(output_size) for _ in range(num_layers)])
        self.detector_scale = nn.Parameter(torch.tensor([10.0], dtype=torch.float32))
        self.fft_grid = self._create_fft_grid().to(device)
        self.detector_masks = self._create_detector_masks().to(device)

        # Generate random skip connections using the global rewiring_prob
        self.skip_connections = self._generate_random_connections()
        self.skip_weights = nn.ParameterDict()
        self.skip_phases = nn.ParameterDict()

        for src, tgt in self.skip_connections:
            key = f"{src}_{tgt}"
            self.skip_weights[key] = nn.Parameter(torch.tensor(0.5))
            self.skip_phases[key] = nn.Parameter(torch.tensor(0.0))

    def _generate_random_connections(self):
        connections = []
        for src in range(self.num_layers - min_skip_distance):
            max_tgt = min(src + max_skip_distance, self.num_layers - 1)
            for tgt in range(src + min_skip_distance, max_tgt + 1):
                if random.random() < rewiring_prob:
                    connections.append((src, tgt))
        return connections

    def _create_fft_grid(self):
        fx = torch.fft.fftfreq(output_size, d=1e-6)
        fy = torch.fft.fftfreq(output_size, d=1e-6)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')
        k = 2 * np.pi / wavelength
        arg = torch.clamp(1 - (wavelength * FX)**2 - (wavelength * FY)**2, 0.0)
        return torch.exp(1j * k * torch.sqrt(arg))

    def _create_detector_masks(self):
        masks = []
        H = W = output_size
        rows = len(detector_layout)
        band_h = H // rows
        for r, cols in enumerate(detector_layout):
            cell_w = W // cols
            det_h = band_h
            y0 = r * band_h
            y1 = y0 + det_h
            for c in range(cols):
                x0 = c * cell_w
                x1 = x0 + cell_w
                m = torch.zeros(H, W)
                m[y0:y1, x0:x1] = 1.0
                masks.append(m)
        return torch.stack(masks)

    def tile_input(self, x):
        B = x.size(0)
        x = x.view(B, 1, 1, image_size, image_size)
        x = x.repeat(1, tiles, tiles, 1, 1)
        x = x.permute(0, 1, 3, 2, 4).reshape(B, output_size, output_size)
        return x.to(torch.complex64)

    def forward(self, x):
        x = self.tile_input(x)
        intermediate_outputs = {}

        for layer_idx in range(self.num_layers):
            for (src, tgt) in self.skip_connections:
                if tgt == layer_idx and src in intermediate_outputs:
                    key = f"{src}_{tgt}"
                    weight = torch.sigmoid(self.skip_weights[key])
                    phase_corr = torch.exp(1j * self.skip_phases[key])
                    x = x + weight * (intermediate_outputs[src] * phase_corr)
            x = self.layers[layer_idx](x, self.fft_grid)
            intermediate_outputs[layer_idx] = x.clone()

        intensity = x.real**2 + x.imag**2
        raw = (intensity.unsqueeze(1) * self.detector_masks.unsqueeze(0)).sum(dim=(2, 3))
        return raw * self.detector_scale

# --- Training & Evaluation ---

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.squeeze(1).to(device), lbls.to(device)
            out = model(imgs)
            preds = out.argmax(1)
            correct += preds.eq(lbls).sum().item()
            total += lbls.size(0)
    return 100.0 * correct / total


def train_model(model, train_loader, val_loader, test_loader, epochs=15, lr=0.01):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    best_val_acc = 0.0
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        start = time.time()

        for imgs, lbls in train_loader:
            imgs, lbls = imgs.squeeze(1).to(device), lbls.to(device)
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, lbls)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item() * lbls.size(0)
            preds = out.argmax(1)
            correct += preds.eq(lbls).sum().item()
            total += lbls.size(0)

        scheduler.step()
        train_acc = 100.0 * correct / total
        val_acc = evaluate(model, val_loader)
        duration = time.time() - start

        print(f"Epoch {epoch}/{epochs} | Time: {duration:.1f}s | "
              f"Train Loss: {total_loss/total:.4f} | "
              f"Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

        best_val_acc = max(best_val_acc, val_acc)

    test_acc = evaluate(model, test_loader)
    print(f"Best Val Accuracy: {best_val_acc:.2f}% | Test Accuracy: {test_acc:.2f}%")
    return test_acc

# --- Hyperparameter sweep on rewiring_prob ---
for p in [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0]:
    rewiring_prob = p
    print(f"\n=== Training with rewiring_prob = {p} ===")
    random.seed(42)
    torch.manual_seed(42)

    model = FastONNRandomSkipSA(num_layers=10, cascade_sas=2).to(device)
    acc = train_model(model, train_loader, val_loader, test_loader, epochs=15, lr=0.01)
    print(f"→ p={p} Test Accuracy: {acc:.2f}%")

100%|██████████| 26.4M/26.4M [00:02<00:00, 12.4MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 209kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.89MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 22.3MB/s]



=== Training with rewiring_prob = 0.0 ===
Epoch 1/15 | Time: 8.2s | Train Loss: 13.3072 | Train Acc: 73.20% | Val Acc: 79.28%
Epoch 2/15 | Time: 6.9s | Train Loss: 3.6944 | Train Acc: 79.11% | Val Acc: 79.05%
Epoch 3/15 | Time: 6.8s | Train Loss: 0.5688 | Train Acc: 83.34% | Val Acc: 85.17%
Epoch 4/15 | Time: 6.9s | Train Loss: 0.5256 | Train Acc: 84.32% | Val Acc: 83.03%
Epoch 5/15 | Time: 6.9s | Train Loss: 0.5163 | Train Acc: 84.51% | Val Acc: 84.37%
Epoch 6/15 | Time: 6.8s | Train Loss: 0.4605 | Train Acc: 85.75% | Val Acc: 85.58%
Epoch 7/15 | Time: 6.8s | Train Loss: 0.4698 | Train Acc: 85.60% | Val Acc: 84.50%
Epoch 8/15 | Time: 6.9s | Train Loss: 0.4620 | Train Acc: 85.92% | Val Acc: 84.68%
Epoch 9/15 | Time: 6.8s | Train Loss: 0.4536 | Train Acc: 85.90% | Val Acc: 85.80%
Epoch 10/15 | Time: 6.9s | Train Loss: 0.4555 | Train Acc: 85.85% | Val Acc: 86.17%
Epoch 11/15 | Time: 6.9s | Train Loss: 0.4136 | Train Acc: 86.97% | Val Acc: 85.70%
Epoch 12/15 | Time: 6.9s | Train Loss: 0.