In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import random
import torch.nn.functional as F
import math
import time

class FastONNCore(nn.Module):
    """Optical processing unit with a 5×5 MLA over a 15×15 input."""
    def __init__(
        self,
        wavelength=480e-9,
        dx=1e-6,
        dy=1e-6,
        z_dist=186.6e-3,
        num_layers=15,
        tiles=5,            # 5×5 MLA
        tile_size=3,         # each lenslet is 3×3
              # small-world params:
        sw_m=2,         # local span (how many previous layers to link)
        sw_p=0.2,       # rewiring prob
        sw_trainable=True,  # if True, learn weights for small-world links
        sw_init_gamma=0.1,    # initial scale for small-world links
        sw_seed=None
    ):
        super().__init__()
        self.wavelength = wavelength
        self.dx = dx
        self.dy = dy
        self.num_layers = num_layers
        self.z_list = [z_dist] * num_layers
        self.tiles = tiles
        self.tile_size = tile_size
        self.output_size = tiles * tile_size  # 5 × 3 = 15
        self.hub = HubModule(self.output_size, hidden=128)  # Add to init

        self.amp_list = nn.ParameterList([
            nn.Parameter(0.5 * torch.ones(self.output_size, self.output_size))
            for _ in range(num_layers)
        ])
        self.phase_list = nn.ParameterList([
            nn.Parameter(torch.zeros(self.output_size, self.output_size))
            for _ in range(num_layers)
        ])

        # Precompute the FFT grid for propagation (15×15)
        self.register_buffer('fft_grid', self._create_fft_grid(), persistent=False)
        # Detector: 10 regions tiled over the 15×15 output
        self.register_buffer('detector_masks', self._create_detector_masks(), persistent=False)
        # Indices to rebuild a 15×15 composite mask from 25 small 3×3 tiles
        # self._mask_indices = self._create_mask_indices()

    def _create_fft_grid(self):
        H = W = self.output_size  # 15
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        fx = torch.fft.fftfreq(W, d=self.dx, device=device)
        fy = torch.fft.fftfreq(H, d=self.dy, device=device)
        FX, FY = torch.meshgrid(fx, fy, indexing='xy')

        k = 2 * np.pi / self.wavelength
        arg = 1.0 - (self.wavelength * FX) ** 2 - (self.wavelength * FY) ** 2
        arg = torch.clamp(arg, min=0.0)
        return torch.sqrt(arg)

    def _create_detector_masks(self):
        H = W = self.output_size  # 15
        masks = torch.zeros(10, H, W)

        # Partition 15×15 into 10 regions (3 rows with 3 regions, 1 row with 1 region)
        for i in range(10):
            row = i // 3
            col = i % 3

            # Compute row boundaries
            h_start = int(row * H / 4)
            if row < 3:
                h_end = int((row + 1) * H / 4)
            else:
                h_end = H  # Last row takes remaining space

            # Compute column boundaries
            if row < 3:
                w_start = int(col * W / 3)
                w_end = int((col + 1) * W / 3)
            else:
                w_start = 0
                w_end = W  # Last row spans full width

            masks[i, h_start:h_end, w_start:w_end] = 1.0

        return masks

    def propagate(self, U, z):
        k = 2 * np.pi / self.wavelength
        H_transfer = torch.exp(1j * k * z * self.fft_grid)  # (15×15)
        U_fft = torch.fft.fft2(U)
        U_prop = torch.fft.ifft2(U_fft * H_transfer)
        return U_prop

    def tile_input(self, U0):
        """
        U0: (B, H_in, W_in) with H_in=W_in=15.
        We split each 15×15 U0 into 25 patches of size 3×3 → build a 15×15 complex field.
        """
        if U0.ndim == 4:
                U0 = U0.squeeze(1)  # (B, 15, 15)

        return U0.to(torch.complex64)

        # return U_tiled

    def forward(self, U0):

        U = self.tile_input(U0)  # (B, 15, 15) complex



        history = []                   # ← Step 1: initialize history
        # skip_k, α, β = 3, 0.1, 0.1     # hop length & scaling

        for i in range(self.num_layers):

            phase = torch.clamp(self.phase_list[i], -np.pi, np.pi)       # (15,15)
            Mi = self.amp_list[i] * torch.exp(1j * phase)                # (15,15)
            U = U * Mi

            U = self.propagate(U, self.z_list[i])                        # (B,15,15) complex



        I = U.real*2 + U.imag*2  # (B,15,15) real
        logits = (I.unsqueeze(1) * self.detector_masks.unsqueeze(0)).sum(dim=(2, 3))
        return logits






