<a href="https://colab.research.google.com/github/li-ziang/2025-Fall-MLG/blob/main/hw2_3d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Coding Practice 2 Graph Neural Networks with 3D Coordinates: An Exploration

Welcome to this hands-on exploration of Graph Neural Networks (GNNs) with a focus on 3D coordinates. In this assignment, we delve into the fascinating intersection of graph theory, neural networks, and geometric deep learning. You will have the opportunity to build and train GNNs that can effectively handle 3D structured data, a common occurrence in fields such as computational chemistry, material science, and computer vision.

Understanding how to manipulate and process graph data that includes 3D spatial information is critical for developing models that can learn from the geometric relationships inherent in many real-world datasets. Throughout this assignment, you will tackle the unique challenges posed by 3D data, learn to implement GNNs that are invariant or equivariant to 3D rotations and translations, and apply your knowledge to solve problems that require an understanding of the underlying spatial structure.

Get ready to enhance your machine learning toolkit with the capability to process 3D graph data, and prepare to unlock a new dimension of possibilities!

## Environment Setup

For a seamless execution of this notebook, ensure your Python environment is properly set up. Here's what you'll need:

- **Python Version**: We recommend using Python 3.8 or higher.
- **Required Packages**: Install the following libraries to delve into GNNs:
  - `torch`
  - `torch_geometric`
  - `torch_scatter`
  - `torch_sparse`
  - `torchmetrics`
  - `networkx`
  - `numpy`
  - `jupyter`
  - `rdkit-pypi`
  - `py3Dmol`
  - `pandas`
  - `seaborn`

- **For Local Testing**: If you wish to visualize and run tests outside this notebook, please also install:
  - `matplotlib`


In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install rdkit py3Dmol
!pip install -q torchmetrics

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import os
import random
import time
from IPython.display import HTML

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import py3Dmol
import seaborn as sns
import torch
from google.colab import files
from rdkit.Chem import AllChem, Crippen, QED, rdMolDescriptors, rdmolops
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Geometry.rdGeometry import Point3D
from scipy.stats import ortho_group
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
import torch.nn.functional as F

import rdkit.Chem as Chem
import torch_geometric
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
from torch_geometric.utils import dense_to_sparse, remove_self_loops, to_dense_adj
import torch_geometric.transforms as T
from torch_scatter import scatter

