In [26]:
# Import standard libraries
import jax
import haiku as hk
import haiku.experimental.flax as hkflax
import jax.numpy as jnp
from src.utils.containers import Graph
from src.mlip.message_passing import MessagePassing
from src.run.base import param_count

key = jax.random.PRNGKey(0)

# Initialize message passing
message_passing = MessagePassing(
    features=4,
    max_degree=0,
    num_iterations=1,
    num_basis_functions=4,
    cutoff=3.0,
    max_atomic_number=9
)

# Define energy and forces function 
@hk.without_apply_rng
@hk.transform
def energy_and_forces(graph: Graph):
    mod = hkflax.lift(message_passing, name='e3x_mlip')
    return mod(graph.features, graph.positions)


# Create graph 
graph = Graph(
    features=jnp.ones((1, 1, 3), dtype=jnp.uint8),
    positions= jax.random.uniform(jax.random.PRNGKey(0), (3, 3))
    )

params = energy_and_forces.init(key, graph)
energy, forces = energy_and_forces.apply(params, graph)



In [27]:
param_count(params)

114