In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import random
import torch.nn.functional as F
import math
import time

class FastONNCore(nn.Module):
    """Optical processing unit with a 5×5 MLA over a 15×15 input."""
    def __init__(
        self,
        wavelength=480e-9,
        dx=1e-6,
        dy=1e-6,
        z_dist=186.6e-3,
        num_layers=10,
        tiles=5,            # 5×5 MLA
        tile_size=3,         # each lenslet is 3×3
              # small-world params:
        sw_m=2,         # local span (how many previous layers to link)
        sw_p=0.2,       # rewiring prob
        sw_trainable=True,  # if True, learn weights for small-world links
        sw_init_gamma=0.1,    # initial scale for small-world links
        sw_seed=None
    ):
        super().__init__()
        self.wavelength = wavelength
        self.dx = dx
        self.dy = dy
        self.num_layers = num_layers
        self.z_list = [z_dist] * num_layers
        self.tiles = tiles
        self.tile_size = tile_size
        self.output_size = tiles * tile_size  # 5 × 3 = 15
        self.hub = HubModule(self.output_size, hidden=128)  # Add to init


        self.amp_list = nn.ParameterList([
            nn.Parameter(0.5 * torch.ones(self.output_size, self.output_size))
            for _ in range(num_layers)
        ])
        self.phase_list = nn.ParameterList([
            nn.Parameter(torch.zeros(self.output_size, self.output_size))
            for _ in range(num_layers)
        ])

        # Precompute the FFT grid for propagation (15×15)
        self.register_buffer('fft_grid', self._create_fft_grid(), persistent=False)
        # Detector: 10 regions tiled over the 15×15 output
        self.register_buffer('detector_masks', self._create_detector_masks(), persistent=False)

        # --- Small-world setup ---
        self.sw_m = sw_m
        self.sw_p = sw_p
        if sw_seed is not None:
          random.seed(sw_seed)
        # Build small-world neighbors: for each layer i, a list of earlier layer indices
        sw_neighbors = []

        for i in range(num_layers):
            neigh = []
            # local neighbors: previous sw_m layers

            for offset in range(1, sw_m + 1):
                j = i - offset
                if j < 0:
                    break
                # decide whether to rewire
                if random.random() < sw_p and i > 0:
                    # pick a random earlier layer in [0, i-1], avoiding duplicates
                    candidates = set(range(0, i)) - set(neigh)
                    if candidates:
                        j_rand = random.choice(list(candidates))
                        neigh.append(j_rand)
                    else:
                        neigh.append(j)
                else:
                    neigh.append(j)
            sw_neighbors.append(neigh)
        # store as plain Python list (no gradient)
        self.sw_neighbors = sw_neighbors

        # Optional: trainable weights per small-world link
        self.sw_trainable = sw_trainable
        if sw_trainable:
            # For each layer i, create a Parameter of shape (k_i,)
            self.sw_weights = nn.ParameterList()
            for i in range(num_layers):
                k_i = len(self.sw_neighbors[i])
                if k_i > 0:
                    # initialize around sw_init_gamma
                    init = torch.ones(k_i) * sw_init_gamma
                    self.sw_weights.append(nn.Parameter(init))
                else:
                    # placeholder for consistency; won't be used in forward
                    self.sw_weights.append(nn.Parameter(torch.zeros(0)))
        else:
            # fixed scale
            self.sw_gamma = sw_init_gamma


    def _create_fft_grid(self):
        H = W = self.output_size  # 15
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        fx = torch.fft.fftfreq(W, d=self.dx, device=device)
        fy = torch.fft.fftfreq(H, d=self.dy, device=device)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')

        k = 2 * np.pi / self.wavelength
        arg = 1.0 - (self.wavelength * FX) ** 2 - (self.wavelength * FY) ** 2
        arg = torch.clamp(arg, min=0.0)
        return torch.sqrt(arg)

    def _create_detector_masks(self):
        H = W = self.output_size  # 15
        masks = torch.zeros(10, H, W)

        # Partition 15×15 into 10 regions (3 rows with 3 regions, 1 row with 1 region)
        for i in range(10):
            row = i // 3
            col = i % 3

            # Compute row boundaries
            h_start = int(row * H / 4)
            if row < 3:
                h_end = int((row + 1) * H / 4)
            else:
                h_end = H  # Last row takes remaining space

            # Compute column boundaries
            if row < 3:
                w_start = int(col * W / 3)
                w_end = int((col + 1) * W / 3)
            else:
                w_start = 0
                w_end = W  # Last row spans full width

            masks[i, h_start:h_end, w_start:w_end] = 1.0

        return masks

    def propagate(self, U, z):
        k = 2 * np.pi / self.wavelength
        H_transfer = torch.exp(1j * k * z * self.fft_grid)  # (15×15)
        U_fft = torch.fft.fft2(U)
        U_prop = torch.fft.ifft2(U_fft * H_transfer)
        return U_prop

    def tile_input(self, U0):
        """
        U0: (B, H_in, W_in) with H_in=W_in=15.
        We split each 15×15 U0 into 25 patches of size 3×3 → build a 15×15 complex field.
        """
        B, H, W = U0.shape  # H=W=15
        return U0.to(torch.complex64)


    def forward(self, U0):
        """
        U0: (B, H_in, W_in) with H_in=W_in=15 (grayscale per channel).
        1) tile_input → (B, 15, 15) complex
        2) For each of the num_layers:
            a) build composite mask (15×15) from 25 learned 3×3 tiles
            b) multiply U * M, then propagate
        3) Compute intensity → (B, 15,15) real → apply 10 detectors → (B,10) logits
        """
        U = self.tile_input(U0)  # (B, 15, 15) complex

        history = []                   # ← Step 1: initialize history

        for i in range(self.num_layers):
            phase = torch.clamp(self.phase_list[i], -np.pi, np.pi)       # (15,15)
            Mi = self.amp_list[i] * torch.exp(1j * phase)                # (15,15)
            U = U * Mi

            U = self.propagate(U, self.z_list[i])                        # (B,15,15) complex

            #  --- Small-world aggregation ---
            neigh = [j for j in self.sw_neighbors[i] if j < len(history) and history[j] is not None]
            k_i = len(neigh)
            if k_i > 0:
              if self.sw_trainable:
                weights_sw = self.sw_weights[i][:k_i] # Select only weights for valid neighbors
                norm = math.sqrt(k_i)
                for idx_j, j in enumerate(neigh):
                  gamma_ij = weights_sw[idx_j] / norm
                  U = U + gamma_ij * history[j]
              else:
                scale = self.sw_gamma / math.sqrt(k_i)
                for j in neigh:
                  U = U + scale * history[j]
            # ----------------------------------

            # 3) Append the *final* U for this layer
            history.append(U)

        # Ensure U is not None before calculating intensity
        if U is not None:
            I = U.real*2 + U.imag*2  # (B,15,15) real
            logits = (I.unsqueeze(1) * self.detector_masks.unsqueeze(0)).sum(dim=(2, 3))
        else:
            # Handle the case where U might still be None (e.g., num_layers is 0)
            logits = torch.zeros(U0.shape[0], 10, device=U0.device) # Or handle as appropriate

        return logits


