In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import torch
from identity_net import IdentityNet, IdentityNetRes, IdentityNetNAC
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.linspace(-5, 5, 1000).unsqueeze(1)
y = x.clone()
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

def train_network(model, epochs=10_000, lr=1e-3):
    model = model.to(device)
    loss_fn = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            optimizer.step()
    return model.cpu()

activation_factories = [
    ("ReLU", nn.ReLU),
    ("ReLU6", nn.ReLU6),
    ("Softplus", nn.Softplus),
    ("Tanh", nn.Tanh),
    ("Sigmoid", nn.Sigmoid),
    ("ELU", nn.ELU),
    ("SiLU", nn.SiLU),
    ("Identity", nn.Identity),
    ("PReLU", nn.PReLU),
]

trained_models = []

for name, activation_cls in activation_factories:
    trained_models.append(("MLP", name, train_network(IdentityNet(activation_cls()))))
    trained_models.append(("Residual MLP", name, train_network(IdentityNetRes(activation_cls()))))

trained_models.append(("NAC stack", "NAC", train_network(IdentityNetNAC())))


In [None]:
import matplotlib.pyplot as plt

val = torch.linspace(-20, 20, 10_000).unsqueeze(1)
gt = val

plt.figure(figsize=(10, 6))
style_map = {"MLP": "-", "Residual MLP": "--", "NAC stack": ":"}
cmap = plt.cm.get_cmap("tab20", len(trained_models))

for idx, (architecture, activation_name, model) in enumerate(trained_models):
    out = model(val)
    mae = torch.abs(out - gt).clamp_min(1e-8)
    label = f"{architecture} ({activation_name})" if activation_name != "NAC" else "NAC"
    plt.plot(
        val.squeeze().numpy(),
        mae.detach().numpy(),
        linestyle=style_map.get(architecture, "-"),
        color=cmap(idx),
        label=label
    )

plt.yscale("log")
plt.ylim(bottom=1e-8)
plt.xlabel("Input value")
plt.ylabel("Mean Absolute Error (log scale)")
plt.title("MAE across architectures and activation functions")
plt.grid(True, which="both", linestyle=":", linewidth=0.5)
plt.legend(ncol=2, fontsize=8)
plt.tight_layout()
plt.show()
