In [None]:
# Copyright 2025 Marc-Antoine Ruel
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from matplotlib import cm
from matplotlib import patches
from matplotlib import pyplot as plt

In [None]:
# Image management.

def create_ring_image(width=1000, height=1000, inner_diameter=300, outer_diameter=400):
    image = np.ones((height, width), dtype=np.uint8) * 255
    y, x = np.indices((height, width))
    distance_from_center = np.sqrt((x - (width//2))**2 + (y - (height//2))**2)
    # Pixels within the outer radius and outside the inner radius are set to black (0)
    ring_mask = (distance_from_center <= (outer_diameter / 2)) & (distance_from_center > (inner_diameter / 2))
    image[ring_mask] = 0
    return image


def select_pixels(image, num_pixels, seed):
    """Selects a specified number of random pixels from the image and returns their coordinates and pixel values as a stacked NumPy array."""
    height, width = image.shape
    total_pixels = height * width
    if num_pixels > total_pixels:
        raise ValueError("Number of pixels to select cannot exceed the total number of pixels in the image.")
    np.random.seed(seed)
    indices = np.random.choice(total_pixels, size=num_pixels, replace=False)
    # Convert linear indices to 2D coordinates (row, column).
    rows = indices // width
    cols = indices % width
    # Stack the coordinates and pixel values. Order: x, y, value
    return np.stack((cols, rows, image[rows, cols]), axis=1)


def display_image(image):
    plt.imshow(image, cmap='gray')  # 'gray' colormap for black and white
    plt.title("Black Ring on White Background")
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()


def display_image_and_points(image, points, colormap='plasma'):
    """Plots the ring image with overlaid scatter points from stacked data."""
    image_height, image_width = image.shape
    fig, ax = plt.subplots(1, figsize=(image_width / 100, image_height / 100))
    ax.imshow(image, cmap='gray')
    scatter = ax.scatter(points[:, 0], points[:, 1], c=points[:, 2], cmap=colormap, s=4)
    cbar = fig.colorbar(scatter, ax=ax, label='Color Value (0 to 1)')
    ax.set_xlim(0, image_width)
    ax.set_ylim(image_height, 0)
    ax.axis('off')
    ax.set_aspect('equal')
    plt.title("Black Ring with Colored Scatter Plot")
    plt.show()

In [None]:
# Neural network

def train_model(model, data, epochs=1000, lr=0.01):
    # Extract features and labels.
    X = data[:, :2]  # First two columns (x, y coordinates)
    y = data[:, 2:3]  # Third column (ring indicator)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # Lists to store loss values for plotting
    losses = []
    for epoch in range(epochs):
        # Forward pass
        outputs = model(X)
        loss = criterion(outputs, y)
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    return losses


def evaluate_model(model, data):
    with torch.no_grad():
        # Grade inputs and the outputs generated by the model with the expected outputs.
        return (nn.MSELoss()(model(data[:, :2]), data[:, 2:3])).item()


def one_run(model):
    image = create_ring_image()
    training_data = torch.FloatTensor(select_pixels(image, num_pixels=6000, seed=42))
    losses = train_model(model, training_data, epochs=1000, lr=0.01)
    accuracy = evaluate_model(model, training_data)
    print(f'Final model accuracy: {accuracy:.4f}')
    plt.plot(losses)
    plt.show()

    test_data = select_pixels(image, num_pixels=500, seed=43)
    print("Test data evaluation:")
    test_accuracy = evaluate_model(model, torch.FloatTensor(test_data))
    display_image_and_points(image, test_data)

In [None]:
# Inspired by https://github.com/YihongDong/FANformer/blob/main/olmo/model.py#L77-L137 but mostly rewritten.
# License unclear. That said, code is trivial.

class FANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, p_ratio=0.25, activation=None):
        """FANLayer: The layer used in FAN (https://arxiv.org/abs/2410.02675).
    
        Args:
            input_dim (int): The number of input features.
            output_dim (int): The number of output features.
            p_ratio (float): The ratio of output dimensions used for cosine and sine parts (default: 0.25).
            activation (callable): The activation function to apply to the g component.
        """
        super(FANLayer, self).__init__()
        assert 0 <= p_ratio <= 0.5, "p_ratio must be between 0 and 0.5"
        p_output_dim = int(output_dim * p_ratio)
        # Account for cosine and sine terms.
        g_output_dim = output_dim - p_output_dim * 2
        self.input_linear = nn.Linear(input_dim, p_output_dim+g_output_dim)
        self.fused_dims = (p_output_dim, g_output_dim)
        # With f.relu or F.gelu, it's unable to learn negative values.
        self.activation = activation or (lambda x: x)

    def forward(self, x):
        p, g = self.input_linear(x).split(self.fused_dims, dim=-1)
        # Concatenate cos(p), sin(p), and activated g along the last dimension.
        return torch.cat((torch.cos(p), torch.sin(p), self.activation(g)), dim=-1)


class FAN(nn.Module):
    """Neural network using FAN."""
    def __init__(self, input_dim=2, hidden_layers=1, hidden_size=64, output_dim=1, p_ratio=0.25, activation=None, dropout_rate=0.2):
        super(FAN, self).__init__()
        layers = [
            FANLayer(input_dim, hidden_size, p_ratio, activation),
        ]
        for i in range(hidden_layers):
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            layers.append(FANLayer(hidden_size, hidden_size, p_ratio, activation))
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))
        layers.append(FANLayer(hidden_size, output_dim))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [None]:
class SimpleNetwork(nn.Module):
    def __init__(self, input_dim=2, hidden_layers=1, hidden_size=64, output_dim=1, dropout_rate=0.2):
        super(SimpleNetwork, self).__init__()
        layers = [
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
        ]
        for i in range(hidden_layers):
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))
        layers.append(nn.Linear(hidden_size, output_dim))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [None]:
for cls in (SimpleNetwork, FAN):
    print(cls.__name__)
    one_run(cls(hidden_layers=1, hidden_size=6, dropout_rate=0))