In [3]:
# Tutorial: Introduction to jax_e3nn

# install jax_e3nn using pip:
# !pip install jax_e3nn
# !pip install jax
# !pip install haiku
# !pip install jraph
 

#Importing the necessary modules
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import haiku as hk

In e3nn we have a notation to define direct sums of irreducible representations of O(3) (rotations in 3D). 

In [11]:
# Define the irreps consisting of two scalars (0e) and one vector (1o)
irreps = e3nn.Irreps("2x0e + 1x1o")  # 0e stands for the even irrep L=0 and 1o stands for the odd irrep L=1
print(irreps)

2x0e+1x1o


In [205]:
# creat an IrrepsArray contains 3 entries, each corresponding to different values for the irreps
x = e3nn.IrrepsArray(irreps, jnp.array(
    [
        [1.0, 0.0,  0.0, 0.0, 0.0],
        [0.0, 1.0,  2.0, 0.0, 0.0],
        [0.0, 0.0,  0.0, 2.0, 0.0],
    ]
))
print(x)
print('shape:', x.shape)

2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 2. 0. 0.]
 [0. 0. 0. 2. 0.]]
shape: (3, 5)


In [35]:
# The first index of the IrrepsArray is the batch index
print(x[0], '\n')
# The second index is the index of the irreps
print(x[:, '2x0e'], '\n')
# we can also directly use the index of the irreps
print(x[:, 2:], '\n')
# An error is raised if the irreps is not present in the IrrepsArray
try:
    print(x[:, '2o'], '\n')
except Exception as e:
    print('Error:', e)

2x0e+1x1o [1. 0. 0. 0. 0.] 

2x0e
[[1. 0.]
 [0. 1.]
 [0. 0.]] 

1x1o
[[0. 0. 0.]
 [2. 0. 0.]
 [0. 2. 0.]] 

Error: Error in IrrepsArray.__getitem__, Can't slice with 1x2o because it doesn't appear exactly once in 2x0e+1x1o.


### Operations for irrepsArray

In [207]:
# add two IrrepsArray
print('add two IrrepsArray')
print(x + x, '\n')
# multiply by a scalar
print('multiply by a scalar')
print(3 * x, '\n')
# divide by a scalar
print('divide by a scalar')
print(x/2, '\n')
# sum all the entries
print('sum entries by axis 0')
print(e3nn.sum(x, axis = 1), '\n')
print('sum entries by axis 1')
print(e3nn.sum(x, axis = 0), '\n')
# cancatenate two IrrepsArray
print('concatenate two IrrepsArray')
z = e3nn.concatenate([x, x], axis=1)
print(z.irreps, '\n')
# simplify the IrrepsArray
print('simply IrrepsArray')
print(z.regroup().irreps)


add two IrrepsArray
2x0e+1x1o
[[2. 0. 0. 0. 0.]
 [0. 2. 4. 0. 0.]
 [0. 0. 0. 4. 0.]] 

multiply by a scalar
2x0e+1x1o
[[3. 0. 0. 0. 0.]
 [0. 3. 6. 0. 0.]
 [0. 0. 0. 6. 0.]] 

divide by a scalar
2x0e+1x1o
[[0.5 0.  0.  0.  0. ]
 [0.  0.5 1.  0.  0. ]
 [0.  0.  0.  1.  0. ]] 

sum entries by axis 0
1x0e+1x1o
[[1. 0. 0. 0.]
 [1. 2. 0. 0.]
 [0. 0. 2. 0.]] 

sum entries by axis 1
2x0e+1x1o [1. 1. 2. 2. 0.] 

concatenate two IrrepsArray
2x0e+1x1o+2x0e+1x1o 

simply IrrepsArray
4x0e+2x1o


#### Operations Not Allowed for IrrepsArray

In [67]:
# add IrrepsArray with mismatched irreps
y = y = e3nn.IrrepsArray("0o + 2x0e", jnp.array(
    [
        [1.5,  0.0, 1.0],
        [0.5, -1.0, 2.0],
        [0.5,  1.0, 1.5],
    ]
))

try:
    print(x + y)
except Exception as e:
    print(e, '\n')

# Anon-scalar division
try:
    print(x / x)
except Exception as e:
    print(e, '\n')

