In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import torch

from sonnix import modules as snn

In [None]:
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (16, 5)

In [None]:
random_seed = 15803713921897248048

In [None]:
def plot_activation(x: torch.Tensor, y: torch.Tensor) -> None:
    r"""Plot the activations generated by a module.

    Args:
        x: The inputs.
        y: The outputs.
    """
    _fig, ax = plt.subplots()
    ax.plot(x.detach().flatten().numpy(), y.detach().flatten().numpy())

In [None]:
x = torch.linspace(start=-10, end=10, steps=1001).view(-1, 1)

## Asinh

In [None]:
module = snn.Asinh()
y = module(x)
plot_activation(x, y)

## DynamicAsinh

In [None]:
fig, ax = plt.subplots()
for alpha in [0.1, 0.5, 1.0, 10.0]:
    module = snn.DynamicAsinh(normalized_shape=(1,), alpha_init_value=alpha)
    y = module(x)
    ax.plot(x.detach().flatten().numpy(), y.detach().flatten().numpy(), label=f"{alpha=}")
ax.legend()

## DynamicTanh

In [None]:
fig, ax = plt.subplots()
for alpha in [0.05, 0.1, 0.5, 1.0]:
    module = snn.DynamicTanh(normalized_shape=(1,), alpha_init_value=alpha)
    y = module(x)
    ax.plot(x.detach().flatten().numpy(), y.detach().flatten().numpy(), label=f"{alpha=}")
ax.legend()

## Exp

In [None]:
module = snn.Exp()
y = module(x)
plot_activation(x, y)

## ExpSin

In [None]:
module = snn.ExpSin()
y = module(x)
plot_activation(x, y)

## Expm1

In [None]:
module = snn.Expm1()
y = module(x)
plot_activation(x, y)

## Gaussian

In [None]:
module = snn.Gaussian()
y = module(x)
plot_activation(x, y)

## Laplacian

In [None]:
module = snn.Laplacian()
y = module(x)
plot_activation(x, y)

## Log

In [None]:
module = snn.Log()
y = module(x)
plot_activation(x, y)

## Log1p

In [None]:
module = snn.Log1p()
y = module(x)
plot_activation(x, y)

## MultiQuadratic

In [None]:
module = snn.MultiQuadratic()
y = module(x)
plot_activation(x, y)

## Quadratic

In [None]:
module = snn.Quadratic()
y = module(x)
plot_activation(x, y)

## RectifierAsinhUnit

In [None]:
module = snn.RectifierAsinhUnit()
y = module(x)
plot_activation(x, y)

## ReLUn

In [None]:
fig, ax = plt.subplots()
for m in [1, 2, 5]:
    module = snn.ReLUn(m)
    y = module(x)
    ax.plot(x.detach().flatten().numpy(), y.detach().flatten().numpy(), label=f"{m=}")
ax.legend()

## SafeExp

In [None]:
module = snn.SafeExp()
y = module(x)
plot_activation(x, y)

## SafeLog

In [None]:
module = snn.SafeLog()
y = module(x)
plot_activation(x, y)

## Sin

In [None]:
module = snn.Sin()
y = module(x)
plot_activation(x, y)

## Sinh

In [None]:
module = snn.Sinh()
y = module(x)
plot_activation(x, y)

## Snake

In [None]:
fig, ax = plt.subplots()
for freq in [0.5, 1, 2]:
    module = snn.Snake(freq)
    y = module(x)
    ax.plot(x.detach().flatten().numpy(), y.detach().flatten().numpy(), label=f"{freq=}")
ax.legend()

## SquaredReLU

In [None]:
module = snn.SquaredReLU()
y = module(x)
plot_activation(x, y)