# Comformer Tutorial: Graph Neural Networks for Material Properties

This tutorial provides a comprehensive introduction to Comformer, a Graph Neural Network (GNN) model for predicting material properties. We'll walk through the entire process from loading data to inference, explaining each step in detail.

## Introduction

Comformer is a GNN model designed for materials science applications. Unlike traditional Convolutional Neural Networks (CNNs) that operate on regular grid-like data (e.g., images), GNNs can process irregular data structures like atomic structures in materials. This makes them particularly suitable for predicting material properties based on atomic arrangements.

In this tutorial, we'll explore:
1. Loading material data
2. Creating graph representations of materials
3. Understanding feature generation
4. Setting up and training the Comformer model
5. Making predictions with the trained model

Let's begin!

## 1. Setting Up the Environment

First, let's import the necessary libraries and set up our environment.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from jarvis.core.atoms import Atoms, pmg_to_atoms
from pymatgen.core import Structure
import matplotlib as mpl
import seaborn as sns

# Set up paths to the Comformer code
import sys
sys.path.append('../') # Add parent directory to path

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Loading and Exploring the Dataset

We'll use the sample data provided in the repository to demonstrate how to work with material data. Let's first load and explore this dataset.

In [None]:
# Load the sample dataset
sample_data_path = "../sample_data/surface_prop_data_set_top_bottom.csv"
df = pd.read_csv(sample_data_path, on_bad_lines="skip")

# Display the first few rows
print(f"Dataset shape: {df.shape}")
df.head()

### Exploring dataset properties

Let's explore the dataset to better understand what we're working with.

In [None]:
# Look at the available columns
print("Available columns:")
print(df.columns.tolist())

# Basic statistics of numerical properties
print("\nBasic statistics of target properties:")
targets = ['WF_bottom', 'WF_top', 'cleavage_energy']
print(df[targets].describe())

# Check for missing values
print("\nMissing values:")
print(df[targets].isna().sum())

### Visualizing the target distributions

Let's visualize the distribution of our target properties.

In [None]:
# Plot distributions of target properties
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for i, target in enumerate(targets):
    sns.histplot(df[target], kde=True, ax=axes[i])
    axes[i].set_title(f'Distribution of {target}')
    axes[i].set_xlabel(target)
    axes[i].set_ylabel('Count')

plt.tight_layout()
plt.show()

### Exploring the structure data

The key element of materials datasets is the atomic structure data. Let's examine how the structure data is stored and represented.

In [None]:
# Let's look at the structure data format
print("Example of structure data:")
print(df['slab'].iloc[0][:500] + '...')

# Parse one structure to understand the format
structure_dict = eval(df['slab'].iloc[0])
structure = Structure.from_dict(structure_dict)
atoms = pmg_to_atoms(structure)

print(f"\nStructure information:")
print(f"Number of atoms: {len(atoms)}")
print(f"Chemical species: {set(atoms.elements)}")
print(f"Lattice parameters: a={atoms.lattice.a:.4f}, b={atoms.lattice.b:.4f}, c={atoms.lattice.c:.4f}")
print(f"Angles: alpha={atoms.lattice.alpha:.2f}, beta={atoms.lattice.beta:.2f}, gamma={atoms.lattice.gamma:.2f}")

## 3. Data Preparation for Graph Neural Networks

Now let's implement the key functions from the Comformer codebase to prepare the data for the GNN. This involves:
1. Creating a unique ID for each structure
2. Preprocessing the structure data
3. Converting structures to graphs

Let's implement these steps one by one.

In [None]:
# Step 1: Create a unique ID for each structure
df["jid"] = df["mpid"].astype(str) + df["miller"].astype(str) + df["term"].astype(str)

# Step 2: Rename 'slab' to 'atoms' if needed (to match the expected format)
if "slab" in df.columns:
    df = df.rename(columns={"slab": "atoms"})

# Choose our target property
target = "WF_top"  # We'll predict the work function of the top surface

# Print information about the prepared dataset
print(f"Number of samples: {len(df)}")
print(f"Target property: {target}")
print(f"Target range: {df[target].min():.4f} to {df[target].max():.4f}")

### 3.1 Graph Generation Explained

