Skip to content

MARA is a plug-and-play SE(3)-equivariant angular-radial attention module that improves accuracy and robustness of molecular force fields.

Notifications You must be signed in to change notification settings

MonsieurSolver/MARA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 

Repository files navigation

MARA – Modular Angular-Radial Attention

This repository contains the implementation of MARA, a spherical attention module integrated into the MACE framework for equivariant message passing on atomic systems.

How to use it

from MARA import SEAttention

# Init
self.spherical_attention = SEAttention(input_size=128,
                                       hidden_size=32,
                                       H=4,
                                       W=8)

# Forward
m_ji, gate = self.spherical_attention(
    features,    # [N, F]
    positions,   # [N, 3]
    edge_index   # [2, E]
)

Overview

MARA introduces a spherical attention mechanism operating on a discretized spherical grid and is designed to enhance message passing in equivariant neural networks. The module is integrated into the RealAgnosticInteractionBlock and RealAgnosticResidualInteractionBlock of MACE and returns both attention-weighted messages and the corresponding attention weights.

An overview of the model architecture is shown in the figure below.

Model architecture

Requirements

The implementation has been developed and tested with the following setup:

  • Python 3.10 with PyTorch 2.8.0 and CUDA 12.8
pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
  • torch-scatter 2.8.0+cu128 - scatter operator
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.8.0+cu128.html

We recommend using a CUDA-enabled GPU for both training and inference.

Spherical Attention Module

The spherical attention operates on a discretized spherical grid. By default, we use a 4 × 8 grid resolution, unless otherwise specified.

The module returns:

  • attention-weighted messages
  • the corresponding attention weights (for analysis and visualization)

Integration into MACE

The module is integrated into the following MACE blocks:

  • RealAgnosticInteractionBlock
  • RealAgnosticResidualInteractionBlock

The message passing procedure is as follows:

m_ji = self.conv_tp(
    node_feats[edge_index[0]],
    edge_attrs,
    tp_weights
)

m_ji, att = self.spherical_attention(
    m_ji,
    positions,
    edge_index,
    edge_feats
)

message = scatter_sum(
    src=m_ji,
    index=edge_index[1],
    dim=0,
    dim_size=node_feats.shape[0]
)

Training and Hardware

  • Training was primarily performed on NVIDIA H100 GPUs
  • Inference benchmarks were conducted on an RTX 4090

About

MARA is a plug-and-play SE(3)-equivariant angular-radial attention module that improves accuracy and robustness of molecular force fields.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published