<a href="https://colab.research.google.com/github/chrisjelliott/Equivariant_NNs_via_invariant_theory/blob/main/Equivariant_NN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Equivariant Neural Networks for General Representations**

In this notebook I'm going to implement some neural network models based on "E(n) Equivariant Graph Neural Networks'' by Satorras, Hoogeboom and Welling https://arxiv.org/pdf/2102.09844 . I'll describe a generalization to the following situation.

We'll describe an equivariant graph neural network associated to a graph $\Gamma = (V,E)$ and a Lie group $G$.  So associated to each vertex $v_i$ of the graph we will have a variable $h_i \in W_V$ where $W_V$ is a linear $G$-representation, and to each edge $e_{ij}$ we will have a variable $a_{ij} \in W_E$ where $W_E$ is a linear $G$-representation.  Our model will learn $G$-equivariant functions with output encoded similarly by a graph for some new representations $W_V^{\mathrm{out}}, W_E^{\mathrm{out}}$.

**Note:**

In the paper of Satorras et al the vertex representation takes the form $\mathbb R^n \times W$ where $G=E(n)$ acts by isometries on the first factor and trivially on the second factor.  The isometry action is affine, not linear, but we can turn it into a linear $G$-action on $\mathbb R^{n+1}$.

# Equivariant Layer Structure

Let's describe a single layer using a generalization of Satorras et al's equivariant message passing layer.  Our approach is to take as input a suitable set of $G$-equivariant polynomial functions, and build functions from linear combinations of these functions.

We'll start with the message function, associated to a single edge $e_{ij}$ in the graph.  We will build equivariant functions
$$F \in M_I = (\mathbb R[W_V^2 \times W_E] \otimes W_I)^G$$
where $W_I$ is some intermediate $G$-representation.  This $M_I$ is a module over the algebra of invariant functions $R_I =  \mathbb R[W_V^2 \times W_E]^G$.

We make the assumption that $R_I$ is a fininitely generated algebra with basis $f_1, \ldots, f_n$ and that $M_I$ is finitely generated as an $R_I$-module with basis $\mu_1, \ldots, \mu_N$ (for instance, this is guaranteed if $G$ is reductive).  We consider the following set of equivariant functions:
$$\{F^\alpha \colon \alpha \in A\} = \{\mu_l\} \cup \{f_k \cdot \mu_l\}.$$
Note that this set might not be linearly independent, there may be linear relations (syzygys) between the elements, leading to some potential redundancy in functions represented as linear combinations.

The interpretation here is that we are generalizing the set of affine functions (sums of linear functions and constant functions) between vector spaces by including lowest order and next-to-lowest order generators.

So we can now define the possible message functions.  These will take the form
$$m_{ij} = \sigma_I \left(\sum_{\alpha \in A} a_\alpha f^\alpha(h_i, h_j, a_{ij})  \right)$$
for some learnable coefficients $a_\alpha$, and some pointwise activation function $\sigma_I$.

We can use these message functions to update the vertex and edge variables.  We will again construct sets of invariant functions
\begin{align}
g^\beta &\in (\mathbb R[W_V \times W_I] \otimes W_V^{\mathrm{out}})^G \\
k^\gamma &\in (\mathbb R[W_E \times W_I] \otimes W_E^{\mathrm{out}})^G
\end{align}
in exactly the same way.  If we choose another activation function $\sigma$ then the updated vertex and edge variables are given as follows:
\begin{align}
h_i^{\mathrm{out}} &= \sigma \left(b_\beta g^\beta\left(h_i, \sum_{v_j \in N(v_i)} m_{ij} \right) \right) \\
a_{ij}^{\mathrm{out}} &= \sigma \left(c_\gamma k^\gamma\left(a_{ij}, \sum_{v_\ell \in N(v_i)} m_{i\ell} + \sum_{v_\ell \in N(v_j)} m_{j\ell}  \right) \right)
\end{align}
where again $b_\beta, c_\gamma$ are learnable weights, and where we write $N(v_i)$ for the neighborhood of vertex $v_i$ in the graph $\Gamma$.

#Example 1:

We can check that if $G$ is trivial then we recover a usual graph neural network architecture.  Indeed, in the trivial case, when we study functions $W_1 \to W_2$, the generators in our model are given as follows.

*   Generators of $W_2$ as an $\mathbb R[W_1]$-module -- basis vectors $e^{(2)}_i$
*   The product of generators of $W_2$ with algebra generators of $\mathbb R[W_1]$ -- tensors of the form $(e^{(1)}_j)^* \otimes e^{(2)}_i$.  In other words, matrix elements in $W_1^* \otimes W_2$.

Linear combinations of the first type of element generate constant functions $W_1 \to W_2$, and linear combinations of the second type of element generate linear functions $W_1 \to W_2$.  So altogether when we take arbitrary linear combinations in our model we are just considering the set of affine functions.

#Example 2:

Let's consider the example of Satorras, Hoogeboom and Welling.  So let $G = E(n)$, let $W_E$ be a trivial representation, and let $W_V = \mathbb R^{n+1} \times U$ where $U$ is again a trivial representation, and $E(n)$ acts on $\mathbb R^{n+1} = \mathbb R^n \times \mathbb R$ by affine transformations:
$$(C,b) \cdot (v, t) = (Cv + tb, t).$$
We obtain the usual affine action on $\mathbb R^n$ by restricting to the hyperplane $t=1$, and we will restrict attention to $E(n)$-equivariant functions that preserve this hyperplane.

**One Spatial Input**

Let us start by analyzing the equivariant functions of the form
$$F \colon \mathbb R^{n+1} \times U_1 \to \mathbb R^{n+1} \times U_2$$
where $U_1, U_2$ are trivial representations.  So according to our procedure we will need to compute algebra generators for $$A = \mathbb R[\mathbb R^{n+1} \times U_1]^{E(n)}$$ and module generators for $$M = (\mathbb R[\mathbb R^{n+1} \times U_1] \otimes (\mathbb R^{n+1} \times U_2))^{E(n)}.$$  In each case we will restrict to those functions that preserve the $(n+1)^{\text{st}}$ coordinate in $\mathbb R^{n+1}$.

In the first case $A \cong \mathbb R[\mathbb R^{n+1}]^{E(n)} \otimes \mathbb R[U_1]$, and the only generators that preserve the final coordinate are constant in the first factor, so we have generators associated to basis vectors in $U_1$.

In the latter case, $M$ is generated as a module by constant functions to $U_2$ together with the identity function $\mathbb R^{n+1} \to \mathbb R^{n+1}$.  So, altogether, the set of zeroth and first order generating functions can be identified with $$\{f^\alpha\} \cong \langle \mathrm{id}\rangle \oplus U_2 \oplus (U_1^* \otimes \langle\mathrm{id}\rangle) \oplus (U_1^* \otimes U_2).$$

**Two Spatial Inputs**

Finally, associated to the intermediate term we have a variant of this computation.  We need to compute algebra generators for $$A = \mathbb R[(\mathbb R^{n+1})^2 \times U_1]^{E(n)}$$ and module generators for $$M = (\mathbb R[(\mathbb R^{n+1})^2 \times U_1] \otimes (\mathbb R^{n+1} \times U_2))^{E(n)},$$
again preserving the hyperplane where the final coordinate in $\mathbb R^{n+1}$ is equal to one.

Let us use the notation $((x_1, t_1), (x_2, t_2)) \in (\mathbb R^{n+1})^2$. The algebra $A$ now has generators associated to basis vectors in $U_1$, but in addition we have a quadratic generator of the form $\|x_1 - x_2\|^2$.  The module $M$ still has module generators associated to constant functions to $U_2$, but rather than the identity there are now two additional generators associated to the projections $\pi_1, \pi_2$ onto the two factors of $(\mathbb R^{n+1})^2$.  So now, altogether, the set of zeroth and first order generating functions can be identified with
$$\{f^\alpha\} \cong \langle \pi_1, \pi_2 \rangle \oplus U_2 \oplus (U_1^* \otimes \langle \pi_1, \pi_2 \rangle) \oplus (\langle \|x_1 - x_2 \|^2 \rangle \otimes \langle \pi_1, \pi_2 \rangle) \oplus (U_1^* \otimes U_2) \oplus (U_1^* \otimes \langle \|x_1 - x_2 \|^2 \rangle).$$
In order to preserve the hyperplane we will need to restrict attention to those linear combinations $a_1 \pi_1 + a_2 \pi_2$ where $a_1 + a_2 = 1$.



#Implementation

Let's go ahead and implement the equivariant graph convolution layer following this procedure.  Note that Satorras et al actually allow their message passing and vertex update terms to contain two layers: linear -> activation -> linear -> activation (where the second activation may be constant).  I will include this behaviour as an option if desired so that we can compare the results with one and two layers.

In [6]:
import torch
from torch import nn
from typing import Dict
import torch.nn.functional as F
import torch.cuda as cuda
import time

