ONN MODEL WITH 10 LAYERS, WITH SMALL WORLD CONNECTIONS AND SATURABLE ABSORBER. DATASET EMNIST


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import EMNIST
import numpy as np
import time
import random

# --- Config ---
detector_layout = [8, 8, 8, 8, 8, 7]  # six rows: five of 8, one of 7 = 47 classes for EMNIST Balanced
wavelength = 480e-9
image_size = 14
tiles = 6
output_size = tiles * image_size  # 84
min_skip_distance = 2  # Minimum layers to skip
max_skip_distance = 7  # Maximum skip distance
# rewiring_prob will be set in the sweep

# Saturable absorber base parameters (modifiable)
base_A0 = 0.1         # small-signal absorption
base_I_sat = 5e6      # saturation intensity (|E|^2 units)
base_A_ns = 0.005     # non-saturable loss

# --- Device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Dataset (EMNIST Balanced) ---
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
full_train = EMNIST(root="./data", split="balanced", train=True, download=True, transform=transform)
test_dataset = EMNIST(root="./data", split="balanced", train=False, download=True, transform=transform)

# Create train / validation split
total_train = len(full_train)
val_size = int(0.1 * total_train)
train_size = total_train - val_size
train_dataset, val_dataset = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# --- Compute number of classes ---
num_classes = sum(detector_layout)
assert num_classes == 47, f"Sum of detector_layout must be 47, got {num_classes}"