print("Torch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

## Graph Neural Networks in Chemistry

In the realm of computational chemistry, graph neural networks (GNNs) have emerged as a powerful tool to model and predict the properties of molecules. Molecules can be naturally depicted as graphs, where atoms serve as nodes and chemical bonds act as edges. This representation is particularly advantageous for **Molecular Property Prediction**, where GNNs can learn from existing molecular data to predict physical and chemical properties crucial in drug discovery and material science.


A striking illustration of GNN utility is in predicting molecular activities, which has direct implications for **drug discovery**. For instance, GNNs have been instrumental in identifying novel compounds with therapeutic potential by predicting their interaction with biological targets. A notable success story is the discovery of [**Halicin**](https://en.wikipedia.org/wiki/Halicin), a compound with promising antibacterial properties found through GNN-driven screening.


## Delving into the QM9 Dataset

The QM9 dataset is a comprehensive collection of around **130,000 small molecules** characterized by 19 different regression targets, making it a gold standard for evaluating GNNs in molecular property prediction. This dataset has gained prominence following its adoption by the [MoleculeNet](https://arxiv.org/abs/1703.00564) benchmark.

Our focus will be on predicting the [homolumo gap](https://en.wikipedia.org/wiki/HOMO/LUMO), a quantum property that is indicative of a molecule’s reactivity. The homolumo gap refers to the energy difference between the highest occupied molecular orbital (HOMO) and the lowest unoccupied molecular orbital (LUMO), as depicted below:


For our purposes, it's not necessary to delve into the quantum mechanics governing the homolumo gap. It suffices to understand that representing molecules as graphs with node features, edge features, and spatial coordinates allows us to leverage GNNs to predict such intricate properties from provided ground truth data.

With this understanding, let's proceed to load the QM9 dataset and examine the structure of these molecular graphs, a process simplified by the utilities provided by PyG.


In [None]:
class SetTarget(object):
    """
    This transform modifies the labels vector per data sample to only keep
    the label for a specific target (there are 19 targets in QM9).

    Note: for this practical, we allow the target to be set via the constructor.
    """
    def __init__(self, target=0):  # Allow target to be set at instantiation
        self.target = target

    def __call__(self, data):
        data.y = data.y[:, self.target]  # Use the target set in the constructor
        return data

class CompleteGraph(object):
    """
    This transform adds all pairwise edges into the edge index per data sample,
    then removes self loops, i.e., it builds a fully connected or complete graph.
    """
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

# Define the target outside of the classes for flexibility
target = 0

if 'IS_GRADESCOPE_ENV' not in os.environ:
    path = './qm9'

    # Transforms which are applied during data loading:
    # (1) Fully connect the graphs, (2) Select the target/label
    transform = T.Compose([CompleteGraph(), SetTarget(target)])

    # Load the QM9 dataset with the transforms defined
    dataset = QM9(path, transform=transform)

    # Normalize targets to have zero mean and unit variance
    mean, std = dataset.data.y[:, target].mean(), dataset.data.y[:, target].std()
    dataset.data.y[:, target] = (dataset.data.y[:, target] - mean) / std
    mean, std = mean.item(), std.item()


## Data Preparation and Dataset Splitting

The comprehensive QM9 dataset encompasses over **130,000** molecular graphs, providing a rich ground for training robust models in molecular property prediction.

For the scope of this assignment, we'll work with a curated subset of **4,500** molecular graphs. This subset size strikes a balance between computational manageability and sufficient data complexity. We will divide this subset into distinct training, validation, and test sets, each comprising 1,500 graphs. This partitioning allows for a thorough evaluation of our model's performance.

Later in the assignment, you'll have the opportunity to scale your experiments to larger portions of the QM9 dataset, challenging your model with an even broader array of molecular structures.


In [None]:
print(f"Total number of samples available for selection: {len(dataset)}.")

# Split datasets (our 4.5K subset)
train_dataset = dataset[:1500]
val_dataset = dataset[1500:3000]
test_dataset = dataset[3000:4500]
print(f"Dataset divisions established with {len(train_dataset)} samples for training, " +
      f"{len(val_dataset)} samples for validation, and {len(test_dataset)} samples for testing.")

# Create dataloaders with batch size = 256
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
data = train_dataset[0] # one data sample, i.e. molecular graph
print("One molecular graph contains:")
print(data)

## Understanding the Graph Representation of Molecules

In our dataset, each molecule is represented as a graph with specific attributes accessible via a `Data` object in PyTorch Geometric. Here's an overview of the graph features we'll work with:

**Node Features (`data.x`)**: Each node (atom) in the graph has an 11-dimensional feature vector that includes information such as atom type, atomic number, and other chemical properties.

**Edge Connectivity (`data.edge_index`)**: This is a tensor that defines which nodes (atoms) are connected to which, representing the bonds in the molecule.

**Edge Features (`data.edge_attr`)**: For each edge (bond), a 4-dimensional vector describes the type of bond between the connected atoms using one-hot encoding.

**Atomic Positions (`data.pos`)**: The 3D coordinates for each atom in the molecule are included, which will be crucial for our 3D graph neural network models.

**Target Property (`data.y`)**: We're interested in predicting a single property of the molecule, such as its electric dipole moment, represented here as a scalar value.

**Important Note**: In our approach, we consider fully-connected graphs, meaning every atom is connected to every other atom. The edge features will distinguish between actual chemical bonds and non-bonded pairs of atoms: real bonds are indicated by their bond type, while non-bonded pairs have zero vectors for their edge attributes.

This setup allows our model to learn from both the chemical bonds and the spatial structure of the molecules, providing a comprehensive view of each molecule's potential properties.


In [None]:
atom_count = data.x.shape[0]
edge_count = data.edge_attr.shape[0]
atom_features = data.x.shape[1]
edge_features = data.edge_attr.shape[1]
coordinate_dimensions = data.pos.shape[1]
target_count = data.y.shape[0]

print("\nExploring the molecular graph:")
print("Number of atoms:", atom_count)
print("Number of bonds:", edge_count)
print("Features per atom:", atom_features)
print("Features per bond:", edge_features)
print("Dimensions of spatial coordinates per atom:", coordinate_dimensions)
print("Total targets to predict for the molecule:", target_count)

print("\nUp next, we'll dive into constructing a Graph Neural Network using Message Passing to utilize these features for predicting molecular properties.")
print("The significance of the spatial coordinates will be covered in an upcoming section of this tutorial.")


In [None]:
#export
import torch
from torch.nn import Linear, BatchNorm1d, ReLU, Sequential, Module
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter_add, scatter


class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        super().__init__(aggr=aggr)  # Initialize the parent class with the aggregation method.

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # Message MLP: (2 * emb_dim + edge_dim) -> emb_dim
        self.mlp_msg = Sequential(
            Linear(2 * emb_dim + edge_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU()
        )

        # Update MLP: (2 * emb_dim) -> emb_dim
        self.mlp_upd = Sequential(
            Linear(2 * emb_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
            BatchNorm1d(emb_dim),
            ReLU()
        )

    def forward(self, h, edge_index, edge_attr):
        # Propagate messages using the defined message and update functions.
        return self.propagate(edge_index, h=h, edge_attr=edge_attr)

    def message(self, h_i, h_j, edge_attr):
        # Constructs messages for each edge in the graph.
        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)

    def aggregate(self, inputs, index, dim_size=None):
        # Aggregates messages using the specified aggregation method.
        return scatter_add(inputs, index, dim=self.node_dim, dim_size=dim_size)

    def update(self, aggr_out, h):
        # Updates node features by combining aggregated messages with initial node features.
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self):
        return f'{self.__class__.__name__}(emb_dim={self.emb_dim}, edge_dim={self.edge_dim}, aggr={self.aggr})'


In [None]:
#export
class NaiveModel(Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """
        Message Passing Neural Network (MPNN) model for graph property prediction.

        Parameters:
        - num_layers (int): Number of message passing layers.
        - emb_dim (int): Embedding dimension for node features.
        - in_dim (int): Dimension of initial node features.
        - edge_dim (int): Dimension of edge features.
        - out_dim (int): Dimension of output features.
        """
        super(NaiveModel, self).__init__()

        self.embedding = Linear(in_dim, emb_dim)
        self.layers = torch.nn.ModuleList([
            MPNNLayer(emb_dim, edge_dim) for _ in range(num_layers)
        ])
        self.readout = global_mean_pool
        self.output = Linear(emb_dim, out_dim)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Embed initial node features
        x = self.embedding(x)

        # Apply MPNN layers with residual connections
        for layer in self.layers:
            x = x + layer(x, edge_index, edge_attr)

        # Pool graph representations
        x = self.readout(x, batch)

        # Apply final output layer
        return self.output(x).squeeze()


In [None]:
def permute_graph(data, perm):
    """Helper function for permuting PyG Data object attributes consistently.
    """
    # Permute the node attribute ordering
    data.x = data.x[perm]
    data.pos = data.pos[perm]
    data.z = data.z[perm]
    data.batch = data.batch[perm]

    # Permute the edge index
    adj = to_dense_adj(data.edge_index)
    adj = adj[:, perm, :]
    adj = adj[:, :, perm]
    data.edge_index = dense_to_sparse(adj)[0]

    # Note:
    # (1) While we originally defined the permutation matrix P as only having
    #     entries 0 and 1, its implementation via `perm` uses indexing into
    #     torch tensors, instead.
    # (2) It is cumbersome to permute the edge_attr, so we set it to constant
    #     dummy values. For any experiments beyond unit testing, all GNN models
    #     use the original edge_attr.

    return data

def permutation_invariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN model) is
    permutation invariant.
    """
    it = iter(dataloader)
    data = next(it)

    # Set edge_attr to dummy values (for simplicity)
    data.edge_attr = torch.zeros(data.edge_attr.shape)

    # Forward pass on original example
    out_1 = module(data)

    # Create random permutation
    perm = torch.randperm(data.x.shape[0])
    data = permute_graph(data, perm)

    # Forward pass on permuted example
    out_2 = module(data)

    # Check whether output varies after applying transformations
    return torch.allclose(out_1, out_2, atol=1e-04)


def permutation_equivariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN layer) is
    permutation equivariant.
    """
    it = iter(dataloader)
    data = next(it)

    # Set edge_attr to dummy values (for simplicity)
    data.edge_attr = torch.zeros(data.edge_attr.shape)

    # Forward pass on original example
    out_1 = module(data.x, data.edge_index, data.edge_attr)

    # Create random permutation
    perm = torch.randperm(data.x.shape[0])
    data = permute_graph(data, perm)

    # Forward pass on permuted example
    out_2 = module(data.x, data.edge_index, data.edge_attr)

    # Check whether output varies after applying transformations
    return torch.allclose(out_1[perm], out_2, atol=1e-04)

In [None]:
layer = MPNNLayer(emb_dim=11, edge_dim=4)
model = NaiveModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Permutation invariance unit test for MPNN model
print(f"Is {type(model).__name__} permutation invariant? --> {permutation_invariance_unit_test(model, dataloader)}!")

# Permutation equivariance unit for MPNN layer
print(f"Is {type(layer).__name__} permutation equivariant? --> {permutation_equivariance_unit_test(layer, dataloader)}!")

In [None]:
def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch.to(device)
        optimizer.zero_grad()
        predictions = model(batch)
        loss = F.mse_loss(predictions, batch.y)
        loss.backward()
        total_loss += loss.item() * batch.num_graphs
        optimizer.step()

    average_loss = total_loss / len(train_loader.dataset)
    return average_loss

def eval(model, loader, device):
    model.eval()
    error = 0

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            y_pred = model(data)
            # Mean Absolute Error using std (computed when preparing data)
            error += ((y_pred - data.y) * std).abs().sum().item()
    return error / len(loader.dataset)



def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100):
    print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("\ndevice:", device, "\nModel architecture:")
    print(model)
    total_params = sum(np.prod(p.size()) for p in model.parameters())
    print(f'Total parameters: {total_params}')

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)

    print("\nStart training:")
    best_val_error = float('inf')
    best_test_error = float('inf')
    perf_per_epoch = []

    start_time = time.time()
    for epoch in range(1, n_epochs + 1):
        loss = train(model, train_loader, optimizer, device)
        val_error = eval(model, val_loader, device)

        if val_error < best_val_error:
            best_val_error = val_error
            best_test_error = eval(model, test_loader, device)

        if epoch % 10 == 0:
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f'Epoch: {epoch:03d}, LR: {current_lr:.6f}, Loss: {loss:.7f}, '
                  f'Val MAE: {val_error:.7f}, Test MAE: {best_test_error:.7f}')

        scheduler.step(val_error)
        perf_per_epoch.append((best_test_error, val_error, epoch, model_name))

    elapsed_time = time.time() - start_time
    train_time_minutes = elapsed_time / 60
    print(f"\nDone! Training took {train_time_minutes:.2f} mins. "
          f"Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {best_test_error:.7f}.")

    return best_val_error, best_test_error, train_time_minutes, perf_per_epoch



In [None]:
results = {}

In [None]:
model = NaiveModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)
model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name,
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)
results[model_name] = (best_val_error, test_error, train_time)