class EGCL(nn.Module):
    def __init__(
        self,
        num_vertices: int,
        adj_matrix: torch.Tensor,
        vertex_inputs: int,
        edge_inputs: int,
        vertex_outputs: int,
        edge_outputs: int,
        inter_vars: int,
        inter_invt_funs: list,
        vertex_invt_funs: list,
        edge_invt_funs: list,
        inter_activation: callable,
        vertex_activation: callable,
        edge_activation: callable,
        double_layer: bool = False,
        inter_vars_2: int = None,
        vertex_hidden: int = None,
        edge_hidden: int = None,
        inter_activation_2: callable = None,
        vertex_activation_2: callable = None,
        edge_activation_2: callable = None,
        is_affine: bool = False,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        """
        Equivariant Graph Convolution Layer with optional double layer processing.

        Parameters:
        -----------
        num_vertices : int
            Number of vertices in the graph
        adj_matrix : torch.Tensor
            Adjacency matrix of the graph (shape: [num_vertices, num_vertices])
        vertex_inputs : int
            Dimension of vertex input features
        edge_inputs : int
            Dimension of edge input features
        vertex_outputs : int
            Dimension of vertex output features
        edge_outputs : int
            Dimension of edge output features
        inter_vars : int
            Dimension of intermediate message features
        inter_invt_funs : list
            List of invariant functions for message passing
        vertex_invt_funs : list
            List of invariant functions for vertex updates
        edge_invt_funs : list
            List of invariant functions for edge updates
        inter_activation : callable
            Activation function for intermediate computations
        vertex_activation : callable
            Activation function for vertex updates
        edge_activation : callable
            Activation function for edge updates
        double_layer : bool
            Whether to use double layer processing (default: False)
        inter_vars_2 : int
            Dimension of second intermediate layer (default: same as inter_vars)
        vertex_hidden : int
            Dimension of hidden layer in vertex update (default: same as vertex_outputs)
        edge_hidden : int
            Dimension of hidden layer in edge update (default: same as edge_outputs)
        inter_activation_2 : callable
            Second activation for message passing (default: same as inter_activation)
        vertex_activation_2 : callable
            Second activation for vertex update (default: same as vertex_activation)
        edge_activation_2 : callable
            Second activation for edge update (default: same as edge_activation)
        is_affine : bool
            Whether to normalize outputs to preserve affine transformations
        device : str
            Device to run computations on
        """
        super(EGCL, self).__init__()

        # Validate inputs
        if not isinstance(adj_matrix, torch.Tensor):
            adj_matrix = torch.tensor(adj_matrix, dtype=torch.float32)
        if adj_matrix.shape != (num_vertices, num_vertices):
            raise ValueError(f"adj_matrix shape {adj_matrix.shape} doesn't match num_vertices {num_vertices}")

        # Store basic parameters
        self.device = device
        self.num_vertices = num_vertices
        self.adj_matrix = adj_matrix.to(device)
        self.vertex_inputs = vertex_inputs
        self.edge_inputs = edge_inputs
        self.vertex_outputs = vertex_outputs
        self.edge_outputs = edge_outputs
        self.inter_vars = inter_vars
        self.inter_invt_funs = inter_invt_funs
        self.vertex_invt_funs = vertex_invt_funs
        self.edge_invt_funs = edge_invt_funs
        self.inter_activation = inter_activation
        self.vertex_activation = vertex_activation
        self.edge_activation = edge_activation
        self.is_affine = is_affine

        self.num_inter_invts = len(inter_invt_funs)
        self.num_vertex_invts = len(vertex_invt_funs)
        self.num_edge_invts = len(edge_invt_funs)

        self.timings = {}  # For storing timing information

        # Initialize first layer weights
        n_in = 2*vertex_inputs + edge_inputs  # Total input dimension
        n_out = inter_vars                    # Output dimension
        std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

        self.inter_weights = nn.Parameter(
            torch.randn(self.num_inter_invts, device=device) * std
        )

        # For vertex update weights
        n_in = vertex_inputs + inter_vars     # Input vertex features + message features
        n_out = vertex_outputs
        std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

        self.vertex_weights = nn.Parameter(
            torch.randn(self.num_vertex_invts, device=device) * std
        )

        # For edge update weights
        n_in = edge_inputs + 2*inter_vars     # Edge features + sum of messages
        n_out = edge_outputs
        std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

        self.edge_weights = nn.Parameter(
            torch.randn(self.num_edge_invts, device=device) * std
        )

        # Handle double layer parameters
        self.double_layer = double_layer
        if double_layer:
            self.inter_vars_2 = inter_vars_2 if inter_vars_2 is not None else inter_vars
            self.vertex_hidden = vertex_hidden if vertex_hidden is not None else vertex_outputs
            self.edge_hidden = edge_hidden if edge_hidden is not None else edge_outputs
            self.inter_activation_2 = inter_activation_2 if inter_activation_2 is not None else inter_activation
            self.vertex_activation_2 = vertex_activation_2 if vertex_activation_2 is not None else vertex_activation
            self.edge_activation_2 = edge_activation_2 if edge_activation_2 is not None else edge_activation

            # Initialize second layer weights
            n_in = inter_vars  # Total input dimension
            n_out = inter_vars_2       # Output dimension
            std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

            self.inter_weights_2 = nn.Parameter(
                torch.randn(self.num_inter_vars_2, device=device) * std
            )

            # For vertex update weights
            n_in = vertex_hidden   # Input vertex features + message features
            n_out = vertex_outputs
            std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

            self.vertex_weights_2 = nn.Parameter(
                torch.randn(self.num_vertex_outputs, device=device) * std
            )

            # For edge update weights
            n_in = edge_hidden    # Edge features + sum of messages
            n_out = edge_outputs
            std = torch.sqrt(torch.tensor(6.0/(n_in + n_out)))

            self.edge_weights_2 = nn.Parameter(
                torch.randn(self.num_edge_outputs, device=device) * std
            )

        # Pre-compute neighborhoods
        self.neighborhoods = self._compute_neighborhoods()

    def _compute_neighborhoods(self):
        """Pre-compute neighborhoods for each vertex."""
        neighborhoods = {}
        for i in range(self.num_vertices):
            neighborhoods[i] = torch.nonzero(self.adj_matrix[i], as_tuple=False).squeeze(1)
        return neighborhoods

    def _normalize_if_affine(self, tensor):
        """Normalize tensor if is_affine is True and tensor is non-zero."""
        if self.is_affine and torch.norm(tensor) > 1e-8:
            return F.normalize(tensor, dim=-1)
        return tensor

    def intermediate_term(self, h_i, h_j, a_ij):
        """Compute message from vertex i to vertex j."""
        try:
            # First layer
            #print(f"Computing {len(self.inter_invt_funs)} invariant functions...")

            invt_outputs = torch.stack([
                torch.as_tensor(f(h_i, h_j, a_ij), device=self.device)
                for f in self.inter_invt_funs
            ])
            message = torch.matmul(self.inter_weights, invt_outputs)
            message = self._normalize_if_affine(message)
            message = self.inter_activation(message)

            if not self.double_layer:
                return message

            # Second layer
            message = torch.matmul(self.inter_weights_2, message)
            message = self._normalize_if_affine(message)
            return self.inter_activation_2(message)

        except RuntimeError as e:
            raise RuntimeError(f"Error in intermediate_term: {str(e)}")

    def vertex_update(self, h, m):
        """Update vertex features."""
        try:
            # First layer
            invt_outputs = torch.stack([
                torch.as_tensor(f(h, m), device=self.device)
                for f in self.vertex_invt_funs
            ])
            update = torch.matmul(self.vertex_weights, invt_outputs)
            update = self._normalize_if_affine(update)
            update = self.vertex_activation(update)

            if not self.double_layer:
                return update

            # Second layer
            update = torch.matmul(self.vertex_weights_2, update)
            update = self._normalize_if_affine(update)
            return self.vertex_activation_2(update)

        except RuntimeError as e:
            raise RuntimeError(f"Error in vertex_update: {str(e)}")

    def edge_update(self, a, m):
        """Update edge features."""
        try:
            # First layer
            invt_outputs = torch.stack([
                torch.as_tensor(f(a, m), device=self.device)
                for f in self.edge_invt_funs
            ])
            update = torch.matmul(self.edge_weights, invt_outputs)
            update = self._normalize_if_affine(update)
            update = self.edge_activation(update)

            if not self.double_layer:
                return update

            # Second layer
            update = torch.matmul(self.edge_weights_2, update)
            update = self._normalize_if_affine(update)
            return self.edge_activation_2(update)

        except RuntimeError as e:
            raise RuntimeError(f"Error in edge_update: {str(e)}")

    def _time_op(self, name: str, op, *args, **kwargs):
        """Time an operation and store result"""
        if cuda.is_available():
            cuda.synchronize()  # Ensure GPU ops complete
        start = time.perf_counter()
        result = op(*args, **kwargs)
        if cuda.is_available():
            cuda.synchronize()
        duration = time.perf_counter() - start

        if name not in self.timings:
            self.timings[name] = []
        self.timings[name].append(duration)

        return result

    def forward(self, h_graph: torch.Tensor, a_graph: torch.Tensor):
        """
        Forward pass of the layer.

        Parameters:
        -----------
        h_graph : torch.Tensor
            Vertex features (shape: [num_vertices, vertex_inputs])
        a_graph : torch.Tensor
            Edge features (shape: [num_vertices, num_vertices, edge_inputs])

        Returns:
        --------
        tuple(torch.Tensor, torch.Tensor)
            Updated vertex and edge features
        """
        if h_graph.shape != (self.num_vertices, self.vertex_inputs):
            raise ValueError(f"h_graph shape {h_graph.shape} doesn't match expected shape"
                           f" ({self.num_vertices}, {self.vertex_inputs})")
        if a_graph.shape != (self.num_vertices, self.num_vertices, self.edge_inputs):
            raise ValueError(f"a_graph shape {a_graph.shape} doesn't match expected shape"
                           f" ({self.num_vertices}, {self.num_vertices}, {self.edge_inputs})")

        print("EGCL forward pass starting...")
        self.timings = {}

        # Move inputs to correct device
        h_graph = h_graph.to(self.device)
        a_graph = a_graph.to(self.device)

        # Initialize output tensors
        h_graph_out = torch.zeros(
            (self.num_vertices, self.vertex_outputs),
            device=self.device
        )
        a_graph_out = torch.zeros(
            (self.num_vertices, self.num_vertices, self.edge_outputs),
            device=self.device
        )

        # Compute messages for all edges
        """
        messages = torch.zeros(
            (self.num_vertices, self.num_vertices, self.inter_vars),
            device=self.device
        )
        """
        messages = self._time_op(
            "init_messages",
            lambda: torch.zeros(
                (self.num_vertices, self.num_vertices, self.inter_vars),
                device=self.device
            )
        )

        # Get active edges
        edge_indices = torch.triu_indices(self.num_vertices, self.num_vertices, 1).to(self.device)
        edge_mask = self.adj_matrix[edge_indices[0], edge_indices[1]] > 0
        edge_indices = (edge_indices[0][edge_mask], edge_indices[1][edge_mask])
        num_edges = len(edge_indices[0])

        # Compute messages only for existing edges
        """
        for idx, (i, j) in enumerate(zip(*edge_indices)):
            messages[i, j] = self.intermediate_term(h_graph[i], h_graph[j], a_graph[i, j])
            messages[j, i] = self.intermediate_term(h_graph[j], h_graph[i], a_graph[j, i])
        """

        for idx, (i, j) in enumerate(zip(*edge_indices)):
            messages[i, j] = self._time_op(
                "intermediate_term",
                self.intermediate_term,
                h_graph[i], h_graph[j], a_graph[i, j]
            )
            messages[j, i] = self.intermediate_term(h_graph[j], h_graph[i], a_graph[j, i])

        #print("Computing message sums...")
        """
        message_sums = torch.sum(
            messages * self.adj_matrix.unsqueeze(-1),
            dim=1
        )
        """
        message_sums = self._time_op(
            "message_sums",
            lambda: torch.sum(
                messages * self.adj_matrix.unsqueeze(-1),
                dim=1
            )
        )

        # print(f"Updating vertices...")
        """
        for i in range(self.num_vertices):
            h_graph_out[i] = self.vertex_update(h_graph[i], message_sums[i])
        """
        for i in range(self.num_vertices):
            h_graph_out[i] = self._time_op(
                "vertex_update",
                self.vertex_update,
                h_graph[i], message_sums[i]
            )

        #print(f"Updating edges...")
        """
        for i, j in zip(*edge_indices):
            a_graph_out[i, j] = self.edge_update(
                a_graph[i, j],
                message_sums[i] + message_sums[j]
            )
        """
        for i, j in zip(*edge_indices):
            a_graph_out[i, j] = self._time_op(
                "edge_update",
                self.edge_update,
                a_graph[i, j],
                message_sums[i] + message_sums[j]
            )

        # Print timing summary
        print("\nEGCL Layer Timing Summary:")
        for op, times in self.timings.items():
            avg_time = sum(times) / len(times)
            total_time = sum(times)
            print(f"{op:20s} - Avg: {avg_time:.4f}s, Total: {total_time:.4f}s, Count: {len(times)}")


        #print("Forward pass complete")
        return h_graph_out, a_graph_out

Let's start by testing the layer with an example.  I'll let $G = E(3)$ and use the fundamental representation $\mathbb R^4$ for the vertex space $W_V$ and intermediate space and the trivial representation $\mathbb R$ for the edge space $W_E$.  I'll just list the invariant functions defined individually.

In [7]:
num_vertices = 5
adj_matrix = torch.randint(0,2,(num_vertices, num_vertices))
adj_matrix = (adj_matrix + adj_matrix.T) / 2

def inter_invt_fun1(h_i, h_j, a_ij):
    return h_i

def inter_invt_fun2(h_i, h_j, a_ij):
    return h_j

def inter_invt_fun3(h_i, h_j, a_ij):
    return a_ij * h_i

def inter_invt_fun4(h_i, h_j, a_ij):
    return a_ij * h_j

def inter_invt_fun5(h_i, h_j, a_ij):
    return h_i * torch.linalg.norm(h_i - h_j) ** 2

def inter_invt_fun6(h_i, h_j, a_ij):
    return h_j * torch.linalg.norm(h_i - h_j) ** 2

def vertex_invt_fun1(h, m):
    return h

def vertex_invt_fun2(h, m):
    return m

def vertex_invt_fun3(h, m):
    return h * torch.linalg.norm(h - m) ** 2

def vertex_invt_fun4(h, m):
    return m * torch.linalg.norm(h - m) ** 2


def edge_invt_fun(a, m):
    return a

inter_invt_funs = [
    inter_invt_fun1,
    inter_invt_fun2,
    inter_invt_fun3,
    inter_invt_fun4,
    inter_invt_fun5,
    inter_invt_fun6
]

vertex_invt_funs = [
    vertex_invt_fun1,
    vertex_invt_fun2,
    vertex_invt_fun3,
    vertex_invt_fun4
]

edge_invt_funs = [edge_invt_fun]

In [8]:
layer = EGCL(
    num_vertices=num_vertices,
    adj_matrix=adj_matrix,
    vertex_inputs=4,    # R^4 for vertex features
    edge_inputs=1,      # R^1 for edge features (trivial rep)
    vertex_outputs=4,   # R^4 output
    edge_outputs=1,     # R^1 output
    inter_vars=4,       # R^4 for intermediate representation
    inter_invt_funs=inter_invt_funs,
    vertex_invt_funs=vertex_invt_funs,
    edge_invt_funs=edge_invt_funs,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True     # We are using an affine group so should rescale to preserve a hyperplane
)

# Create some example input data
h_graph = torch.randn(num_vertices, 4)  # Random vertex features in R^4
a_graph = torch.randn(num_vertices, num_vertices, 1)  # Random edge features

# Forward pass
h_out, a_out = layer(h_graph, a_graph)

print(f"Input vertex features shape: {h_graph.shape}")
print(f"Output vertex features shape: {h_out.shape}")
print(f"Input edge features shape: {a_graph.shape}")
print(f"Output edge features shape: {a_out.shape}")

# Print statistics before and after activation
def print_activation_stats(tensor, name):
    print(f"\n{name} statistics:")
    print(f"Min: {tensor.min():.4f}")
    print(f"Max: {tensor.max():.4f}")
    print(f"Mean: {tensor.mean():.4f}")
    print(f"Std: {tensor.std():.4f}")

# Get pre-activation values (you'd need to modify the layer to expose these)
print_activation_stats(h_out, "Vertex outputs")
print_activation_stats(a_out, "Edge outputs")


EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.0005s, Total: 0.0044s, Count: 9
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.0005s, Total: 0.0023s, Count: 5
edge_update          - Avg: 0.0003s, Total: 0.0028s, Count: 9
Input vertex features shape: torch.Size([5, 4])
Output vertex features shape: torch.Size([5, 4])
Input edge features shape: torch.Size([5, 5, 1])
Output edge features shape: torch.Size([5, 5, 1])

Vertex outputs statistics:
Min: 0.0000
Max: 0.4412
Mean: 0.0453
Std: 0.1160

Edge outputs statistics:
Min: 0.0000
Max: 1.0000
Mean: 0.1600
Std: 0.3742


#Building $E(n)$ Equivariant Models

The next step is to build a model.  I'll stick with the $E(n)$ example, and representations of the form $(\mathbb R^{n+1})^a \times \mathbb R^b$ -- products of the fundamental and trivial representations.  I'll need a general method for constructing the sets of generating equivariant functions.

Let's consider equivariant functions
$$F \colon (\mathbb R^{n+1})^a \times \mathbb R^b \to (\mathbb R^{n+1})^c \times \mathbb R^d.$$
The set of $\mathrm O(n)$-equivariant functions we will wish to generate will consist of
$$\{x_{ij}, e_{kl}, \langle x_p, x_q \rangle x_{ij}, \langle x_p, x_q \rangle e_{kl}, e_r x_{ij}, e_r e_{kl}\}$$
where $x_{ij}, e_{kl}$ are the matrix element functions in the first and second factors respectively, and where $x_p, e_r$ are coordinate functions on the $p^\text{th}$ factor of $(\mathbb R^{n+1})^a$ and the $r^\text{th}$ factor of $\mathbb R^b$ respectively.  To be additionally translation equivariant we must restrict the inner product terms to those generated by $\langle x_p, x_q \rangle - \langle x_p, x_{q'} \rangle$, and the $x_{ij}$ terms to linear combinations with coefficients summing to one.


In [9]:
from typing import List, Tuple, Callable, Optional

class EnRepresentation:
    """
    Handles E(n) representations of the form (R^(n+1))^a × R^b
    """
    def __init__(self, dim: int, num_vectors: int, num_scalars: int):
        """
        Args:
            dim: Dimension n of the ambient spacetime
            num_vectors: Number of R^(n+1) factors
            num_scalars: Dimension of factors
        """
        self.dim = dim
        self.num_vectors = num_vectors
        self.num_scalars = num_scalars

    def total_dim(self) -> int:
        """Total dimension of representation space."""
        return (self.dim + 1) * self.num_vectors + self.num_scalars

    def split_vector_scalar(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Split tensor into vector and scalar parts."""
        vector_dim = (self.dim + 1) * self.num_vectors
        return x[:vector_dim], x[vector_dim:]

    def combine_vector_scalar(self, vectors: torch.Tensor, scalars: torch.Tensor) -> torch.Tensor:
        """Combine vector and scalar parts."""
        return torch.cat([vectors, scalars])

def generate_En_message_invariants(
    dim: int,
    vertex_rep: EnRepresentation,  # Vertex rep
    edge_rep: EnRepresentation,  # Edge rep
    out_rep: EnRepresentation   # Output rep
) -> List[Callable]:
    """
    Generate E(n)-equivariant functions for message passing.
    """
    invariants = []

    def make_linear_vector_invariant(i: int, j: int) -> Callable:
        """Create linear invariant function for vector factor."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
            v1, s1 = vertex_rep.split_vector_scalar(h1)
            v2, s2 = vertex_rep.split_vector_scalar(h2)
            ve, se = edge_rep.split_vector_scalar(e)

            out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
            out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

            if i < vertex_rep.num_vectors:
                out_v[j*(dim + 1):(j+1)*(dim + 1)] = v1[i*(dim + 1):(i+1)*(dim + 1)]
            elif i < vertex_rep.num_vectors * 2:
                i2 = i - vertex_rep.num_vectors
                out_v[j*(dim + 1):(j+1)*(dim + 1)] = v2[i2*(dim + 1):(i2+1)*(dim + 1)]

            return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_linear_scalar_invariant(k: int, l: int) -> Callable:
        """Create linear invariant function for scalar factor."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
          v1, s1 = vertex_rep.split_vector_scalar(h1)
          v2, s2 = vertex_rep.split_vector_scalar(h2)
          ve, se = edge_rep.split_vector_scalar(e)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

          if k < vertex_rep.num_scalars:
            out_s[l] = s1[k]
          elif k < vertex_rep.num_scalars * 2:
            k2 = k - vertex_rep.num_scalars
            out_s[l] = s2[k2]

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_quadratic_vector_invariant(i: int, j: int, r: int) -> Callable:
        """Create quadratic invariant function x_ij e_r."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
            v1, s1 = vertex_rep.split_vector_scalar(h1)
            v2, s2 = vertex_rep.split_vector_scalar(h2)
            ve, se = edge_rep.split_vector_scalar(e)

            out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
            out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

            g = make_linear_vector_invariant(i,j)
            lin_v, lin_s = out_rep.split_vector_scalar(g(h1, h2, e))
            if r < vertex_rep.num_scalars:
                out_v = lin_v * s1[r]
            elif r < vertex_rep.num_scalars * 2:
                r2 = r - vertex_rep.num_scalars
                out_v = lin_v * s2[r2]

            return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_quadratic_scalar_invariant(k: int, l: int, r: int) -> Callable:
        """Create quadratic invariant function e_kl e_r."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
            v1, s1 = vertex_rep.split_vector_scalar(h1)
            v2, s2 = vertex_rep.split_vector_scalar(h2)
            ve, se = edge_rep.split_vector_scalar(e)

            out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
            out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

            g = make_linear_scalar_invariant(k,l)
            lin_v, lin_s = out_rep.split_vector_scalar(g(h1, h2, e))
            if r < vertex_rep.num_scalars:
                out_s = lin_s * s1[r]
            elif r < vertex_rep.num_scalars * 2:
                r2 = r - vertex_rep.num_scalars
                out_s = lin_s * s2[r2]

            return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_cubic_vector_invariant(i: int, j: int, p: int, q: int) -> Callable:
        """Create cubic invariant function <x_p, x_q> x_ij."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
          v1, s1 = vertex_rep.split_vector_scalar(h1)
          v2, s2 = vertex_rep.split_vector_scalar(h2)
          ve, se = edge_rep.split_vector_scalar(e)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

          if p < vertex_rep.num_vectors:
            w1 = v1[p*(dim + 1):(p+1)*(dim + 1)]
          elif p < vertex_rep.num_vectors * 2:
            p2 = p - vertex_rep.num_vectors
            w1 = v2[p2*(dim + 1):(p2+1)*(dim + 1)]

          if q < vertex_rep.num_vectors:
            w2 = v1[q*(dim + 1):(q+1)*(dim + 1)]
          elif q < vertex_rep.num_vectors * 2:
            q2 = q - vertex_rep.num_vectors
            w2 = v2[q2*(dim + 1):(q2+1)*(dim + 1)]

          g = make_linear_vector_invariant(i,j)
          lin_v, lin_s = out_rep.split_vector_scalar(g(h1, h2, e))
          out_v = lin_v * torch.dot(w1, w2)

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_cubic_scalar_invariant(k: int, l: int, p: int, q: int) -> Callable:
        """Create cubic invariant function <x_p, x_q> e_kl."""
        def f(h1: torch.Tensor, h2: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
          v1, s1 = vertex_rep.split_vector_scalar(h1)
          v2, s2 = vertex_rep.split_vector_scalar(h2)
          ve, se = edge_rep.split_vector_scalar(e)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h1.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h1.device)

          if p < vertex_rep.num_vectors:
            w1 = v1[p*(dim + 1):(p+1)*(dim + 1)]
          elif p < vertex_rep.num_vectors * 2:
            p2 = p - vertex_rep.num_vectors
            w1 = v2[p2*(dim + 1):(p2+1)*(dim + 1)]

          if q < vertex_rep.num_vectors:
            w2 = v1[q*(dim + 1):(q+1)*(dim + 1)]
          elif q < vertex_rep.num_vectors * 2:
            q2 = q - vertex_rep.num_vectors
            w2 = v2[q2*(dim + 1):(q2+1)*(dim + 1)]

          g = make_linear_scalar_invariant(k,l)
          lin_v, lin_s = out_rep.split_vector_scalar(g(h1, h2, e))
          out_v = lin_v * torch.dot(w1, w2)

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f


    # Add linear invariants
    for i in range(vertex_rep.num_vectors *2):
        for j in range(out_rep.num_vectors):
            invariants.append(make_linear_vector_invariant(i, j))

    for k in range(vertex_rep.num_scalars *2):
        for l in range(out_rep.num_scalars):
            invariants.append(make_linear_scalar_invariant(k, l))

    # Add quadratic invariants
    for i in range(vertex_rep.num_vectors *2):
        for j in range(out_rep.num_vectors):
            for r in range(vertex_rep.num_scalars * 2):
                invariants.append(make_quadratic_vector_invariant(i, j, r))

    for k in range(vertex_rep.num_scalars *2):
        for l in range(out_rep.num_scalars):
            for r in range(vertex_rep.num_scalars * 2):
              invariants.append(make_quadratic_scalar_invariant(k, l, r))

    # Add cubic invariants.  We need to generate differences like <x_p, x_q - x_{q+1}>x_ij

    def make_difference(f1, f2):
                      def diff(h1, h2, e):
                          return f1(h1, h2, e) - f2(h1, h2, e)
                      return diff

    for i in range(vertex_rep.num_vectors *2):
        for j in range(out_rep.num_vectors):
            for p in range(vertex_rep.num_vectors *2):
                for q in range(vertex_rep.num_vectors):
                  next_q = (q + 1) % vertex_rep.num_vectors  # Cyclic index
                  f_current = make_cubic_vector_invariant(i, j, p, q)
                  f_next = make_cubic_vector_invariant(i, j, p, next_q)

                  invariants.append(make_difference(f_current, f_next))

                base_q = vertex_rep.num_vectors
                for q_offset in range(vertex_rep.num_vectors):
                   q = base_q + q_offset
                   next_q = base_q + ((q_offset + 1) % vertex_rep.num_vectors)
                   g_current = make_cubic_vector_invariant(i, j, p, q)
                   g_next = make_cubic_vector_invariant(i, j, p, next_q)

                   invariants.append(make_difference(g_current, g_next))

    for k in range(vertex_rep.num_scalars *2):
        for l in range(out_rep.num_scalars):
          for p in range(vertex_rep.num_vectors *2):
                for q in range(vertex_rep.num_vectors):
                  next_q = (q + 1) % vertex_rep.num_vectors  # Cyclic index
                  f_current = make_cubic_scalar_invariant(k, l, p, q)
                  f_next = make_cubic_scalar_invariant(k, l, p, next_q)

                  invariants.append(make_difference(f_current, f_next))

                base_q = vertex_rep.num_vectors
                for q_offset in range(vertex_rep.num_vectors):
                   q = base_q + q_offset
                   next_q = base_q + ((q_offset + 1) % vertex_rep.num_vectors)
                   g_current = make_cubic_scalar_invariant(k, l, p, q)
                   g_next = make_cubic_scalar_invariant(k, l, p, next_q)

                   invariants.append(make_difference(g_current, g_next))


    return invariants

def generate_En_vertex_edge_invariants(
    dim: int,
    input_rep: EnRepresentation,  # Vertex or edge rep
    internal_rep: EnRepresentation,  # Internal variable rep
    out_rep: EnRepresentation   # Output rep
) -> List[Callable]:
    """
    Generate E(n)-equivariant functions for message passing.
    """
    invariants = []

    def make_linear_vector_invariant(i: int, j: int) -> Callable:
        """Create linear invariant function for vector factor."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
            v, s = input_rep.split_vector_scalar(h)
            vi, si = internal_rep.split_vector_scalar(m)

            out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
            out_s = torch.zeros(out_rep.num_scalars, device=h.device)

            if i < input_rep.num_vectors:
                out_v[j*(dim + 1):(j+1)*(dim + 1)] = v[i*(dim + 1):(i+1)*(dim + 1)]
            elif i < input_rep.num_vectors + internal_rep.num_vectors:
                i2 = i - input_rep.num_vectors
                out_v[j*(dim + 1):(j+1)*(dim + 1)] = vi[i2*(dim + 1):(i2+1)*(dim + 1)]

            return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_linear_scalar_invariant(k: int, l: int) -> Callable:
        """Create linear invariant function for scalar factor."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
          v, s = input_rep.split_vector_scalar(h)
          vi, si = internal_rep.split_vector_scalar(m)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h.device)

          if k < input_rep.num_scalars:
            out_s[l] = s[k]
          elif k < input_rep.num_scalars + internal_rep.num_scalars:
            k2 = k - input_rep.num_scalars
            out_s[l] = si[k2]

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_quadratic_vector_invariant(i: int, j: int, r: int) -> Callable:
        """Create quadratic invariant function x_ij e_r."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
            v, s = input_rep.split_vector_scalar(h)
            vi, si = internal_rep.split_vector_scalar(m)

            out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
            out_s = torch.zeros(out_rep.num_scalars, device=h.device)

            g = make_linear_vector_invariant(i,j)
            lin_v, lin_s = out_rep.split_vector_scalar(g(h, m))
            if r < input_rep.num_scalars:
                out_v = lin_v * s[r]
            elif r < input_rep.num_scalars + internal_rep.num_scalars:
                r2 = r - input_rep.num_scalars
                out_v = lin_v * si[r2]

            return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_quadratic_scalar_invariant(k: int, l: int, r: int) -> Callable:
        """Create quadratic invariant function e_kl e_r."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
          v, s = input_rep.split_vector_scalar(h)
          vi, si = internal_rep.split_vector_scalar(m)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h.device)

          g = make_linear_scalar_invariant(k,l)
          lin_v, lin_s = out_rep.split_vector_scalar(g(h, m))
          if r < input_rep.num_scalars:
              out_s = lin_s * s[r]
          elif r < input_rep.num_scalars + internal_rep.num_scalars:
              r2 = r - input_rep.num_scalars
              out_s = lin_s * si[r2]

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_cubic_vector_invariant(i: int, j: int, p: int, q: int) -> Callable:
        """Create cubic invariant function <x_p, x_q> x_ij."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
          v, s = input_rep.split_vector_scalar(h)
          vi, si = internal_rep.split_vector_scalar(m)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h.device)

          if p < input_rep.num_vectors:
            w1 = v[p*(dim + 1):(p+1)*(dim + 1)]
          elif p < input_rep.num_vectors + internal_rep.num_vectors:
            p2 = p - input_rep.num_vectors
            w1 = vi[p2*(dim + 1):(p2+1)*(dim + 1)]

          if q < input_rep.num_vectors:
            w2 = v[q*(dim + 1):(q+1)*(dim + 1)]
          elif q < input_rep.num_vectors + internal_rep.num_vectors:
            q2 = q - input_rep.num_vectors
            w2 = vi[q2*(dim + 1):(q2+1)*(dim + 1)]

          g = make_linear_vector_invariant(i,j)
          lin_v, lin_s = out_rep.split_vector_scalar(g(h, m))
          out_v = lin_v * torch.dot(w1, w2)

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f

    def make_cubic_scalar_invariant(k: int, l: int, p: int, q: int) -> Callable:
        """Create cubic invariant function <x_p, x_q> e_kl."""
        def f(h: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
          v, s = input_rep.split_vector_scalar(h)
          vi, si = internal_rep.split_vector_scalar(m)

          out_v = torch.zeros((out_rep.num_vectors * (dim + 1)), device=h.device)
          out_s = torch.zeros(out_rep.num_scalars, device=h.device)

          if p < input_rep.num_vectors:
            w1 = v[p*(dim + 1):(p+1)*(dim + 1)]
          elif p < input_rep.num_vectors + internal_rep.num_vectors:
            p2 = p - input_rep.num_vectors
            w1 = vi[p2*(dim + 1):(p2+1)*(dim + 1)]

          if q < input_rep.num_vectors:
            w2 = v[q*(dim + 1):(q+1)*(dim + 1)]
          elif q < input_rep.num_vectors + internal_rep.num_vectors:
            q2 = q - input_rep.num_vectors
            w2 = vi[q2*(dim + 1):(q2+1)*(dim + 1)]

          g = make_linear_scalar_invariant(k,l)
          lin_v, lin_s = out_rep.split_vector_scalar(g(h, m))
          out_s = lin_s * torch.dot(w1, w2)

          return out_rep.combine_vector_scalar(out_v, out_s)
        return f


    # Add linear invariants
    for i in range(input_rep.num_vectors + internal_rep.num_vectors):
        for j in range(out_rep.num_vectors):
            invariants.append(make_linear_vector_invariant(i, j))

    for k in range(input_rep.num_scalars + internal_rep.num_scalars):
        for l in range(out_rep.num_scalars):
            invariants.append(make_linear_scalar_invariant(k, l))

    # Add quadratic invariants
    for i in range(input_rep.num_vectors + internal_rep.num_vectors):
        for j in range(out_rep.num_vectors):
            for r in range(input_rep.num_scalars + internal_rep.num_scalars):
                invariants.append(make_quadratic_vector_invariant(i, j, r))

    for k in range(input_rep.num_scalars + internal_rep.num_scalars):
        for l in range(out_rep.num_scalars):
            for r in range(input_rep.num_scalars + internal_rep.num_scalars):
              invariants.append(make_quadratic_scalar_invariant(k, l, r))

    # Add cubic invariants.  We need to generate differences like <x_p, x_q - x_{q+1}>x_ij

    def make_difference(f1, f2):
                      def diff(h, m):
                          return f1(h, m) - f2(h, m)
                      return diff

    for i in range(input_rep.num_vectors + internal_rep.num_vectors):
        for j in range(out_rep.num_vectors):
            for p in range(input_rep.num_vectors + internal_rep.num_vectors):
                for q in range(input_rep.num_vectors):
                  next_q = (q + 1) % input_rep.num_vectors  # Cyclic index
                  f_current = make_cubic_vector_invariant(i, j, p, q)
                  f_next = make_cubic_vector_invariant(i, j, p, next_q)

                  invariants.append(make_difference(f_current, f_next))

                base_q = input_rep.num_vectors
                for q_offset in range(internal_rep.num_vectors):
                   q = base_q + q_offset
                   next_q = base_q + ((q_offset + 1) % internal_rep.num_vectors)
                   g_current = make_cubic_vector_invariant(i, j, p, q)
                   g_next = make_cubic_vector_invariant(i, j, p, next_q)

                   invariants.append(make_difference(g_current, g_next))

    for k in range(input_rep.num_scalars + internal_rep.num_scalars):
        for l in range(out_rep.num_scalars):
          for p in range(input_rep.num_vectors + internal_rep.num_vectors):
                for q in range(input_rep.num_vectors):
                  next_q = (q + 1) % input_rep.num_vectors  # Cyclic index
                  f_current = make_cubic_scalar_invariant(k, l, p, q)
                  f_next = make_cubic_scalar_invariant(k, l, p, next_q)

                  invariants.append(make_difference(f_current, f_next))

                base_q = input_rep.num_vectors
                for q_offset in range(internal_rep.num_vectors):
                   q = base_q + q_offset
                   next_q = base_q + ((q_offset + 1) % internal_rep.num_vectors)
                   g_current = make_cubic_scalar_invariant(k, l, p, q)
                   g_next = make_cubic_scalar_invariant(k, l, p, next_q)

                   invariants.append(make_difference(g_current, g_next))


    return invariants

Now that we are able to generate invariant functions we can build a model involving several $E(n)$-equivariant layers.  Let's build an $E(3)$-equivariant model with three hidden layers.  The input and output representations will be defined with $W_V = (\mathbb{R}^4) \times \mathbb{R}^{n_f}$, and $W_E = \mathbb{R}$.  Let's specifically set $n_f = 2$.

I'll keep the edge representation trivial in intermediate layers, but allow the vertex representation to vary.  Initially let's say we have the following intermediate layer vertex representations:
\begin{align}
W_V^{(2)} &= (\mathbb R^4)^2 \times \mathbb R^2 \\
W_V^{(3)} &= (\mathbb R^4)^3 \times \mathbb R^3 \\
W_V^{(4)} &= (\mathbb R^4)^2 \times \mathbb R^2.
\end{align}
In each layer I'll set the internal message representation equal to the vertex representation.  I won't use the double layer option for now, and I'll use ReLU activations throughout.


In [10]:
dim = 3

num_vertices = 4
adj_matrix = torch.randint(0,2,(num_vertices, num_vertices))
adj_matrix = (adj_matrix + adj_matrix.T) / 2

# Define representations
input_vertex_rep = EnRepresentation(dim=3, num_vectors=1, num_scalars=2)
edge_rep = EnRepresentation(dim=3, num_vectors=0, num_scalars=1)
message_rep_1 = input_vertex_rep
hidden_vertex_rep_1 = EnRepresentation(dim=3, num_vectors=2, num_scalars=2)
hidden_message_rep_1 = hidden_vertex_rep_1
hidden_vertex_rep_2 = EnRepresentation(dim=3, num_vectors=3, num_scalars=3)
hidden_message_rep_2 = hidden_vertex_rep_2
hidden_vertex_rep_3 = EnRepresentation(dim=3, num_vectors=2, num_scalars=2)
hidden_message_rep_3 = hidden_vertex_rep_3
output_vertex_rep = EnRepresentation(dim=3, num_vectors=1, num_scalars=2)

# Layer 1
message_invariants_1 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = input_vertex_rep,
    edge_rep = edge_rep,
    out_rep = message_rep_1
)
vertex_invariants_1 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = input_vertex_rep,
    internal_rep = message_rep_1,
    out_rep = hidden_vertex_rep_1
)
edge_invariants_1 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = message_rep_1,
    out_rep = edge_rep
)

layer1 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=input_vertex_rep.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_1.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=message_rep_1.total_dim(),
    inter_invt_funs=message_invariants_1,
    vertex_invt_funs=vertex_invariants_1,
    edge_invt_funs=edge_invariants_1,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 2
message_invariants_2 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_1,
    edge_rep = edge_rep,
    out_rep = hidden_message_rep_1
)
vertex_invariants_2 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_1,
    internal_rep = hidden_message_rep_1,
    out_rep = hidden_vertex_rep_2
)
edge_invariants_2 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_1,
    out_rep = edge_rep
)

layer2 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_1.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_2.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_1.total_dim(),
    inter_invt_funs=message_invariants_2,
    vertex_invt_funs=vertex_invariants_2,
    edge_invt_funs=edge_invariants_2,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 3
message_invariants_3 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_2,
    edge_rep = edge_rep,
    out_rep = hidden_message_rep_2
)
vertex_invariants_3 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_2,
    internal_rep = hidden_message_rep_2,
    out_rep = hidden_vertex_rep_3
)
edge_invariants_3 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_2,
    out_rep = edge_rep
)

layer3 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_2.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_3.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_2.total_dim(),
    inter_invt_funs=message_invariants_3,
    vertex_invt_funs=vertex_invariants_3,
    edge_invt_funs=edge_invariants_3,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 4
message_invariants_4 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_3,
    edge_rep = edge_rep,
    out_rep = hidden_vertex_rep_3
)
vertex_invariants_4 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_3,
    internal_rep = hidden_message_rep_3,
    out_rep = output_vertex_rep
)
edge_invariants_4 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_3,
    out_rep = edge_rep
)

layer4 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_3.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=output_vertex_rep.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_3.total_dim(),
    inter_invt_funs=message_invariants_4,
    vertex_invt_funs=vertex_invariants_4,
    edge_invt_funs=edge_invariants_4,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

In [11]:
class EnEquivariantNet(nn.Module):

    def __init__(self, layers: List[EGCL]):
        """
        Args:
            layers: List of EGCL layers to be applied in sequence
        """
        super(EnEquivariantNet, self).__init__()
        self.layers = nn.ModuleList(layers)

        # Verify layers are compatible
        for i in range(len(layers)-1):
            if layers[i].vertex_outputs != layers[i+1].vertex_inputs:
                raise ValueError(f"Layer {i} output dimension {layers[i].vertex_outputs} "
                               f"doesn't match layer {i+1} input dimension {layers[i+1].vertex_inputs}")
            if layers[i].edge_outputs != layers[i+1].edge_inputs:
                raise ValueError(f"Layer {i} edge output dimension {layers[i].edge_outputs} "
                               f"doesn't match layer {i+1} edge input dimension {layers[i+1].edge_inputs}")

    def forward(self, h_graph: torch.Tensor, a_graph: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            h_graph: Input vertex features
            a_graph: Input edge features

        Returns:
            tuple(torch.Tensor, torch.Tensor): Final vertex and edge features
        """
        h, a = h_graph, a_graph
        for layer in self.layers:
            h, a = layer(h, a)
        return h, a

# Combine layers into model
model = EnEquivariantNet([layer1, layer2, layer3, layer4])

# Create some example input data
h_input = torch.randn(num_vertices, input_vertex_rep.total_dim())
a_input = torch.randn(num_vertices, num_vertices, edge_rep.total_dim())

# Forward pass through model
h_output, a_output = model(h_input, a_input)

print(f"Input vertex features shape: {h_input.shape}")
print(f"Output vertex features shape: {h_output.shape}")
print(f"Input edge features shape: {a_input.shape}")
print(f"Output edge features shape: {a_output.shape}")

EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0001s, Total: 0.0001s, Count: 1
intermediate_term    - Avg: 0.0339s, Total: 0.2037s, Count: 6
message_sums         - Avg: 0.0004s, Total: 0.0004s, Count: 1
vertex_update        - Avg: 0.0488s, Total: 0.1952s, Count: 4
edge_update          - Avg: 0.0044s, Total: 0.0266s, Count: 6
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.2313s, Total: 1.3880s, Count: 6
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.2125s, Total: 0.8499s, Count: 4
edge_update          - Avg: 0.0081s, Total: 0.0484s, Count: 6
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.7566s, Total: 4.5394s, Count: 6
message_sums         - Avg: 0.0010s, Total: 0.0010s, Count: 1
vertex_update       

Let's now try training the model on the MD17 dataset, using the Benzene molecule and the DFT FHI-aims level of accuracy.  We'll try to learn the atomic force function from the atomic position.  I've reduced the size of layer 3 to $(\mathbb R^4)^2 \times \mathbb R^2$, similarly to the other layers.



In [12]:
# install PyTorch Geometric (PyG)
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [13]:
import torch
from torch_geometric.datasets import MD17
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torch_geometric.data import Data

import numpy as np
from torch.utils.data import Dataset
import os

class MD17Dataset(Dataset):
    """
    Dataset class for revised MD17 molecular dynamics data (FHI-aims level), specifically for benzene.
    """
    def __init__(
        self,
        root_dir: str = 'data/md17',
        max_samples: int = 1000,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        """
        Args:
            root_dir: Directory containing MD17 npz file
            max_samples: Maximum number of samples to use
            device: Device to store tensors on
        """
        self.device = device

        # Load benzene data
        data  = MD17(root=root_dir, name="benzene FHI-aims")

        # Get coordinates and forces
        self.coords = torch.stack([d.pos for d in data], dim=0).type(torch.float32).to(self.device)
        self.forces = torch.stack([d.force for d in data], dim=0).type(torch.float32).to(self.device)
        self.energies = torch.tensor([d.energy.item() for d in data], dtype=torch.float32).to(self.device)  # Extract energy values

        # Limit samples
        if max_samples and max_samples < len(self.coords):
            self.coords = self.coords[:max_samples]
            self.forces = self.forces[:max_samples]
            self.energies = self.energies[:max_samples]

        # Define benzene topology (adjacency matrix)
        self.n_atoms = 12  # C6H6
        self.adj_matrix = self._create_benzene_adjacency()

        # Create atomic features
        self.atom_features = self._create_atom_features()

    def _create_benzene_adjacency(self) -> torch.Tensor:
        """Create adjacency matrix for benzene."""
        adj = torch.zeros((self.n_atoms, self.n_atoms), device=self.device)

        # C-C bonds
        for i in range(6):
            j = (i + 1) % 6
            adj[i, j] = adj[j, i] = 1

        # C-H bonds
        for i in range(6):
            adj[i, i+6] = adj[i+6, i] = 1

        return adj

    def _create_atom_features(self) -> torch.Tensor:
        """
        Create atomic features using one-hot encoding
        """
        features = torch.zeros((self.n_atoms, 2), device=self.device)
        features[:6, 0] = 1  # Carbon atoms
        features[6:, 1] = 1  # Hydrogen atoms
        return features

    def __len__(self) -> int:
        return len(self.coords)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns:
            coords: [n_atoms, 4] homogeneous coordinates
            forces: [n_atoms, 4] forces (padded)
            edge_features: [n_atoms, n_atoms, 1] edge features (adjacency and distances)
        """
        # Get coordinates and convert to homogeneous coordinates
        coords = self.coords[idx]
        coords_homog = torch.cat([
            coords,
            torch.ones((self.n_atoms, 1), device=self.device)
        ], dim=1)

        # Get forces and pad with zeros
        forces = self.forces[idx]
        forces_homog = torch.cat([
            forces,
            torch.zeros((self.n_atoms, 1), device=self.device)
        ], dim=1)

        # Create edge features as zeroes
        edge_features = torch.zeros(
        (self.n_atoms, self.n_atoms, 1),
        device=self.device
    )

        return coords_homog, forces_homog, edge_features

def create_md17_dataloaders(
    root_dir: str,
    batch_size: int = 32,
    max_samples: int = 1000,
    validation_split: float = 0.1,
    num_workers: int = 0,
    seed: int = 42
) -> Tuple[DataLoader, DataLoader]:
    """Create train and validation dataloaders for MD17 data."""

    # Create the full dataset
    full_dataset = MD17Dataset(root_dir=root_dir)

    # Limit the number of samples if specified
    if max_samples and max_samples < len(full_dataset):
        indices = torch.randperm(len(full_dataset))[:max_samples]
        full_dataset = torch.utils.data.Subset(full_dataset, indices)

    # Calculate lengths for train/validation split
    total_size = len(full_dataset)
    val_size = int(validation_split * total_size)
    train_size = total_size - val_size

    # Create train/validation splits
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset,
        [train_size, val_size],
        generator=generator
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    return train_loader, val_loader

In [14]:
import torch.nn as nn
import torch.optim as optim
from typing import Tuple, List
import time
from tqdm import tqdm

class ForcePredictor(nn.Module):
    def __init__(self, model: EnEquivariantNet):
        super().__init__()
        self.model = model

    def forward(self, coords: torch.Tensor, edge_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            coords: [batch_size, n_atoms, 4] homogeneous coordinates
            edge_features: [batch_size, n_atoms, n_atoms, 1] edge features

        Returns:
            forces: [batch_size, n_atoms, 4] predicted forces
        """
        print("ForcePredictor forward pass starting...")
        batch_size = coords.shape[0]
        device = coords.device
        n_atoms = coords.shape[1]
        print(f"batch_size: {batch_size}, n_atoms: {n_atoms}")

        # Create one-hot encoding for atom types (C6H6)
        atom_features = torch.zeros((n_atoms, 2), device=device)
        atom_features[:6, 0] = 1  # Carbon atoms
        atom_features[6:, 1] = 1  # Hydrogen atoms
        print("Created atom features")

        # Process each molecule in batch
        all_forces = []
        for i in range(batch_size):
            print(f"Processing molecule {i+1}/{batch_size}")
            # Reshape coords[i] to be [n_atoms, 4]
            curr_coords = coords[i].reshape(n_atoms, -1)

            # Concatenate along the feature dimension
            h_graph = torch.cat([
                curr_coords,  # [n_atoms, 4]
                atom_features # [n_atoms, 2]
            ], dim=-1)  # Result: [n_atoms, 6]
            print(f"Created h_graph with shape {h_graph.shape}")

            # Forward pass through model
            print("Starting model forward pass...")
            h_out, _ = self.model(h_graph, edge_features[i])
            print("Model forward pass complete")

            # Extract force predictions (first 4 components)
            forces = h_out[:, :4]
            all_forces.append(forces)
            print(f"Processed molecule {i+1}")

        return torch.stack(all_forces)

def train_epoch(
    model: ForcePredictor,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    device: str
) -> float:
    """Train for one epoch."""
    model.train()
    total_loss = 0

    for coords, forces, edge_features in train_loader:
        coords = coords.to(device)
        forces = forces.to(device)
        edge_features = edge_features.to(device)

        optimizer.zero_grad()
        pred_forces = model(coords,  edge_features)

        # Compute loss only on spatial components
        loss = criterion(pred_forces[..., :3], forces[..., :3])

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(
    model: ForcePredictor,
    val_loader: DataLoader,
    criterion: nn.Module,
    device: str
) -> float:
    """Compute validation loss."""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for coords, forces, edge_features in val_loader:
            coords = coords.to(device)
            forces = forces.to(device)
            edge_features = edge_features.to(device)

            pred_forces = model(coords, edge_features)
            loss = criterion(pred_forces[..., :3], forces[..., :3])

            total_loss += loss.item()

    return total_loss / len(val_loader)

def train_model(
    model: ForcePredictor,
    train_loader: DataLoader,
    val_loader: DataLoader,
    n_epochs: int = 100,
    learning_rate: float = 1e-3,
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
) -> Tuple[List[float], List[float]]:
    """Train model and return training history."""
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    train_losses = []
    val_losses = []

    for epoch in range(n_epochs):
        start_time = time.time()

        # Training
        model.train()
        total_train_loss = 0
        train_batches = 0

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{n_epochs} [Train]')
        for coords, forces, edge_features in train_pbar:
            # Move to device
            coords = coords.to(device)
            forces = forces.to(device)
            edge_features = edge_features.to(device)

            optimizer.zero_grad()

            # Forward pass with memory tracking
            with torch.cuda.amp.autocast(enabled=True):  # Mixed precision
                pred_forces = model(coords, edge_features)
                loss = criterion(pred_forces[..., :3], forces[..., :3])

            loss.backward()
            optimizer.step()

            # Log loss
            total_train_loss += loss.item()
            train_batches += 1
            train_pbar.set_postfix({'loss': f'{loss.item():.6f}'})

            # Clear memory
            del pred_forces, loss
            torch.cuda.empty_cache()

        avg_train_loss = total_train_loss / train_batches
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        total_val_loss = 0
        val_batches = 0

        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{n_epochs} [Val]')
            for coords, forces, edge_features in val_pbar:
                coords = coords.to(device)
                forces = forces.to(device)
                edge_features = edge_features.to(device)

                with torch.amp.autocast(enabled=True):
                    pred_forces = model(coords, edge_features)
                    loss = criterion(pred_forces[..., :3], forces[..., :3])

                total_val_loss += loss.item()
                val_batches += 1
                val_pbar.set_postfix({'loss': f'{loss.item():.6f}'})

                # Clear memory
                del pred_forces, loss
                torch.cuda.empty_cache()

        avg_val_loss = total_val_loss / val_batches
        val_losses.append(avg_val_loss)

        epoch_time = time.time() - start_time

        print(f"\nEpoch {epoch+1}/{n_epochs} Summary:")
        print(f"Train Loss: {avg_train_loss:.6f}")
        print(f"Val Loss: {avg_val_loss:.6f}")
        print(f"Time: {epoch_time:.2f}s")
        print("-" * 40)

    return train_losses, val_losses



In [15]:
train_loader, val_loader = create_md17_dataloaders(
    root_dir='data/md17',
    batch_size=2,
    max_samples=1000
)

In [16]:
dim = 3

num_vertices = 12

adj_matrix = torch.zeros((num_vertices, num_vertices), device='cuda')

# C-C bonds
for i in range(6):
    j = (i + 1) % 6
    adj_matrix[i, j] = adj_matrix[j, i] = 1

# C-H bonds
for i in range(6):
    adj_matrix[i, i+6] = adj_matrix[i+6, i] = 1

# Define representations
input_vertex_rep = EnRepresentation(dim=3, num_vectors=1, num_scalars=2)
edge_rep = EnRepresentation(dim=3, num_vectors=0, num_scalars=1)
message_rep_1 = input_vertex_rep
hidden_vertex_rep_1 = EnRepresentation(dim=3, num_vectors=2, num_scalars=2)
hidden_message_rep_1 = hidden_vertex_rep_1
hidden_vertex_rep_2 = EnRepresentation(dim=3, num_vectors=2, num_scalars=2)
hidden_message_rep_2 = hidden_vertex_rep_2
hidden_vertex_rep_3 = EnRepresentation(dim=3, num_vectors=2, num_scalars=2)
hidden_message_rep_3 = hidden_vertex_rep_3
output_vertex_rep = EnRepresentation(dim=3, num_vectors=1, num_scalars=2)

# Layer 1
message_invariants_1 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = input_vertex_rep,
    edge_rep = edge_rep,
    out_rep = message_rep_1
)
vertex_invariants_1 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = input_vertex_rep,
    internal_rep = message_rep_1,
    out_rep = hidden_vertex_rep_1
)
edge_invariants_1 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = message_rep_1,
    out_rep = edge_rep
)