# --- Model definitions ---
class OpticalLayer(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.phase = nn.Parameter(torch.randn(size, size) * 0.1)

    def forward(self, x, kernel):
        x = x * torch.exp(1j * self.phase)
        x_fft = torch.fft.fft2(x)
        x_fft = x_fft * kernel
        return torch.fft.ifft2(x_fft)

class FastONNRandomSkipSA(nn.Module):
    def __init__(self, num_layers=10, cascade_sas=2):
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([OpticalLayer(output_size) for _ in range(num_layers)])
        self.detector_scale = nn.Parameter(torch.tensor([10.0], dtype=torch.float32))
        self.fft_grid = self._create_fft_grid().to(device)
        self.detector_masks = self._create_detector_masks().to(device)

        # Purely optical gain stages per layer
        self.gains = nn.ParameterList([nn.Parameter(torch.tensor(1.0)) for _ in range(num_layers)])

        # Saturable absorber parameters (learnable or tunable per layer)
        self.A0 = nn.ParameterList([nn.Parameter(torch.tensor(base_A0)) for _ in range(num_layers)])
        self.I_sat = nn.ParameterList([nn.Parameter(torch.tensor(base_I_sat)) for _ in range(num_layers)])
        self.A_ns = nn.ParameterList([nn.Parameter(torch.tensor(base_A_ns)) for _ in range(num_layers)])
        # Tunable absorber length factor per layer
        self.L = nn.ParameterList([nn.Parameter(torch.tensor(1.0)) for _ in range(num_layers)])
        # Number of cascaded SA elements
        self.cascade_sas = cascade_sas

        # Skip connections
        self.skip_connections = self._generate_random_connections()
        self.skip_weights = nn.ParameterDict()
        self.skip_phases = nn.ParameterDict()
        for src, tgt in self.skip_connections:
            key = f"{src}_{tgt}"
            self.skip_weights[key] = nn.Parameter(torch.tensor(0.5))
            self.skip_phases[key] = nn.Parameter(torch.tensor(0.0))

    def _generate_random_connections(self):
        connections = []
        for src in range(self.num_layers - min_skip_distance):
            max_tgt = min(src + max_skip_distance, self.num_layers - 1)
            for tgt in range(src + min_skip_distance, max_tgt + 1):
                if random.random() < rewiring_prob:
                    connections.append((src, tgt))
        print(f"Generated {len(connections)} random skip connections (p={rewiring_prob})")
        return connections

    def _create_fft_grid(self):
        fx = torch.fft.fftfreq(output_size, d=1e-6)
        fy = torch.fft.fftfreq(output_size, d=1e-6)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')
        k = 2 * np.pi / wavelength
        arg = 1 - (wavelength * FX)**2 - (wavelength * FY)**2
        arg = torch.clamp(arg, min=0.0)
        return torch.exp(1j * k * torch.sqrt(arg))

    def _create_detector_masks(self):
        masks = []
        H = W = output_size
        rows = len(detector_layout)
        band_h = H // rows
        for r, cols in enumerate(detector_layout):
            cell_w = W // cols
            det_h, det_w = band_h, cell_w
            y_center = r * band_h + band_h // 2
            y0 = max(0, y_center - det_h // 2)
            y1 = min(H, y0 + det_h)
            for c in range(cols):
                x_center = c * cell_w + cell_w // 2
                x0 = max(0, x_center - det_w // 2)
                x1 = min(W, x0 + det_w)
                m = torch.zeros(H, W)
                m[y0:y1, x0:x1] = 1.0
                masks.append(m)
        masks = torch.stack(masks)
        assert masks.shape[0] == num_classes, f"Expected {num_classes} masks, got {masks.shape[0]}"
        return masks

    def tile_input(self, x):
        B = x.size(0)
        x = x.view(B, 1, 1, image_size, image_size)
        x = x.repeat(1, tiles, tiles, 1, 1)
        x = x.permute(0, 1, 3, 2, 4).reshape(B, output_size, output_size)
        return x.to(torch.complex64)

    def _apply_gain(self, field, layer_idx):
        return field * self.gains[layer_idx]

    def _apply_sa(self, field, layer_idx):
        # cascade multiple SA elements for steeper nonlinearity
        out = field
        for _ in range(self.cascade_sas):
            intensity = out.real**2 + out.imag**2
            T = torch.exp(-self.L[layer_idx] * ((self.A0[layer_idx] / (1 + intensity / self.I_sat[layer_idx])) + self.A_ns[layer_idx]))
            out = torch.sqrt(T) * out
        return out

    def forward(self, x):
        x = self.tile_input(x)
        intermediate = {}
        for idx in range(self.num_layers):
            # merge skip paths
            for src, tgt in self.skip_connections:
                if tgt == idx and src in intermediate:
                    key = f"{src}_{tgt}"
                    w = torch.sigmoid(self.skip_weights[key])
                    p = torch.exp(1j * self.skip_phases[key])
                    x = x + w * (intermediate[src] * p)

            # optical layer
            x = self.layers[idx](x, self.fft_grid)
            # gain stage
            x = self._apply_gain(x, idx)
            # saturable absorber nonlinearity
            x = self._apply_sa(x, idx)
            # store for skips
            intermediate[idx] = x.clone()

        # detection
        intensity = x.real**2 + x.imag**2
        raw = (intensity.unsqueeze(1) * self.detector_masks.unsqueeze(0)).sum(dim=(2,3))
        return raw * self.detector_scale

# --- Evaluation & Training ---
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.squeeze(1).to(device), lbls.to(device)
            out = model(imgs)
            correct += (out.argmax(1) == lbls).sum().item()
            total += lbls.size(0)
    return 100.0 * correct / total


def train_model(model, train_loader, val_loader, test_loader, epochs=15, lr=0.01):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    best_val = 0.0
    for epoch in range(1, epochs+1):
        model.train()
        loss_sum, corr, tot = 0.0, 0, 0
        start = time.time()
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.squeeze(1).to(device), lbls.to(device)
            optimizer.zero_grad()
            out = model(imgs)
            loss = criterion(out, lbls)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            loss_sum += loss.item()*lbls.size(0)
            corr += (out.argmax(1)==lbls).sum().item()
            tot += lbls.size(0)
        scheduler.step()
        train_acc = 100*corr/tot
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch}/{epochs} | Loss: {loss_sum/tot:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {time.time()-start:.1f}s")
        best_val = max(best_val, val_acc)
    test_acc = evaluate(model, test_loader)
    print(f"Best Val Acc: {best_val:.2f}% | Test Acc: {test_acc:.2f}%")
    return test_acc

# --- Hyperparameter Sweep ---
for p in [0.0, 0.1, 0.2, 0.3, 0.5, 0.7, 1.0]:
    rewiring_prob = p
    random.seed(42)
    torch.manual_seed(42)
    print(f"\n=== p={p} ===")
    model = FastONNRandomSkipSA(num_layers=10, cascade_sas=2).to(device)
    _ = train_model(model, train_loader, val_loader, test_loader)


100%|██████████| 562M/562M [00:02<00:00, 200MB/s]



=== p=0.0 ===
Generated 0 random skip connections (p=0.0)
Epoch 1/15 | Loss: 1.2531 | Train Acc: 66.98% | Val Acc: 72.05% | Time: 54.5s
Epoch 2/15 | Loss: 0.9309 | Train Acc: 74.69% | Val Acc: 73.90% | Time: 54.4s
Epoch 3/15 | Loss: 0.8566 | Train Acc: 76.28% | Val Acc: 76.61% | Time: 52.9s
Epoch 4/15 | Loss: 0.8036 | Train Acc: 77.56% | Val Acc: 76.40% | Time: 53.0s
Epoch 5/15 | Loss: 0.7787 | Train Acc: 77.95% | Val Acc: 75.62% | Time: 53.3s
Epoch 6/15 | Loss: 0.6375 | Train Acc: 81.24% | Val Acc: 79.88% | Time: 53.3s
Epoch 7/15 | Loss: 0.6022 | Train Acc: 81.97% | Val Acc: 80.13% | Time: 53.0s
Epoch 8/15 | Loss: 0.5910 | Train Acc: 82.43% | Val Acc: 80.45% | Time: 52.9s
Epoch 9/15 | Loss: 0.5758 | Train Acc: 82.59% | Val Acc: 79.87% | Time: 53.6s
Epoch 10/15 | Loss: 0.5590 | Train Acc: 82.90% | Val Acc: 80.92% | Time: 53.0s
Epoch 11/15 | Loss: 0.4814 | Train Acc: 84.94% | Val Acc: 81.86% | Time: 53.0s
Epoch 12/15 | Loss: 0.4619 | Train Acc: 85.34% | Val Acc: 82.49% | Time: 53.2s
Ep