Before we convert the structures to graphs, let's understand what this process involves:

1. **Nodes**: Each atom in the structure becomes a node in the graph
2. **Node Features**: Each node has features like atomic number, which captures chemical information
3. **Edges**: Connections between atoms, usually based on distance (e.g., k-nearest neighbors)
4. **Edge Features**: Properties of the connections, like interatomic distances

Now let's implement the graph conversion function from the Comformer codebase.

In [None]:
# Import the necessary modules from Comformer
from comformer.graphs import PygGraph

def atoms_to_graph(atoms_string, 
                   neighbor_strategy="k-nearest",
                   cutoff=8.0,
                   max_neighbors=12,
                   use_canonize=False,
                   use_lattice=False,
                   use_angle=False):
    """Convert a structure string to a PyG graph."""
    # Parse the structure
    structure = pmg_to_atoms(Structure.from_dict(eval(atoms_string)))
    
    # Create the graph
    graph = PygGraph.atom_dgl_multigraph(
        structure,
        neighbor_strategy=neighbor_strategy,
        cutoff=cutoff,
        atom_features="atomic_number",
        max_neighbors=max_neighbors,
        compute_line_graph=False,
        use_canonize=use_canonize,
        use_lattice=use_lattice,
        use_angle=use_angle,
    )
    return graph

# Convert a sample structure to a graph to examine
sample_graph = atoms_to_graph(df['atoms'].iloc[0])

# Examine the graph properties
print(f"Graph properties:")
print(f"Number of nodes: {sample_graph.x.shape[0]}")
print(f"Node feature dimension: {sample_graph.x.shape[1]}")
print(f"Number of edges: {sample_graph.edge_index.shape[1]}")
print(f"Edge feature dimension: {sample_graph.edge_attr.shape[1] if hasattr(sample_graph, 'edge_attr') else 'N/A'}")

# Show the first few node features
print("\nFirst 5 node features:")
print(sample_graph.x[:5])

### 3.2 Create a Custom Dataset Class

Now, let's implement a simplified version of the `PygStructureDataset` class to prepare our dataset for training.

In [None]:
from torch_geometric.data import Data
from torch_geometric.data.batch import Batch
from torch.utils.data import Dataset

class SimplePygStructureDataset(Dataset):
    """A simplified version of the Comformer PygStructureDataset."""
    
    def __init__(self, df, target, atom_features="atomic_number"):
        self.df = df
        self.target = target
        self.ids = self.df["jid"]
        
        # Convert targets to tensor
        self.labels = torch.tensor(self.df[target].values).float()
        
        # Normalize labels
        self.mean = self.labels.mean()
        self.std = self.labels.std()
        self.labels = (self.labels - self.mean) / self.std
        
        print(f"Target normalization: mean={self.mean:.4f}, std={self.std:.4f}")
        
        # Convert structures to graphs
        print("Converting structures to graphs...")
        self.graphs = []
        for i, atoms_str in enumerate(self.df['atoms']):
            if i % 5 == 0:  # Print progress every 5 samples
                print(f"Processing structure {i+1}/{len(self.df)}")
            graph = atoms_to_graph(atoms_str)
            self.graphs.append(graph)
        
        # Get atomic features lookup
        from jarvis.core.specie import chem_data, get_node_attributes
        
        # Load atomic features
        max_z = max(v["Z"] for v in chem_data.values())
        template = get_node_attributes("C", atom_features)
        features = np.zeros((1 + max_z, len(template)))
        
        for element, v in chem_data.items():
            z = v["Z"]
            x = get_node_attributes(element, atom_features)
            if x is not None:
                features[z, :] = x
        
        # Update node features in graphs
        for g in self.graphs:
            z = g.x
            g.atomic_number = z
            z = z.type(torch.IntTensor).squeeze()
            f = torch.tensor(features[z]).type(torch.FloatTensor)
            if g.x.size(0) == 1:
                f = f.unsqueeze(0)
            g.x = f
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        graph = self.graphs[idx]
        label = self.labels[idx]
        return graph, graph, graph, label  # Return graph triplet and label as expected by model
    
    @staticmethod
    def collate(samples):
        """Collate function for DataLoader."""
        graphs, line_graphs, lattice, labels = map(list, zip(*samples))
        batched_graph = Batch.from_data_list(graphs)
        batched_line_graph = Batch.from_data_list(line_graphs)
        return batched_graph, batched_line_graph, batched_line_graph, torch.tensor(labels)