# multiply two IrrepsArray element-wise
try:
    print(x * x)
except Exception as e:
    print(e, '\n')

# let's try e3nn.tensor_product. 
print(x.irreps, x.shape)
print(y.irreps, y.shape)
print(e3nn.tensor_product(x, y).irreps, e3nn.tensor_product(x, x).shape)


IrrepsArray(2x0e+1x1o, shape=(3, 5)) + IrrepsArray(1x0o+2x0e) is not equivariant. 

IrrepsArray(2x0e+1x1o, shape=(3, 5)) / IrrepsArray(2x0e+1x1o) is not equivariant. 

IrrepsArray(2x0e+1x1o, shape=(3, 5)) * IrrepsArray(2x0e+1x1o) is only supported for scalar * irreps and irreps * scalar. To perform irreps * irreps use e3nn.elementwise_tensor_product or e3nn.tensor_product. 

2x0e+1x1o (3, 5)
1x0o+2x0e (3, 3)
4x0e+2x0o+2x1o+1x1e (3, 25)


### Spherical Harmonics

In [113]:
# creat a batch of vectors
vectors = jnp.array([
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
])

vectors = e3nn.IrrepsArray("1o", vectors)  # Wrap the array of vectors in an IrrepsArray
e3nn.spherical_harmonics(
                [l for l in range(1, 4)], # List of l values from 1 to 3
                vectors,
                normalize=True,)  # the input is normalized

1x1o+1x2e+1x3o
[[ 1.7320508  0.         0.         0.         0.        -1.1180339
   0.        -1.9364916 -2.09165    0.        -1.620185   0.
   0.         0.         0.       ]
 [ 0.         1.7320508  0.         0.         0.         2.2360678
   0.         0.         0.         0.         0.         2.645751
   0.         0.         0.       ]
 [ 0.         0.         1.7320508  0.         0.        -1.1180339
   0.         1.9364916  0.         0.         0.         0.
  -1.620185   0.         2.09165  ]]

### A minimal example of using Haiku with e3nn

In [202]:
@hk.without_apply_rng
@hk.transform
def model(x):
    # Assuming "1x0e+1x1o" is the output irreps of the linear layer
    linear = e3nn.haiku.Linear(irreps_out = "3x0e+1x1o", biases=True)
    return e3nn.gate(linear(x))

x = e3nn.IrrepsArray("1o + 2x0e + 1x0o", jnp.ones(6))
print('input irrep:', x.irreps)
params = model.init(jax.random.PRNGKey(0), x)  # Initialize model parameters using a random key.
print('\n', params)

print('parameters for 0e: \n', params['linear']['w[0,0] 2x0e,3x0e']) 
print('parameters for 1o: \n', params['linear']['w[2,1] 1x1o,1x1o'])

input irrep: 1x1o+2x0e+1x0o

 {'linear': {'w[0,0] 2x0e,3x0e': DeviceArray([[-1.4581939, -2.047044 , -1.4242861],
             [ 1.1684095, -0.9758364, -1.2718494]], dtype=float32), 'w[2,1] 1x1o,1x1o': DeviceArray([[-0.58665055]], dtype=float32), 'b[0] 3x0e': DeviceArray([0., 0., 0.], dtype=float32)}}
parameters for 0e: 
 [[-1.4581939 -2.047044  -1.4242861]
 [ 1.1684095 -0.9758364 -1.2718494]]
parameters for 1o: 
 [[-0.58665055]]


In [185]:
print('input irrep:', x.irreps)
y = model.apply(params,  x)
print('input irrep:', y.irreps) 
# the output irrep of the linear is "3x0e + 1x1o";
# the last 0e scalar is used as gate to scale the 1o in gate function
# the 1o irrep is automatically discarded in the output

input irrep: 1x1o+2x0e+1x0o
input irrep: 2x0e+1x1o


### Building a GNN with e3nn

In [256]:
# some helper functions
import jraph
from matscipy.neighbours import neighbour_list

cutoff = 2.0  

