Generate rbf output for each of the different RBF implementations and visualize the output

In [None]:
import torch
from modelforge.potential.representation import PhysNetRadialBasisFunction, AniRadialBasisFunction, SchnetRadialBasisFunction


In [2]:
# Test parameters
distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) / 10
number_of_radial_basis_functions = 100
max_distance = 2.0  / 10
min_distance = 0.0
dtype = torch.float32


In [None]:
# Define colors for each radial basis function
colors = ['blue', 'green', 'orange']

for idx, rbf_fn in enumerate([PhysNetRadialBasisFunction, AniRadialBasisFunction, SchnetRadialBasisFunction]):
    print(f"Testing {rbf_fn.__name__}")


    # Instantiate the RBF
    rbf = rbf_fn(
        number_of_radial_basis_functions=number_of_radial_basis_functions,
        max_distance=max_distance,
        min_distance=min_distance,
        dtype=dtype,
        trainable_centers_and_scale_factors=False,
    )

    # Get actual outputs
    actual_output = rbf(distances)

    import numpy as np
    import matplotlib.pyplot as plt
    rs = torch.tensor([[r] for r in np.linspace(0,0.2, number_of_radial_basis_functions)])
    for i in range(3):
        plt.plot(rs, actual_output[i].numpy(), color=colors[idx])
        # Draw the vertical line (axvline)
        plt.axvline(distances[i].numpy(), 0, 0.2, c='r')
    # Add the legend entry for the radial basis function once
    plt.plot([], [], color=colors[idx], label=f'{rbf_fn.__name__}')

plt.legend()
plt.show()