layer1 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=input_vertex_rep.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_1.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=message_rep_1.total_dim(),
    inter_invt_funs=message_invariants_1,
    vertex_invt_funs=vertex_invariants_1,
    edge_invt_funs=edge_invariants_1,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 2
message_invariants_2 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_1,
    edge_rep = edge_rep,
    out_rep = hidden_message_rep_1
)
vertex_invariants_2 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_1,
    internal_rep = hidden_message_rep_1,
    out_rep = hidden_vertex_rep_2
)
edge_invariants_2 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_1,
    out_rep = edge_rep
)

layer2 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_1.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_2.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_1.total_dim(),
    inter_invt_funs=message_invariants_2,
    vertex_invt_funs=vertex_invariants_2,
    edge_invt_funs=edge_invariants_2,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 3
message_invariants_3 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_2,
    edge_rep = edge_rep,
    out_rep = hidden_message_rep_2
)
vertex_invariants_3 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_2,
    internal_rep = hidden_message_rep_2,
    out_rep = hidden_vertex_rep_3
)
edge_invariants_3 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_2,
    out_rep = edge_rep
)

layer3 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_2.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=hidden_vertex_rep_3.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_2.total_dim(),
    inter_invt_funs=message_invariants_3,
    vertex_invt_funs=vertex_invariants_3,
    edge_invt_funs=edge_invariants_3,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

