In [None]:
# Copyright 2025 Marc-Antoine Ruel
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from matplotlib import pyplot as plt

In [None]:
def generate_ring_points(num_points, seed, grid_size=1000, inner_diameter=300, outer_diameter=400):
    """Creates tensor of points forming a ring on a grid."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    center = grid_size/2
    inner_radius, outer_radius = inner_diameter/2, outer_diameter/2    
    x = np.random.uniform(0, grid_size, num_points)
    y = np.random.uniform(0, grid_size, num_points)
    # Calculate distances from center.
    distances = np.sqrt((x - center)**2 + (y - center)**2)    
    # Create ring indicator (1 if in ring, 0 otherwise)
    in_ring = np.where((distances >= inner_radius) & (distances <= outer_radius), 1, 0)
    # Stack coordinates and ring indicator.
    return torch.FloatTensor(np.column_stack((x, y, in_ring)))


def train_model(model, data, epochs=1000, lr=0.01):
    # Extract features and labels.
    X = data[:, :2]  # First two columns (x, y coordinates)
    y = data[:, 2:3]  # Third column (ring indicator)
    # Define loss function and optimizer
    #criterion = nn.BCELoss()  # Binary Cross Entropy Loss
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # Lists to store loss values for plotting
    losses = []
    for epoch in range(epochs):
        # Forward pass
        outputs = model(X)
        loss = criterion(outputs, y)
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    return losses


def evaluate_model(model, data):
    # Extract features and labels.
    expected = data[:, 2:3]
    criterion = nn.MSELoss()
    with torch.no_grad():
        outputs = model(data[:, :2])
        #accuracy = (outputs.round() == expected).sum().item() / y.size(0)
        #print(f'Accuracy: {accuracy:.4f}')
        test_loss = criterion(outputs, expected)
        accuracy = test_loss.item()
        print(f'Test Loss: {accuracy:.4f}')
    return accuracy


def visualize_results(model, data, losses):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    # Plot loss
    ax1.plot(losses)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    # Plot decision boundary.
    # Create a grid of points.
    if True:
        grid_size = 1000
        xx, yy = np.meshgrid(np.linspace(0, grid_size, 100), np.linspace(0, grid_size, 100))
        grid_points = np.c_[xx.ravel(), yy.ravel()]
        # Evaluate model on grid points
        with torch.no_grad():
            Z = model(torch.FloatTensor(grid_points)).numpy().reshape(xx.shape)
        # Plot decision boundary
        ax2.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=0.6)
        # Plot data points
        points = data[:, :2].numpy()
        value = data[:, 2].numpy()
        ax2.scatter(points[value==0, 0], points[value==0, 1], c='red', label='Outside Ring')
        ax2.scatter(points[value==1, 0], points[value==1, 1], c='blue', label='Inside Ring')
        ax2.set_title('Decision Boundary')
        ax2.set_xlabel('X')
        ax2.set_ylabel('Y')
        ax2.legend()
        ax2.set_aspect('equal')
    else:
        grid_size = 1000
        xx, yy = np.meshgrid(np.linspace(0, grid_size, 100), np.linspace(0, grid_size, 100))
        grid_points = np.c_[xx.ravel(), yy.ravel()]
        # Evaluate model on grid points
        with torch.no_grad():
            Z = model(torch.FloatTensor(grid_points)).numpy().reshape(xx.shape)
        # Plot decision boundary
        ax2.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=0.6)
        # Plot data points
        points = data[:, :2].numpy()
        value = data[:, 2].numpy()
        cbar = plt.colorbar(ax2, label='Value (0 to 1)')
        ax2.scatter(points[:, 0], points[:, 1], c=value, cmap='viridis') #, s=20)
        ax2.set_title('Decision Boundary')
        ax2.set_xlabel('X')
        ax2.set_ylabel('Y')
        ax2.legend()
        ax2.set_aspect('equal')
    plt.tight_layout()
    plt.show()


def one_run(model):
    data = generate_ring_points(num_points=6000, seed=42)
    losses = train_model(model, data, epochs=1000, lr=0.01)
    accuracy = evaluate_model(model, data)
    test_data = generate_ring_points(num_points=500, seed=100)
    print("Test data evaluation:")
    test_accuracy = evaluate_model(model, test_data)
    visualize_results(model, data, losses)

In [None]:
# Inspired by https://github.com/YihongDong/FANformer/blob/main/olmo/model.py#L77-L137 but mostly rewritten.
# License unclear. That said, code is trivial.

class FANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, p_ratio=0.25, activation=None):
        """FANLayer: The layer used in FAN (https://arxiv.org/abs/2410.02675).
    
        Args:
            input_dim (int): The number of input features.
            output_dim (int): The number of output features.
            p_ratio (float): The ratio of output dimensions used for cosine and sine parts (default: 0.25).
            activation (callable): The activation function to apply to the g component.
        """
        super(FANLayer, self).__init__()
        assert 0 <= p_ratio <= 0.5, "p_ratio must be between 0 and 0.5"
        p_output_dim = int(output_dim * p_ratio)
        # Account for cosine and sine terms.
        g_output_dim = output_dim - p_output_dim * 2
        self.input_linear = nn.Linear(input_dim, p_output_dim+g_output_dim)
        self.fused_dims = (p_output_dim, g_output_dim)
        # With f.relu or F.gelu, it's unable to learn negative values.
        self.activation = activation or (lambda x: x)

    def forward(self, x):
        p, g = self.input_linear(x).split(self.fused_dims, dim=-1)
        # Concatenate cos(p), sin(p), and activated g along the last dimension.
        return torch.cat((torch.cos(p), torch.sin(p), self.activation(g)), dim=-1)


class FAN(nn.Module):
    """Neural network using FAN."""
    def __init__(self, input_dim=2, hidden_layers=1, hidden_size=64, output_dim=1, p_ratio=0.25, activation=None, dropout_rate=0.2):
        super(FAN, self).__init__()
        layers = [
            FANLayer(input_dim, hidden_size, p_ratio, activation),
        ]
        for i in range(hidden_layers):
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            layers.append(FANLayer(hidden_size, hidden_size, p_ratio, activation))
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))
        layers.append(FANLayer(hidden_size, output_dim))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

In [None]:
class SimpleNetwork(nn.Module):
    def __init__(self, input_dim=2, hidden_layers=1, hidden_size=64, output_dim=1, dropout_rate=0.2):
        super(SimpleNetwork, self).__init__()
        layers = [
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
        ]
        for i in range(hidden_layers):
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
        if dropout_rate > 0:
            layers.append(nn.Dropout(dropout_rate))
        layers.append(nn.Linear(hidden_size, output_dim))
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [None]:
model1 = SimpleNetwork(hidden_layers=1, hidden_size=6, dropout_rate=0)
one_run(model1)

In [None]:
model2 = FAN(hidden_layers=1, hidden_size=6, dropout_rate=0)
one_run(model2)