class RGB_FastONN(nn.Module):
    """Optical NN with parallel RGB processing paths using a 5×5 MLA over 15×15 per channel."""
    def __init__(
        self,
        wavelengths=480e-9,
        dx=1e-6,
        dy=1e-6,
        z_dist=186.6e-3,
        num_layers=15,      # 15 layers per color path
        tiles=5,            # 5×5 MLA
        tile_size=3,        # each lenslet is 3×3
        hidden_size=64
    ):
        super().__init__()
        self.core = FastONNCore(
            wavelength=wavelengths,
            dx=dx, dy=dy,
            z_dist=z_dist,
            num_layers=num_layers,
            tiles=tiles,
            tile_size=tile_size
        )


        # Electronic fusion MLP (3×10 → hidden_size → 10)
        self.fc = nn.Sequential(
            nn.Linear(30, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, 10)
        )

    def forward(self, x):
        out = self.core(x[:, 0])      # remove channel dim
        return self.fc(out)


def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    device = next(model.parameters()).device

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    return 100. * correct / total


def train_model(model, train_loader, test_loader, epochs=50, lr=0.001):
    device = next(model.parameters()).device
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5, verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0

    for epoch in range(epochs):
        model.train()
        total_loss, total_correct, total_samples = 0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Gentle phase regularization
            phase_reg = 0.0
            for module in model.modules():
                if isinstance(module, FastONNCore):
                    for p in module.phase_list:
                        phase_reg += 0.001 * p.abs().mean()
            loss += phase_reg

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            _, predicted = outputs.max(1)
            total_correct += predicted.eq(labels).sum().item()
            total_samples += labels.size(0)

        train_loss = total_loss / total_samples
        train_acc = 100.0 * total_correct / total_samples

        test_acc = evaluate(model, test_loader)
        scheduler.step(test_acc)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "best_rgb_onn_model.pth")

        print(f"Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Test Acc: {test_acc:.2f}%")

    return best_acc


if __name__ == "__main__":

    start_time = time.time()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Fixed transforms for RGB (3 channels)
    train_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize(16),
        transforms.CenterCrop(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # RGB normalization
    ])

    test_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize(16),
        transforms.CenterCrop(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Download datasets
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=test_transform)



    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

    # Instantiate model
    model = FastONNCore(

        wavelength=480e-9,
        num_layers=15,
        tiles=5,
        tile_size=3,

    ).to(device)
    # model = torch.compile(model)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {total_params:,}")

    # Train the model
    best_acc = train_model(
        model, train_loader, test_loader,
        epochs=50,
        lr=0.001
    )
    print(f"Best Test Accuracy: {best_acc:.2f}%")
    end_time = time.time()
    elapsed_time = end_time - start_time
    minutes = int(elapsed_time // 60)
    seconds = int(elapsed_time % 60)
    print(f"Total training time: {minutes} min {seconds} sec")


Using device: cuda
Trainable parameters: 237,728
Epoch 1/50 | Train Loss: 2.2118 | Train Acc: 20.76% | Test Acc: 25.61%
Epoch 2/50 | Train Loss: 2.0838 | Train Acc: 26.74% | Test Acc: 27.79%
Epoch 3/50 | Train Loss: 2.0634 | Train Acc: 27.87% | Test Acc: 28.97%
Epoch 4/50 | Train Loss: 2.0565 | Train Acc: 28.41% | Test Acc: 28.83%
Epoch 5/50 | Train Loss: 2.0526 | Train Acc: 28.70% | Test Acc: 29.00%
Epoch 6/50 | Train Loss: 2.0506 | Train Acc: 28.88% | Test Acc: 29.32%
Epoch 7/50 | Train Loss: 2.0488 | Train Acc: 28.98% | Test Acc: 29.03%
Epoch 8/50 | Train Loss: 2.0478 | Train Acc: 28.85% | Test Acc: 28.94%
Epoch 9/50 | Train Loss: 2.0465 | Train Acc: 29.12% | Test Acc: 28.93%
Epoch 10/50 | Train Loss: 2.0456 | Train Acc: 28.99% | Test Acc: 29.03%
Epoch 11/50 | Train Loss: 2.0447 | Train Acc: 29.14% | Test Acc: 28.64%
Epoch 12/50 | Train Loss: 2.0442 | Train Acc: 29.17% | Test Acc: 29.32%
Epoch 13/50 | Train Loss: 2.0419 | Train Acc: 29.32% | Test Acc: 28.77%
Epoch 14/50 | Train Loss