In [None]:
results

## Task 1: Implement a Message Passing Neural Network Utilizing Atom Coordinates as Node Features [3 pts]

The baseline MPNN, labeled `NaiveModel`, does not consider the atom coordinates and relies solely on node features for message passing. This approach misses out on critical 3D structural information that could be pivotal for predicting the target property.

The objective of your first task is to enhance the `NaiveModel` by integrating the atom coordinates with the node features.

While we have outlined the structure of the `PositionModel` class, certain sections marked `TODO` remain incomplete and require your implementation.

Keep in mind that the 3D atom positions can be accessed via `data.pos`. At this stage, a straightforward approach, such as concatenating or summing the coordinates with the features, would suffice to make progress.


In [None]:
#export
### DO NOT CHANGE ANY CODE ABOVE THIS LINE IN THIS CELL ###

class PositionModel(NaiveModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """Initializes a PositionModel which is an extension of NaiveModel that
        includes atom coordinates in addition to the node features.

        Parameters:
            num_layers (int): The number of message passing layers.
            emb_dim (int): The embedding dimension of node features.
            in_dim (int): The dimension of initial node features.
            edge_dim (int): The dimension of edge features.
            out_dim (int): The dimension of the model output.
        """
        super().__init__()

        # TODO: Define input linear layer
        # Input layer that projects initial node features and coordinates

        # TODO: Define message passing layers
        # self.layers =

        # Define global pooling function (mean pooling)
        self.pool = global_mean_pool

        # Output layer that predicts the graph property
        self.output = Linear(emb_dim, out_dim)

    def forward(self, data):
        """Performs a forward pass on the graph data.

        Parameters:
            data (PyG.Data): A batch of graphs in PyG format.

        Returns:
            torch.Tensor: The output predictions for each graph in the batch.
        """
        # TODO: decide how to handle the input information
        # Combine node features with atom positions

        # Apply message passing layers with residual connections
        for layer in self.layers:
            h = h + layer(h, data.edge_index, data.edge_attr)

        # Apply global mean pooling to get graph-level representation
        h_graph = self.pool(h, data.batch)

        # Predict the target property for each graph
        out = self.output(h_graph)

        # Flatten output for consistency
        return out.view(-1)



######################################################################
########## DON'T WRITE ANY CODE OUTSIDE THE CLASS! ###################
######## IF YOU WANT TO CALL OR TEST IT CREATE A NEW CELL ############
######################################################################

In [None]:
layer = MPNNLayer(emb_dim=11, edge_dim=4, aggr='add')
model = PositionModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Permutation invariance unit test for Position model
print(f"Is {type(model).__name__} permutation invariant? --> {permutation_invariance_unit_test(model, dataloader)}!")

# Permutation equivariance unit for MPNN layer
print(f"Is {type(layer).__name__} permutation equivariant? --> {permutation_equivariance_unit_test(layer, dataloader)}!")

In [None]:
model = PositionModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name,
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)