### 3.3 Create Train/Validation/Test Splits

Let's create data splits for training, validation, and testing.

In [None]:
def get_data_splits(df, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, seed=42):
    """Split the dataframe into train, validation, and test sets."""
    # Ensure ratios sum to 1
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-10
    
    # Create random indices
    np.random.seed(seed)
    indices = np.random.permutation(len(df))
    
    # Calculate split sizes
    n_train = int(len(df) * train_ratio)
    n_val = int(len(df) * val_ratio)
    
    # Create splits
    train_indices = indices[:n_train]
    val_indices = indices[n_train:n_train+n_val]
    test_indices = indices[n_train+n_val:]
    
    # Create dataframes
    train_df = df.iloc[train_indices].reset_index(drop=True)
    val_df = df.iloc[val_indices].reset_index(drop=True)
    test_df = df.iloc[test_indices].reset_index(drop=True)
    
    return train_df, val_df, test_df

# Create data splits
train_df, val_df, test_df = get_data_splits(df)

print(f"Train set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

### 3.4 Create Data Loaders

For a small tutorial, let's use a smaller subset of the data to speed up processing.

In [None]:
# Use a smaller subset for this tutorial
small_train_df = train_df.head(10)  # Just use 10 samples for training
small_val_df = val_df.head(5)      # 5 samples for validation
small_test_df = test_df.head(5)    # 5 samples for testing

# Create datasets
train_dataset = SimplePygStructureDataset(small_train_df, target)
val_dataset = SimplePygStructureDataset(small_val_df, target)
test_dataset = SimplePygStructureDataset(small_test_df, target)

# Create data loaders
batch_size = 2
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=train_dataset.collate,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=val_dataset.collate,
    drop_last=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=test_dataset.collate,
)

## 4. Understanding the Model Architecture

Now, let's implement a simplified version of the Comformer model to understand its architecture.

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter

class RBFExpansion(nn.Module):
    """Radial basis function expansion module."""
    
    def __init__(self, vmin, vmax, bins):
        super().__init__()
        self.vmin = vmin
        self.vmax = vmax
        self.bins = bins
        self.gap = (vmax - vmin) / bins
        self.centers = nn.Parameter(
            torch.linspace(vmin, vmax, bins), requires_grad=False
        )
        
    def forward(self, dist):
        # Compute RBF values
        dist = dist.view(-1, 1)
        centers = self.centers.view(1, -1)
        diff = dist - centers
        coef = -0.5 / (self.gap**2)
        rbf = torch.exp(coef * (diff**2))
        return rbf