class HubModule(nn.Module):
    def __init__(self, input_size, hidden=128):
        super().__init__()
        flat_dim = input_size * input_size
        # Adjust input size to handle potentially fewer tensors
        self.fc = nn.Sequential(
            nn.Linear(3 * flat_dim * 2, hidden),  # Max 3 previous layers, real+imag
            nn.ReLU(),
            nn.Linear(hidden, flat_dim * 2)       # real+imag output
        )
        self.input_size = input_size

    def forward(self, tensors):  # tensors: list of complex 2D fields
        # Filter out None tensors and handle empty list
        valid_tensors = [t for t in tensors if t is not None]
        if not valid_tensors:
            # Return a zero tensor of the expected output shape
            return torch.zeros(1, self.input_size, self.input_size, dtype=torch.complex64, device=self.fc[0].weight.device)

        # Split real and imag parts
        reals = [t.real.view(t.shape[0], -1) for t in valid_tensors]
        imags = [t.imag.view(t.shape[0], -1) for t in valid_tensors]

        # Pad with zeros if fewer than 3 tensors
        while len(reals) < 3:
            reals.append(torch.zeros_like(reals[0]))
            imags.append(torch.zeros_like(imags[0]))

        x = torch.cat(reals + imags, dim=1)  # concat all
        out = self.fc(x)
        real_part, imag_part = out.chunk(2, dim=1)
        complex_out = torch.complex(real_part, imag_part).view(
            valid_tensors[0].shape[0], self.input_size, self.input_size)
        return complex_out



class RGB_FastONN_Ensemble(nn.Module):
    def __init__(self, base_model_cls, n_models=10, *args, **kwargs):
        super().__init__()
        self.models = nn.ModuleList([
            base_model_cls(*args, **kwargs) for _ in range(n_models)
        ])
        self.weights = nn.Parameter(torch.ones(n_models) / n_models)  # learnable or fixed weights

    def forward(self, x):
        outputs = []
        for model in self.models:
            out = model(x)  # shape: (B, 10)
            outputs.append(out)
        stacked = torch.stack(outputs, dim=0)  # shape: (30, B, 10)
        weighted = self.weights.view(-1, 1, 1) * stacked
        return weighted.sum(dim=0)  # shape: (B, 10)