results[model_name] = (best_val_error, test_error, train_time)

In [None]:
results

In [None]:
torch.save(model.state_dict(), "pos_gnn.pth")
print("Saved PyTorch Model State to pos_gnn.pth")

Upon successful implementation of the `PositionModel`, an interesting observation emerges:
the model's performance is comparable or slightly inferior to that of the basic `NaiveModel`.

This performance pattern suggests that the `PositionModel` may not be effectively leveraging the
3D structural information during its computations.

In the following sections, we aim to delve deeper into the reasons behind this phenomenon and seek
to establish a more methodical approach to utilizing 3D structural data.


## Understanding Invariance to 3D Symmetries

The `PositionModel` did not outperform the `NaiveModel` as expected, despite its use of additional coordinate data. Before you worry about the implications of your findings, it's important to delve into the concept of 3D symmetries.

### Understanding Geometric Invariance

Molecular graphs come with 3D atomic coordinates. An important aspect we've not discussed yet is the relativity of these coordinates. They are not static; instead, they are determined in relation to a reference point.

Consider this visual of a molecule floating in 3D space, showcasing how it rotates and translates:

<!-- Image placeholder for molecule GIF -->

Despite the constant movement in coordinates, the intrinsic properties of the molecule remain constant. They are invariant to any changes in position or orientation in space.

This part of our study will focus on creating GNN layers and models that honor this invariance.

### Defining the Formalism