class SimpleComformerAttention(nn.Module):
    """A simplified version of the Comformer attention mechanism."""
    
    def __init__(self, in_channels, out_channels, edge_channels, heads=1):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        
        # Key, query, value projections
        self.lin_key = nn.Linear(in_channels, heads * out_channels)
        self.lin_query = nn.Linear(in_channels, heads * out_channels)
        self.lin_value = nn.Linear(in_channels, heads * out_channels)
        self.lin_edge = nn.Linear(edge_channels, heads * out_channels)
        
        # Output projection
        self.lin_concat = nn.Linear(heads * out_channels, out_channels)
        
        # Message passing networks
        self.lin_msg_update = nn.Sequential(
            nn.Linear(out_channels * 3, out_channels),
            nn.SiLU(),
            nn.Linear(out_channels, out_channels)
        )
        
        # Key update network
        self.key_update = nn.Sequential(
            nn.Linear(out_channels * 3, out_channels),
            nn.SiLU(),
            nn.Linear(out_channels, out_channels)
        )
        
        # Other layers
        self.bn = nn.BatchNorm1d(out_channels)
        self.bn_att = nn.BatchNorm1d(out_channels)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, edge_index, edge_attr):
        """Forward pass.
        
        Args:
            x: Node features [num_nodes, in_channels]
            edge_index: Edge indices [2, num_edges]
            edge_attr: Edge features [num_edges, edge_channels]
            
        Returns:
            Updated node features [num_nodes, out_channels]
        """
        H, C = self.heads, self.out_channels
        
        # Project features
        query = self.lin_query(x).view(-1, H, C)
        key = self.lin_key(x).view(-1, H, C)
        value = self.lin_value(x).view(-1, H, C)
        edge = self.lin_edge(edge_attr).view(-1, H, C)
        
        # Get source and destination node indices
        src_idx, dst_idx = edge_index
        
        # Get features for source and destination nodes
        query_i = query[dst_idx]
        key_i = key[dst_idx]
        key_j = key[src_idx]
        value_i = value[dst_idx]
        value_j = value[src_idx]
        
        # Compute attention scores
        key_j = self.key_update(torch.cat((key_i, key_j, edge), dim=-1))
        alpha = (query_i * key_j) / (C ** 0.5)
        
        # Compute message updates
        out = self.lin_msg_update(torch.cat((value_i, value_j, edge), dim=-1))
        out = out * self.sigmoid(self.bn_att(alpha.view(-1, C)).view(-1, H, C))
        
        # Aggregate messages
        out = out.view(-1, H * C)
        out = self.lin_concat(out)
        
        # Sum messages for each destination node
        out = scatter(out, dst_idx, dim=0, reduce="sum", dim_size=x.size(0))
        
        # Apply residual connection and batch norm
        return self.softplus(x + self.bn(out))

class SimpleComformer(nn.Module):
    """A simplified version of the Comformer model."""
    
    def __init__(self, 
                 atom_input_features=92,
                 edge_features=256,
                 node_features=256,
                 output_features=1,
                 n_conv_layers=3):
        super().__init__()
        
        # Node embedding
        self.atom_embedding = nn.Linear(atom_input_features, node_features)
        
        # Edge embedding
        self.rbf = nn.Sequential(
            RBFExpansion(vmin=-4.0, vmax=0.0, bins=edge_features),
            nn.Linear(edge_features, node_features),
            nn.Softplus(),
        )
        
        # Attention layers
        self.att_layers = nn.ModuleList([
            SimpleComformerAttention(
                in_channels=node_features,
                out_channels=node_features,
                edge_channels=node_features)
            for _ in range(n_conv_layers)
        ])
        
        # Readout layers
        self.fc = nn.Sequential(
            nn.Linear(node_features, node_features),
            nn.SiLU()
        )
        
        # Output layer
        self.fc_out = nn.Linear(node_features, output_features)
        
    def forward(self, data):
        """Forward pass.
        
        Args:
            data: Tuple of (graph, line_graph, lattice)
            
        Returns:
            Predicted properties
        """
        graph, _, _ = data
        
        # Embed nodes
        node_features = self.atom_embedding(graph.x)
        
        # Compute edge features
        edge_feat = -0.75 / torch.norm(graph.edge_attr, dim=1)
        edge_features = self.rbf(edge_feat)
        
        # Apply attention layers
        for att_layer in self.att_layers:
            node_features = att_layer(node_features, graph.edge_index, edge_features)
        
        # Global pooling (average node features per graph)
        features = scatter(node_features, graph.batch, dim=0, reduce="mean")
        
        # Apply final layers
        features = self.fc(features)
        out = self.fc_out(features)
        
        return torch.squeeze(out)

## 5. Model Training

Now let's set up a simple training loop for our model.

In [None]:
# Initialize the model
# Determine the atom input features from the first graph
atom_features_dim = train_dataset.graphs[0].x.shape[1]

model = SimpleComformer(
    atom_input_features=atom_features_dim,
    edge_features=64,  # Reduced for faster training
    node_features=64,  # Reduced for faster training
    output_features=1,
    n_conv_layers=2    # Reduced for faster training
)

model = model.to(device)

# Print model summary
print(model)
print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

In [None]:
# Set up training
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        # Move batch to device
        batch = [b.to(device) for b in batch]
        g, lg, lattice, target = batch
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model([g, lg, lattice])
        
        # Compute loss
        loss = criterion(output, target)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            # Move batch to device
            batch = [b.to(device) for b in batch]
            g, lg, lattice, target = batch
            
            # Forward pass
            output = model([g, lg, lattice])
            
            # Compute loss
            loss = criterion(output, target)
            val_loss += loss.item()
    
    # Average losses
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    
    # Save losses
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Print progress
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [None]:
# Plot training curves
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, num_epochs+1), val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