class RGB_FastONN(nn.Module):
    """Optical NN with parallel RGB processing paths using a 5×5 MLA over 15×15 per channel."""
    def __init__(
        self,
        wavelengths=[650e-9, 530e-9, 470e-9],
        dx=1e-6,
        dy=1e-6,
        z_dist=186.6e-3,
        num_layers=15,      # 15 layers per color path
        tiles=5,            # 5×5 MLA
        tile_size=3,        # each lenslet is 3×3
        hidden_size=64
    ):
        super().__init__()
        self.red_path = FastONNCore(
            wavelength=wavelengths[0],
            dx=dx, dy=dy,
            z_dist=z_dist,
            num_layers=num_layers,
            tiles=tiles,
            tile_size=tile_size
        )
        self.green_path = FastONNCore(
            wavelength=wavelengths[1],
            dx=dx, dy=dy,
            z_dist=z_dist,
            num_layers=num_layers,
            tiles=tiles,
            tile_size=tile_size
        )
        self.blue_path = FastONNCore(
            wavelength=wavelengths[2],
            dx=dx, dy=dy,
            z_dist=z_dist,
            num_layers=num_layers,
            tiles=tiles,
            tile_size=tile_size
        )

        # Electronic fusion MLP (3×10 → hidden_size → 10)
        self.fusion = nn.Sequential(
            nn.Linear(30, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, 10)
        )

    def forward(self, x):
        # x: (B, 3, 15, 15) after transforms and CenterCrop
        r_out = self.red_path(x[:, 0])    # (B,10)
        g_out = self.green_path(x[:, 1])  # (B,10)
        b_out = self.blue_path(x[:, 2])   # (B,10)

        combined = torch.cat([r_out, g_out, b_out], dim=1)  # (B,30)
        return self.fusion(combined)                        # (B,10)


def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    device = next(model.parameters()).device

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    return 100. * correct / total


def train_model(model, train_loader, test_loader, epochs=50, lr=0.001):
    device = next(model.parameters()).device
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        total_loss, total_correct, total_samples = 0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Gentle phase regularization
            phase_reg = 0.0
            for module in model.modules():
                if isinstance(module, FastONNCore):
                    for p in module.phase_list:
                        phase_reg += 0.001 * p.abs().mean()
            loss += phase_reg

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            _, predicted = outputs.max(1)
            total_correct += predicted.eq(labels).sum().item()
            total_samples += labels.size(0)

        train_loss = total_loss / total_samples
        train_acc = 100.0 * total_correct / total_samples

        test_acc = evaluate(model, test_loader)
        scheduler.step(test_acc)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_rgb_onn_model.pth")

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}%")

    return best_acc


if __name__ == "__main__":

    start_time = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Fixed transforms for RGB (3 channels)
    train_transform = transforms.Compose([
        transforms.Resize(16),
        transforms.CenterCrop(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # RGB normalization
    ])

    test_transform = transforms.Compose([
        transforms.Resize(16),
        transforms.CenterCrop(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Download datasets
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=test_transform)



    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

    # Instantiate model
    model = RGB_FastONN(


        wavelengths=[650e-9, 530e-9, 470e-9],
        num_layers=15,
        tiles=5,
        tile_size=3,
        hidden_size=64
    ).to(device)
    # model = torch.compile(model)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {total_params:,}")

    # Train the model
    best_acc = train_model(
        model, train_loader, test_loader,
        epochs=50,
        lr=0.001
    )
    print(f"Best Test Accuracy: {best_acc:.2f}%")
    end_time = time.time()
    elapsed_time = end_time - start_time
    minutes = int(elapsed_time // 60)
    seconds = int(elapsed_time % 60)
    print(f"Total training time: {minutes} min {seconds} sec")

Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 49.5MB/s]


Trainable parameters: 715,899




Epoch 1/50 | Train Loss: 2.0466 | Train Acc: 24.28% | Test Acc: 33.90%
Epoch 2/50 | Train Loss: 1.8292 | Train Acc: 34.23% | Test Acc: 37.52%
Epoch 3/50 | Train Loss: 1.7740 | Train Acc: 36.35% | Test Acc: 39.80%
Epoch 4/50 | Train Loss: 1.7333 | Train Acc: 38.07% | Test Acc: 41.64%
Epoch 5/50 | Train Loss: 1.7052 | Train Acc: 39.35% | Test Acc: 42.43%
Epoch 6/50 | Train Loss: 1.6784 | Train Acc: 40.65% | Test Acc: 43.28%
Epoch 7/50 | Train Loss: 1.6584 | Train Acc: 41.32% | Test Acc: 44.03%
Epoch 8/50 | Train Loss: 1.6435 | Train Acc: 42.03% | Test Acc: 44.96%
Epoch 9/50 | Train Loss: 1.6316 | Train Acc: 42.39% | Test Acc: 44.77%
Epoch 10/50 | Train Loss: 1.6195 | Train Acc: 42.83% | Test Acc: 45.35%
Epoch 11/50 | Train Loss: 1.6052 | Train Acc: 43.50% | Test Acc: 45.47%
Epoch 12/50 | Train Loss: 1.5962 | Train Acc: 43.90% | Test Acc: 45.77%
Epoch 13/50 | Train Loss: 1.5903 | Train Acc: 43.78% | Test Acc: 46.58%
Epoch 14/50 | Train Loss: 1.5792 | Train Acc: 44.48% | Test Acc: 46.68%
E