def compute_edges(positions, cell, cutoff):
    """Compute edges of the graph from positions and cell."""
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=jnp.array([True, True, True]),
        cell=cell,
        positions=positions,
        cutoff=cutoff,
    )

    num_edges = senders.shape[0]
    assert senders.shape == (num_edges,)
    assert receivers.shape == (num_edges,)
    assert senders_unit_shifts.shape == (num_edges, 3)
    return senders, receivers, senders_unit_shifts

def create_graph(positions, cell, cutoff):
    """Create a graph from positions, cell, and energy."""
    senders, receivers, senders_unit_shifts = compute_edges(positions, cell, cutoff) 
    num_nodes = positions.shape[0]
    num_edges = senders.shape[0]

    graph = jraph.GraphsTuple(
        # positions are per-node features:
        nodes=dict(positions=positions),
        # Unit shifts are per-edge features:
        edges=dict(shifts=senders_unit_shifts),
        # energy and cell are per-graph features:
        senders=senders,
        receivers=receivers,
        globals=dict(cell=cell[None, :, :]),
        n_node=jnp.array([num_nodes]),
        n_edge=jnp.array([num_edges]),
    )
    return graph

def get_relative_vectors(senders, receivers, n_edge, positions, cells, shifts):
    """Compute the relative vectors between the senders and receivers."""
    num_nodes = positions.shape[0]
    num_edges = senders.shape[0]
    num_graphs = n_edge.shape[0]

    assert positions.shape == (num_nodes, 3)
    #assert cells.shape == (num_graphs, 3, 3)
    assert senders.shape == (num_edges,)
    assert receivers.shape == (num_edges,)
    assert shifts.shape == (num_edges, 3)

    # We need to repeat the cells for each edge.
    cells = jnp.repeat(cells, n_edge, axis=0, total_repeat_length=num_edges)

    # Compute the two ends of each edge.
    positions_receivers = positions[receivers]
    positions_senders = positions[senders] + jnp.einsum("ei,eij->ej", shifts, cells)

    vectors = e3nn.IrrepsArray("1o", positions_receivers - positions_senders)
    return vectors



In [271]:
# prepare two molecules

molecule0 = create_graph(
    positions=jnp.array(
        [
            [-0.0, 1.44528, 0.26183],
            [1.25165, 0.72264, 2.34632],
            [1.25165, 0.72264, 3.90714],
            [-0.0, 1.44528, 1.82265],
        ]
    ),
    cell=jnp.eye(3)*4,
    cutoff=cutoff,
)
print(f"molecule0 has {molecule0.n_node} nodes and {molecule0.n_edge} edges")


molecule1 = create_graph(
    positions=jnp.array(
        [
            [0.0, 0.0, 1.78037],
            [0.89019, 0.89019, 2.67056],
            [0.0, 1.78037, 0.0],
            [0.89019, 2.67056, 0.89019],
            [1.78037, 0.0, 0.0],
            [2.67056, 0.89019, 0.89019],
            [1.78037, 1.78037, 1.78037],
            [2.67056, 2.67056, 2.67056],
        ]
    ),
    cell=jnp.eye(3)*5,
    cutoff=cutoff,
)
print(f"molecule1 has {molecule1.n_node} nodes and {molecule1.n_edge} edges")
print('sender:', molecule1.senders)
print('receiver:', molecule1.receivers)

molecule0 has [4] nodes and [8] edges
molecule1 has [8] nodes and [14] edges
sender: [1 0 6 3 2 6 5 4 6 5 3 1 7 6]
receiver: [0 1 1 2 3 3 4 5 5 6 6 6 6 7]


In [272]:
dataset = jraph.batch([molecule0, molecule1])
print(f"dataset has {dataset.n_node} nodes and {dataset.n_edge} edges")

# Print the shapes of the fields of the dataset.
print(jax.tree_util.tree_map(jnp.shape, dataset))

dataset has [4 8] nodes and [ 8 14] edges
GraphsTuple(nodes={'positions': (12, 3)}, edges={'shifts': (22, 3)}, receivers=(22,), senders=(22,), globals={'cell': (2, 3, 3)}, n_node=(2,), n_edge=(2,))