# Layer 4
message_invariants_4 = generate_En_message_invariants(
    dim = dim,
    vertex_rep = hidden_vertex_rep_3,
    edge_rep = edge_rep,
    out_rep = hidden_vertex_rep_3
)
vertex_invariants_4 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = hidden_vertex_rep_3,
    internal_rep = hidden_message_rep_3,
    out_rep = output_vertex_rep
)
edge_invariants_4 = generate_En_vertex_edge_invariants(
    dim = dim,
    input_rep = edge_rep,
    internal_rep = hidden_message_rep_3,
    out_rep = edge_rep
)

layer4 = EGCL(
    num_vertices = num_vertices,
    adj_matrix = adj_matrix,
    vertex_inputs=hidden_vertex_rep_3.total_dim(),
    edge_inputs=edge_rep.total_dim(),
    vertex_outputs=output_vertex_rep.total_dim(),
    edge_outputs=edge_rep.total_dim(),
    inter_vars=hidden_message_rep_3.total_dim(),
    inter_invt_funs=message_invariants_4,
    vertex_invt_funs=vertex_invariants_4,
    edge_invt_funs=edge_invariants_4,
    inter_activation=torch.nn.ReLU(),
    vertex_activation=torch.nn.ReLU(),
    edge_activation=torch.nn.ReLU(),
    is_affine=True
)

