In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


# Neural Field MLP: (x, y) -> f(x, y)
class NeuralFieldMLP(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=64, output_dim=1):
        super(NeuralFieldMLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)


# Training function
def train_neural_field(model, optimizer, criterion, num_epochs=5000):
    # Training data: Sampling 2D points and their function values
    x_train = np.random.uniform(-1, 1, (1000, 2))
    y_train = np.sin(np.pi * x_train[:, 0]) * np.cos(np.pi * x_train[:, 1])

    x_train = torch.tensor(x_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        y_pred = model(x_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")


# Visualization function
def plot_results(model):
    x_vals = np.linspace(-1, 1, 100)
    y_vals = np.linspace(-1, 1, 100)
    X, Y = np.meshgrid(x_vals, y_vals)

    xy_grid = np.stack([X.flatten(), Y.flatten()], axis=-1)
    xy_tensor = torch.tensor(xy_grid, dtype=torch.float32)

    with torch.no_grad():
        Z = model(xy_tensor).numpy().reshape(100, 100)

    plt.figure(figsize=(6, 5))
    plt.contourf(X, Y, Z, levels=50, cmap="coolwarm")
    plt.colorbar(label="Function Value")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title("Learned Neural Field Approximation")
    plt.show()


# Main execution
model = NeuralFieldMLP()
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

train_neural_field(model, optimizer, criterion)
plot_results(model)