In [296]:
class EGNNLayer(hk.Module):
    def __init__(
        self,
        avg_num_neighbors: float,
        max_ell: int = 3,
        output_irreps: e3nn.Irreps = e3nn.Irreps("0e + 1o + 2e"),
    ):
        super().__init__( )
        self.avg_num_neighbors = avg_num_neighbors
        self.max_ell = max_ell
        self.output_irreps = output_irreps

    def __call__(
        self,
        vectors: e3nn.IrrepsArray,
        node_feats: e3nn.IrrepsArray,
        senders: jnp.ndarray,
        receivers: jnp.ndarray,
    ):
        node_feats = e3nn.as_irreps_array(node_feats)
        num_nodes = node_feats.shape[0]
        num_edges = vectors.shape[0]
        assert vectors.shape == (num_edges, 3)
        assert node_feats.shape == (num_nodes, node_feats.irreps.dim)
        assert senders.shape == (num_edges,)
        assert receivers.shape == (num_edges,)

        # we regroup the target irreps to make sure that gate activation
        # has the same irreps as the target
        output_irreps = e3nn.Irreps(self.output_irreps).regroup()

        messages = e3nn.haiku.Linear(node_feats.irreps)(node_feats)[senders] 
        messages = e3nn.concatenate(
            [
                messages.filter(output_irreps + "0e"), # keep information of original features
                e3nn.tensor_product(                   # compute the spherical harmonics
                    messages,           
                    e3nn.spherical_harmonics(
                        [l for l in range(1, self.max_ell + 1)],
                        vectors,
                        normalize=True,
                        normalization="component",
                    ),
                    filter_ir_out=output_irreps + "0e", # make sure the output has the same irreps as the target
                ),
            ]
        ).regroup()

        assert messages.shape == (num_edges, messages.irreps.dim)

        # Discard 0 length edges that come from graph padding
        lengths = e3nn.norm(vectors).array
        mask = jnp.where(lengths == 0.0, 0.0, 1)
        messages = messages * mask
        
        # get the irreps of the messages
        irreps = output_irreps.filter(keep=messages.irreps)
        num_nonscalar = irreps.filter(drop="0e + 0o").num_irreps # the number of non-scalar irreps
        irreps = irreps + e3nn.Irreps(f"{num_nonscalar}x0e").simplify() # add scalar irreps 

        # Message passing
        node_feats = e3nn.scatter_sum(messages, dst=receivers, output_size=num_nodes) #output[dst[i]] += data[i]
        node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors)

        node_feats = e3nn.haiku.Linear(irreps, name="linear_down")(node_feats)

        node_feats = node_feats 
        assert node_feats.shape == (num_nodes, node_feats.irreps.dim)

        node_feats = e3nn.gate(node_feats)

        return node_feats
        

In [298]:
class Model(hk.Module):
    def __init__(self, irreps_out=e3nn.Irreps("4x0e +4x0o + 2x1o + 2x1e")):
        super().__init__()
        self.irreps_out = irreps_out
    def __call__(self, graphs):
        vectors = get_relative_vectors(
            graphs.senders,
            graphs.receivers,
            graphs.n_edge,
            graphs.nodes["positions"],
            graphs.globals["cell"],
            graphs.edges["shifts"],
        )
        num_nodes = jnp.sum(graphs.n_node)
        node_feats = e3nn.IrrepsArray("0e", jnp.ones((num_nodes, 1)))
        for _ in range(2):
            node_feats = EGNNLayer(avg_num_neighbors=4, output_irreps=self.irreps_out)(vectors, node_feats, graphs.senders, graphs.receivers)
        node_feats = e3nn.haiku.Linear("0e")(node_feats)
        return e3nn.scatter_sum(node_feats, nel=graphs.n_node)
    
def get_model(irreps_out=e3nn.Irreps("4x0e +4x0o + 2x1o + 2x1e")):
    def model(graphs):
        return  Model(irreps_out=irreps_out)(graphs)
    return model 

# Create the model and initialize its parameters.
egnn_model = get_model()
egnn_model = hk.without_apply_rng(hk.transform(egnn_model))
egnn_params = egnn_model.init(jax.random.PRNGKey(0), dataset)  # Initialize model parameters using a random key.


In [299]:
# Apply the model to the dataset.
model_output = egnn_model.apply(egnn_params, dataset)
print(model_output)

1x0e
[[1.6778514]
 [2.227569 ]]


In [None]:
# Test quivarient
# Under consturction....