We will now define the concept of invariance within GNNs using matrix representation.

- Define $\mathbf{H} \in \mathbb{R}^{n \times d}$ as the feature matrix for a molecule's graph, with $n$ atoms and $d$ features per atom.
- Define $\mathbf{X} \in \mathbb{R}^{n \times 3}$ as the coordinate matrix for the graph's atoms.
- Define $\mathbf{A} \in \mathbb{R}^{n \times n}$ as the adjacency matrix indicating connections between atoms.
- Define $\mathbf{F}(\mathbf{H}, \mathbf{X}, \mathbf{A})$ as a GNN layer that processes these matrices to update node features.
- Define $f(\mathbf{H}, \mathbf{X}, \mathbf{A})$ as a GNN model that uses the matrices to predict a property at the graph level.

We have updated the notations for the GNN layer $\mathbf{F}$ and the GNN model $\mathbf{f}$ to incorporate $\mathbf{X}$, representing the node coordinates.


In [None]:
def random_orthogonal_matrix(dim=3):
  """Helper function to build a random orthogonal matrix of shape (dim, dim)
  """
  Q = torch.tensor(ortho_group.rvs(dim=dim)).float()
  return Q


def rot_trans_invariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN model/layer) is
    rotation and translation invariant.
    """
    it = iter(dataloader)
    data = next(it)

    if isinstance(module, NaiveModel):
        out_1 = module(data)
    else: # if ininstance(module, MessagePassing):
        out_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    Q = random_orthogonal_matrix(dim=3)
    t = torch.rand(3)
    data.pos = data.pos @ Q + t

    if isinstance(module, NaiveModel):
        out_2 = module(data)
    else: # if ininstance(module, MessagePassing):
        out_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    return torch.allclose(out_1, out_2, atol=1e-04)

In [None]:
# Instantiate temporary model, layer, and dataloader for unit testing
model = PositionModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)
dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Rotation and translation invariance unit test for MPNN model
print(f"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!")

## Task 2: Construct a Message Passing Layer and MPNN Model with 3D Invariance (3 pts)

Task Objective:
Develop a Message Passing Layer (`InvariantLayer`) and an MPNN Model (`InvariantModel`) that are inherently invariant to transformations in 3D space, such as rotations and translations.

Background:
The original MPNN model (`NaiveModel`) did not consider atom coordinates, relying exclusively on node features. The subsequent model (`PositionModel`) included coordinates but failed to maintain invariance to spatial transformations, missing out on the robustness required for accurate property prediction regardless of the molecule's orientation or position.

Your Role:
Create the `InvariantLayer` that effectively integrates atom coordinates and node features. This new layer should contribute to an MPNN Model that is resilient to changes in the molecular structure's spatial configuration.

We provide the skeletal structure of `InvariantLayer` with sections marked for your implementation. The `InvariantModel` is pre-defined and will utilize your newly created layer.

Hints for Implementation:
- Rethink the integration of coordinate data in message construction, rather than mixing it directly with node features.
- Focus on deriving a property from coordinate pairs that remains unchanged with spatial movements of the molecule.
- Utilize the naming convention in `propagate()` to map tensors to their respective nodes.


In [None]:
#export
### DO NOT CHANGE ANY CODE ABOVE THIS LINE IN THIS CELL ###

class InvariantLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Initializes an MPNN layer that accounts for 3D coordinate invariance.

        Args:
            emb_dim (int): Size of each embedding vector.
            edge_dim (int): Size of each edge feature vector.
            aggr (str): Type of aggregation ('add', 'mean', or 'max').
        """
        super().__init__(aggr=aggr)
        self.emb_dim = emb_dim
        self.edge_dim = edge_dim
        # TODO: Define the mlp_msg
        # Message function network

        # Update function network
        self.mlp_upd = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
        )

    def forward(self, h, pos, edge_index, edge_attr):
        """Forward pass for generating updated node features."""
        #TODO: implement forward
        pass

    def message(self, h_i, h_j, pos_i, pos_j, edge_attr):
        """Computes messages for each edge in the graph."""
        # TODO: Implement message
        pass

    def update(self, aggr_out, h):
        """Updates node features after message aggregation."""
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self):
        return '{}(emb_dim={}, edge_dim={}, aggr={})'.format(
            self.__class__.__name__, self.emb_dim, self.edge_dim, self.aggr)