In [17]:
class EnEquivariantNet(nn.Module):

    def __init__(self, layers: List[EGCL]):
        """
        Args:
            layers: List of EGCL layers to be applied in sequence
        """
        super(EnEquivariantNet, self).__init__()
        self.layers = nn.ModuleList(layers)

        # Verify layers are compatible
        for i in range(len(layers)-1):
            if layers[i].vertex_outputs != layers[i+1].vertex_inputs:
                raise ValueError(f"Layer {i} output dimension {layers[i].vertex_outputs} "
                               f"doesn't match layer {i+1} input dimension {layers[i+1].vertex_inputs}")
            if layers[i].edge_outputs != layers[i+1].edge_inputs:
                raise ValueError(f"Layer {i} edge output dimension {layers[i].edge_outputs} "
                               f"doesn't match layer {i+1} edge input dimension {layers[i+1].edge_inputs}")

    def forward(self, h_graph: torch.Tensor, a_graph: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            h_graph: Input vertex features
            a_graph: Input edge features

        Returns:
            tuple(torch.Tensor, torch.Tensor): Final vertex and edge features
        """
        h, a = h_graph, a_graph
        for layer in self.layers:
            h, a = layer(h, a)
        return h, a

# Combine layers into model
BenzeneEquivariantNet = EnEquivariantNet([layer1, layer2, layer3, layer4])

# Create some example input data
h_input = torch.randn(num_vertices, input_vertex_rep.total_dim())
a_input = torch.randn(num_vertices, num_vertices, edge_rep.total_dim())

# Forward pass through model
h_output, a_output = BenzeneEquivariantNet(h_input, a_input)

print(f"Input vertex features shape: {h_input.shape}")
print(f"Output vertex features shape: {h_output.shape}")
print(f"Input edge features shape: {a_input.shape}")
print(f"Output edge features shape: {a_output.shape}")

EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.0247s, Total: 0.2963s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.0341s, Total: 0.4090s, Count: 12
edge_update          - Avg: 0.0032s, Total: 0.0384s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.1428s, Total: 1.7131s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.1382s, Total: 1.6583s, Count: 12
edge_update          - Avg: 0.0086s, Total: 0.1033s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.1904s, Total: 2.2846s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update

In [None]:
model = ForcePredictor(BenzeneEquivariantNet)
train_losses, val_losses = train_model(model, train_loader, val_loader)

Epoch 1/100 [Train]:   0%|          | 0/450 [00:00<?, ?it/s]

ForcePredictor forward pass starting...
batch_size: 2, n_atoms: 12
Created atom features
Processing molecule 1/2
Created h_graph with shape torch.Size([12, 6])
Starting model forward pass...
EGCL forward pass starting...


  with torch.cuda.amp.autocast(enabled=True):  # Mixed precision



EGCL Layer Timing Summary:
init_messages        - Avg: 0.0001s, Total: 0.0001s, Count: 1
intermediate_term    - Avg: 0.0725s, Total: 0.8700s, Count: 12
message_sums         - Avg: 0.0002s, Total: 0.0002s, Count: 1
vertex_update        - Avg: 0.0818s, Total: 0.9817s, Count: 12
edge_update          - Avg: 0.0061s, Total: 0.0728s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0001s, Total: 0.0001s, Count: 1
intermediate_term    - Avg: 0.1671s, Total: 2.0054s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.1292s, Total: 1.5509s, Count: 12
edge_update          - Avg: 0.0076s, Total: 0.0907s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.1393s, Total: 1.6719s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.1506s, Total:

Epoch 1/100 [Train]:   0%|          | 1/450 [01:29<11:11:36, 89.75s/it, loss=1566.694702]

ForcePredictor forward pass starting...
batch_size: 2, n_atoms: 12
Created atom features
Processing molecule 1/2
Created h_graph with shape torch.Size([12, 6])
Starting model forward pass...
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0001s, Total: 0.0001s, Count: 1
intermediate_term    - Avg: 0.0242s, Total: 0.2904s, Count: 12
message_sums         - Avg: 0.0002s, Total: 0.0002s, Count: 1
vertex_update        - Avg: 0.0308s, Total: 0.3696s, Count: 12
edge_update          - Avg: 0.0033s, Total: 0.0398s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.1717s, Total: 2.0609s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.1552s, Total: 1.8618s, Count: 12
edge_update          - Avg: 0.0075s, Total: 0.0902s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_mess

Epoch 1/100 [Train]:   0%|          | 2/450 [02:52<10:38:47, 85.55s/it, loss=499.160980]

ForcePredictor forward pass starting...
batch_size: 2, n_atoms: 12
Created atom features
Processing molecule 1/2
Created h_graph with shape torch.Size([12, 6])
Starting model forward pass...
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.0257s, Total: 0.3089s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.0328s, Total: 0.3939s, Count: 12
edge_update          - Avg: 0.0034s, Total: 0.0403s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_messages        - Avg: 0.0000s, Total: 0.0000s, Count: 1
intermediate_term    - Avg: 0.2277s, Total: 2.7329s, Count: 12
message_sums         - Avg: 0.0001s, Total: 0.0001s, Count: 1
vertex_update        - Avg: 0.1323s, Total: 1.5870s, Count: 12
edge_update          - Avg: 0.0073s, Total: 0.0879s, Count: 12
EGCL forward pass starting...

EGCL Layer Timing Summary:
init_mess

I've noticed that the EGCL forward pass is running very slowly, and I've added some timing to try to see why.  Specifically it's the intermediate_term and vertex_update steps (that compute a linear combination of the invariant functions) that are taking such a long time to run for each vertex.  Right now I'm not able to proceed until I can figure out a way of speeding them up.