## 6. Model Evaluation and Inference

Now let's evaluate our model on the test set.

In [None]:
# Evaluate on test set
model.eval()
test_targets = []
test_predictions = []

with torch.no_grad():
    for batch in test_loader:
        # Move batch to device
        batch = [b.to(device) for b in batch]
        g, lg, lattice, target = batch
        
        # Forward pass
        output = model([g, lg, lattice])
        
        # Denormalize predictions and targets
        output_denorm = output.cpu().numpy() * test_dataset.std.numpy() + test_dataset.mean.numpy()
        target_denorm = target.cpu().numpy() * test_dataset.std.numpy() + test_dataset.mean.numpy()
        
        test_targets.append(target_denorm.item())
        test_predictions.append(output_denorm.item())

# Calculate MAE
mae = np.mean(np.abs(np.array(test_targets) - np.array(test_predictions)))
print(f"Test MAE: {mae:.4f}")

# Plot predictions vs targets
plt.figure(figsize=(8, 8))
plt.scatter(test_targets, test_predictions)
plt.plot([min(test_targets), max(test_targets)], [min(test_targets), max(test_targets)], 'r--')
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.title(f'Test Set: Predictions vs True Values\nMAE: {mae:.4f}')
plt.axis('equal')
plt.grid(True)
plt.show()

## 7. Making Predictions on New Structures

Now let's see how to use our trained model to make predictions on new structures.

In [None]:
def predict_property(model, structure_string, mean, std):
    """Predict property for a new structure."""
    # Convert structure to graph
    graph = atoms_to_graph(structure_string)
    
    # Create a batch with just this graph
    batch = [graph, graph, graph]
    batch = [b.to(device) for b in batch]
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        prediction = model(batch)
        # Denormalize prediction
        prediction_denorm = prediction.cpu().numpy() * std + mean
    
    return prediction_denorm.item()

# Let's make a prediction on a structure from the test set
sample_structure = test_df['atoms'].iloc[0]
true_value = test_df[target].iloc[0]

# Predict
prediction = predict_property(model, sample_structure, test_dataset.mean.item(), test_dataset.std.item())

print(f"True {target}: {true_value:.4f}")
print(f"Predicted {target}: {prediction:.4f}")
print(f"Absolute error: {abs(true_value - prediction):.4f}")

## 8. Summary and Key Insights

In this tutorial, we've walked through the process of using Comformer, a Graph Neural Network for predicting material properties. Here's a summary of the key steps and insights:

### Data Processing Pipeline
1. **Loading Crystal Structures**: We loaded structures from CSV files, where each structure is represented as a dictionary containing atomic positions and lattice information.
2. **Converting to Graphs**: We converted crystal structures to graphs where atoms are nodes and bonds are edges.
3. **Feature Generation**: Node features represent atomic properties, while edge features capture bond information.

### Model Architecture
The Comformer model consists of:
1. **Node and Edge Embedding Layers**: Convert raw node and edge features to embeddings.
2. **Attention Layers**: Update node features based on their neighbors and the edge features.
3. **Readout Layer**: Pool node features to get a graph-level representation.
4. **Output Layer**: Predict the target property from the graph representation.

### Key Advantages of GNNs for Materials
1. **Permutation Invariance**: The model is invariant to the order of atoms in the structure.
2. **Local Structure Awareness**: The model can capture local bonding patterns and environments.
3. **Scalability**: Can handle structures with different numbers of atoms.
4. **Transferability**: Knowledge learned from one material can be transferred to others.

### Practical Considerations
1. **Data Normalization**: Important for stable training and accurate predictions.
2. **Graph Construction**: The choice of cutoff and number of neighbors impacts model performance.
3. **Model Size**: Deep GNNs can capture complex relationships but may overfit on small datasets.

GNNs like Comformer represent a powerful approach for predicting material properties directly from atomic structures, enabling rapid screening and discovery of novel materials with desired properties.