class InvariantModel(NaiveModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """
        Constructs a graph neural network model that ensures invariance to 3D coordinate transformations.

        The model is designed to process graph data where each node is associated with both
        features and 3D spatial coordinates, and it produces a prediction invariant to rotations
        and translations of the coordinate space.

        Parameters:
            num_layers (int): The count of message-passing layers within the network.
            emb_dim (int): The dimensionality of the embedding space for node features.
            in_dim (int): The size of the input feature vector for each node.
            edge_dim (int): The size of the feature vector for each edge.
            out_dim (int): The size of the output vector; set to 1 for scalar predictions.
        """
        # Call the constructor of the parent NaiveModel class
        super().__init__()

        # Feature transformation layer to embed input node features into a higher-dimensional space
        self.embedding = Linear(in_dim, emb_dim)

        # Construct a sequence of graph convolution layers that are invariant to spatial transformations
        self.layers = torch.nn.ModuleList([InvariantLayer(emb_dim, edge_dim, aggr='add') for _ in range(num_layers)])

        # Define a pooling operation that aggregates node embeddings across the graph to form a graph-level representation
        self.pool = global_mean_pool

        # Final linear layer that maps the graph-level representation to the prediction space
        self.output = Linear(emb_dim, out_dim)


    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns:
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.embedding(data.x)

        for conv in self.layers:
            h = h + conv(h, data.pos, data.edge_index, data.edge_attr)

        h_graph = self.pool(h, data.batch)

        out = self.output(h_graph)
        return out.view(-1)

######################################################################
########## DON'T WRITE ANY CODE OUTSIDE THE CLASS! ###################
######## IF YOU WANT TO CALL OR TEST IT CREATE A NEW CELL ############
######################################################################

In [None]:
layer = InvariantLayer(emb_dim=11, edge_dim=4, aggr='add')
model = InvariantModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)


dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

print(f"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!")

print(f"Is {type(layer).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(layer, dataloader)}!")

In [None]:
model = InvariantModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name, # "MPNN w/ Features and Coordinates (Invariant Layers)",
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)

results[model_name] = (best_val_error, test_error, train_time)

In [None]:
results

In [None]:
torch.save(model.state_dict(), "invar_gnn.pth")
print("Saved PyTorch Model State to invar_gnn.pth")

You've progressed from the basic NaiveModel through a simple
coordinate-inclusive PositionModel, to the geometrically sophisticated
InvariantMPNN model. Next, we'll explore even deeper into the geometric
aspects of molecular structures to extract more valuable insights.

## Message Passing with Equivariance to 3D Rotations and Translations

Following our exploration of invariance to 3D rotation and translation, we now turn our attention to developing a GNN for molecular property prediction that leverages message passing layers with equivariance to these transformations.

### The Importance of Geometric Equivariance

We draw a parallel between permutation symmetries in GNNs, translation symmetries in ConvNets for 2D images, and now geometric equivariance in molecular models.

#### Permutation Symmetry in GNNs and DeepSets

Previously, we discussed how a GNN layer should be permutation equivariant, meaning the order of nodes should not affect the output, whereas the overall GNN model should be permutation invariant for graph-level predictions. This design allows the model to capture the relational structure of the graph. DeepSets, another permutation invariant model, can be used for graph-level predictions but may not capture the relational intricacies as effectively as GNNs, which use equivariant layers to build complex node representations.

#### Translation Symmetry in ConvNets for 2D Images

Similarly, ConvNets for image processing are translation invariant overall but composed of translation equivariant convolution filters. These filters detect features like edges and patterns regardless of their position in the input space, allowing the ConvNet to build hierarchical features and learn complex visual concepts.

In conclusion, designing GNN layers that are equivariant to 3D rotations and translations could allow us to better capture the geometric structure of molecules and enhance model performance, just as equivariant layers in ConvNets allow for the extraction of complex visual patterns.


### Understanding Equivariance in GNNs

We've established the significance of creating GNN layers that are equivariant to 3D rotations and translations. Now, let's define this concept with a mathematical framework.

- Consider a matrix $\mathbf{H} \in \mathbb{R}^{n \times d}$ representing node features in a molecular graph, where $n$ is the number of nodes, and each row $h_i$ denotes the feature vector of dimension $d$ for node $i$.
- Let $\mathbf{X} \in \mathbb{R}^{n \times 3}$ represent the coordinates of nodes in the molecular graph, with each row $x_i$ corresponding to the 3D coordinates of node $i$.
- The adjacency matrix $\mathbf{A} \in \mathbb{R}^{n \times n}$ indicates the connections between nodes, with $a_{ij}$ marking the link between nodes $i$ and $j$.
- A GNN layer $\mathbf{F}(\mathbf{H}, \mathbf{X}, \mathbf{A}): \mathbb{R}^{n \times d} \times \mathbb{R}^{n \times 3} \times \mathbb{R}^{n \times n} \rightarrow \mathbb{R}^{n \times d}\times \mathbb{R}^{n \times 3}$ accepts node features, coordinates, and the adjacency matrix to return updated node features and coordinates.
- The GNN model $f(\mathbf{H}, \mathbf{X}, \mathbf{A}): \mathbb{R}^{n \times d} \times \mathbb{R}^{n \times 3} \times \mathbb{R}^{n \times n} \rightarrow \mathbb{R}$ computes the graph-level property from these inputs.

The GNN model comprises several layers $\mathbf{F}^{\ell}(\mathbf{H}^{\ell}, \mathbf{X}^{\ell}, \mathbf{A})$ that are equivariant to rotations and translations.

### Distinction from Invariant Message Passing

Equivariant message passing differs from invariant message passing because it updates not only the node features but also their coordinates:

\[
\mathbf{H}^{\ell+1}, \mathbf{X}^{\ell+1} = \mathbf{F}^{\ell} (\mathbf{H}^{\ell}, \mathbf{X}^{\ell}, \mathbf{A}).
\]

This method is especially useful when modeling dynamical systems where node coordinates change due to intermolecular forces.

Note these nuances regarding equivariant message passing layers $\mathbf{F}$:
- Updated node coordinates $\mathbf{X'}$ are equivariant to the 3D transformations of the initial coordinates $\mathbf{X}$.
- Updated node features $\mathbf{H'}$ remain invariant to the 3D transformations of $\mathbf{X}$, similar to invariant message passing.
- The overall MPNN model $f$ is invariant to 3D transformations because it predicts a single scalar quantity that does not change with the atoms' coordinate transformations.

The final prediction of the GNN model is based on the graph embedding derived from the final node features after multiple layers of message passing, disregarding the final node coordinates.

We aim to explore how using equivariant message passing layers can enhance a GNN model that is invariant to 3D symmetries.

Let's proceed with our exploration.


## Task 3: Develop an Equivariant Message Passing Layer [4pts]

**Objective:** Create a message passing layer that updates both the node features and coordinates, ensuring equivariance to 3D rotations and translations for a given molecular graph.

**Guidance:**
- The layer should handle both node features and coordinates, outputting a tuple that contains the updated versions of these elements.
- Implement `message()`, `aggregate()`, and `update()` functions to work with these tuples, considering the distinct nature of invariant and equivariant quantities.
- Invariant quantities should remain unchanged under 3D transformations, while equivariant quantities should change correspondingly with the coordinates.

**Considerations:**
- There are various strategies to achieve this, and simple replication of existing solutions from libraries like PyG is not acceptable.
- Avoid trivial solutions such as leaving the coordinates unchanged. The goal is to design a coordinate message function that intelligently aggregates information from neighboring nodes' coordinates, adhering to 3D symmetries.

The challenge lies in engineering a method for node coordinate updates through message passing that honors the principles of 3D equivariance.


In [None]:
#export

### DO NOT CHANGE ANY CODE ABOVE THIS LINE IN THIS CELL ###

class EquivariantLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """
        An MPNN layer that maintains equivariance to 3D geometric transformations.

        Parameters:
            emb_dim (int): The size of the hidden feature dimension.
            edge_dim (int): The dimensionality of edge features.
            aggr (str): The aggregation function to use.
        """
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # TODO: Define the essential variables





    def forward(self, h, pos, edge_index, edge_attr):
        """
        Conducts a message passing update on node features and coordinates.

        Parameters:
            h (Tensor): Initial node features.
            pos (Tensor): Node positions.
            edge_index (LongTensor): Edge indices.
            edge_attr (Tensor): Edge attributes.

        Returns:
            A tuple of updated node features and positions.
        """
        # Message passing with updated coordinates
        # TODO: implement forward
        pass



    def message(self, h_i, h_j, pos_i, pos_j, edge_attr):
        """
        Generates messages for each node based on neighboring nodes and edge attributes.

        Parameters:
            h_i (Tensor): The features of the destination nodes.
            h_j (Tensor): The features of the source nodes.
            pos_i (Tensor): Positions of destination nodes.
            pos_j (Tensor): Positions of source nodes.
            edge_attr (Tensor): Edge attributes.

        Returns:
            A tuple of message vectors and position updates.
        """
        # TODO: implement message
        # Determine message and position changes
        pass

    def aggregate(self, inputs, index):
        """
        Aggregates messages from neighboring nodes using the specified method.

        Parameters:
            inputs (tuple): Messages and positional updates from source nodes.
            index (Tensor): Indices of the source nodes.

        Returns:
            A tuple of aggregated messages and position updates.
        """
        # TODO: implement aggregate
        # Unpack inputs and aggregate
        pass


    def update(self, aggr_out, h, pos):
        """
        Updates node features using aggregated messages and initial features.

        Parameters:
            aggr_out (tuple): Aggregated messages and position updates.
            h (Tensor): Initial node features.

        Returns:
            Updated node features after applying the update MLP.
        """
        # TODO: implement update
        # Extract aggregated results and apply updates
        pass

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class EquivariantModel(NaiveModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """
        Constructs an MPNN model that predicts graph properties while considering node features and spatial coordinates.

        Parameters:
            num_layers (int): The number of message passing layers in the model.
            emb_dim (int): The size of the hidden feature dimension.
            in_dim (int): The dimensionality of the initial node features.
            edge_dim (int): The dimensionality of edge features.
            out_dim (int): The dimension of the model's output.
        """
        super().__init__()

        self.embedding = Linear(in_dim, emb_dim)  # Linear transformation for input features
        self.layers = torch.nn.ModuleList([EquivariantLayer(emb_dim, edge_dim, aggr='add') for _ in range(num_layers)])  # Equivariant layers
        self.pool = global_mean_pool  # Mean pooling layer
        self.output = Linear(emb_dim, out_dim)  # Prediction layer


    def forward(self, data):
        """
        Feeds data through the model to generate predictions.

        Parameters:
            data (Data): The input graph data.

        Returns:
            The predicted property for each graph in the batch.
        """
        h = self.embedding(data.x)
        pos = data.pos

        for conv in self.layers:
            h, pos = conv(h, pos, data.edge_index, data.edge_attr)  # Update features and positions

        pooled_features = self.pool(h, data.batch)  # Pool features to graph-level

        return self.output(pooled_features).view(-1)  # Predict graph properties

######################################################################
########## DON'T WRITE ANY CODE OUTSIDE THE CLASS! ###################
######## IF YOU WANT TO CALL OR TEST IT CREATE A NEW CELL ############
######################################################################

In [None]:
def rot_trans_equivariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN layer) is
    rotation and translation equivariant.
    """
    it = iter(dataloader)
    data = next(it)

    # Original forward pass
    out_1, pos_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    # Create random rotation and translation
    Q = random_orthogonal_matrix(dim=3)
    t = torch.rand(3)

    # Apply rotation and translation
    rotated_translated_pos = (data.pos @ Q.T) + t

    # Forward pass on rotated + translated example
    out_2, pos_2 = module(data.x, rotated_translated_pos, data.edge_index, data.edge_attr)

    # Check whether output node features are the same (they should be equivariant)
    feature_equivariance = torch.allclose(out_1, out_2, atol=1e-4)

    # Check whether output positions are rotated and translated versions of pos_1
    # Since pos_2 = Q * pos_1 + t, we should have pos_2 - t = Q * pos_1
    pos_equivariance = torch.allclose((pos_2 - t), (pos_1 @ Q.T), atol=1e-4)

    return feature_equivariance and pos_equivariance

In [None]:
# Instantiate the EquivariantLayerh appropriate dimensions
layer = EquivariantLayer(emb_dim=11, edge_dim=4, aggr='add')

# Instantiate the EquivariantModel with a specified number of layers and dimensions
model = EquivariantModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Rotation and translation invariance unit test for MPNN model
print(f"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!")

# Rotation and translation invariance unit test for MPNN layer
print(f"Is {type(layer).__name__} rotation and translation equivariant? --> {rot_trans_equivariance_unit_test(layer, dataloader)}!")

We have successfully created the `EquivariantLayer` and `EquivariantModel`, validating their 3D rotation and translation equivariance through theory and tests.

Now, it's time to test our most advanced model that incorporates geometric principles.

## Train and Test the EquivariantModel
Train your `EquivariantModel` and evaluate its performance. Then, think about the outcomes in comparison to the earlier models: the basic `NaiveModel`, the `PositionModel` with straightforward coordinate usage, and the `InvariantModel`. Determine whether the new model shows superior performance and whether the improvement is substantial or marginal.


For a fair comparison, configure the `EquivariantMPNNModel` with four message passing layers and a hidden dimension size of 64, aligning with the previous `NaiveModel`, `PositionModel`, and `InvariantModel` configurations.


In [None]:
model = EquivariantModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name, # "MPNN w/ Features and Coordinates (Equivariant Layers)",
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)

results[model_name] = (best_val_error, test_error, train_time)

In [None]:
results

Great job! You've progressed through various models, starting with the basic `NaiveModel`, moving to the elementary implementation of coordinate information in `PositionModel`, advancing to a geometrically informed `InvariantModel`, and culminating with the `EquivariantModel`. This latest model maintains **invariance to 3D rotations and translations** and integrates **message passing layers that are equivariant** to these spatial transformations.


In [None]:
torch.save(model.state_dict(), "equ_gnn.pth")
print("Saved PyTorch Model State to equ_gnn.pth")

## Submission Guidelines

Ensure you've thoroughly tested your code locally before submitting it for evaluation. Your submission to Gradescope should be a zip file containing specific files related to your solution and the trained models.

### Submission Checklist:

Ensure your zip file contains the following items:

1. **Notebook File**:
   - `hw2_3d.ipynb`: The Jupyter notebook containing all your code and answers.

2. **Model Files**:
   - `pos_gnn.pth`: The saved model file for the Node GCN model.
   - `invar_gnn.pth`: The saved model file for the Graph GCN model.
   - `equ_gnn.path`: The saved model file for the Graph GNN model.

All model files should adhere to the structure defined within your notebook.

### Submission Instructions:

- **File Format**: Submit all your files in a **ZIP** format.
- **File Structure**: Avoid including a root directory in the zip file. Ensure all your files are compressed directly without a containing folder.
- **Validation**: Before submitting, verify your code runs as expected and all outputs align with anticipated results.
  
### Additional Notes:

- Ensure your models maintain the same structure as defined within your notebook.
- Be mindful of ensuring all necessary components are included to avoid discrepancies during the evaluation process.

**CRUCIAL**: It's paramount to test your zip file in a fresh environment to confirm it runs seamlessly and to verify all essential components are included.

Best of luck with your submission!
