# Tutorial 3 - Fitting a MLFF 

For this tutorial we also need to install PyTorch and PyTorch Geometric. We will build a pipeline to test an algorithm that you can define by yourself. 

## Environment setup 

We setup and PyTorch and PyTorch geometric using the `CPU` as most consumer hardware does not offer 64bit GPUs. For simulations we need as good of an accuracy we can get 

```bash 
conda create -n mlfftutorial3 python=3.11
conda activate mlfftutorial3
pip install torch_geometric
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install mond
```

## Dataset 

We use the dataset generated in the previous tutorial. If you just want to test the algorithm fast without generating a dataset for a couple of hours, that's fine. The repo provides a test dataset. If you want to use the provided dataset, make sure to change the dataset path to `methanol_aimd_data_prerun`

In [None]:
import torch
import numpy as np
from torch_geometric.data import Dataset, Data
from pathlib import Path

class NPZTrajectoryDataset(Dataset):
    def __init__(self, root_dir, transform=None, pre_transform=None):
        super().__init__(root_dir, transform, pre_transform)
        self.root_dir = Path(root_dir)
        self.files = sorted(self.root_dir.glob("*.npz"))

    def len(self):
        return len(self.files)

    def get(self, idx):
        file = self.files[idx]
        data = np.load(file)

        pos = torch.tensor(data["positions"], dtype=torch.float64)         # shape (N, 3)
        forces = torch.tensor(data["forces"], dtype=torch.float64)         # shape (N, 3)
        e_pot = torch.tensor(data["e_pot"], dtype=torch.float64).view(1)   # scalar -> shape (1,)
        z = torch.tensor(data["atom_types"], dtype=torch.float64)             # shape (N,)

        edge_index = self._fully_connected_edges(pos.shape[0])

        # Optionally: create edge_attr as relative positions (optional but recommended)
        edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
        edge_attr = edge_vec.norm(dim=1, keepdim=True)
        edge_attr = torch.cat([edge_vec, edge_attr], dim=1)     # [n_edges, 4]


        return Data(pos=pos, z=z.unsqueeze(1).double(), e_pot=e_pot, force=forces,
                    edge_index=edge_index, edge_attr=edge_attr, edge_vec=edge_vec)

    def _fully_connected_edges(self, num_atoms):
        row, col = torch.meshgrid(
            torch.arange(num_atoms), torch.arange(num_atoms), indexing="ij"
        )
        edge_index = torch.stack([row.flatten(), col.flatten()], dim=0)
        return edge_index

## Sanity Checks 

To check if your algorithm behaves in any meaningful way, or if we have to tune something in the algorithm or the training process, we need to plot the probability distribution of the data and later then the algorithm. So first, we are introducing data augmentation to make the algorithm invariant to translations and rotations of the molecule of interest. 

Here, it is absolutely important to rotate the forces if we rotate the molecule. 

In [None]:
import torch
from torch_geometric.data import Data
import numpy as np

class RandomRotationTranslation:
    def __init__(self, box_size):
        """
        box_size: list or tensor of shape (3,) representing box lengths [Lx, Ly, Lz]
        """
        self.box_size = torch.tensor(box_size, dtype=torch.float64)

    def __call__(self, data: Data):
        # Center molecule before rotation
        pos = data.pos
        center = pos.mean(dim=0, keepdim=True)
        pos_centered = pos - center

        # Apply random rotation
        R = random_rotation_matrix()
        pos_rotated = pos_centered @ R.T
        force_rotated = data.force @ R.T

        # Bounding box of rotated molecule
        min_corner = pos_rotated.min(dim=0).values
        max_corner = pos_rotated.max(dim=0).values
        extent = max_corner - min_corner

        # Compute safe margin for translation
        margin = self.box_size - extent
        if (margin < 0).any():
            raise ValueError(f"Molecule too large for bounding box: extent={extent}, box={self.box_size}")

        # Random translation that keeps molecule inside the box
        translation = torch.rand(3, dtype=torch.float64) * margin - min_corner
        pos_translated = pos_rotated + translation

        # Update data
        data.pos = pos_translated
        data.force = force_rotated

        # Update edge features
        edge_vec = data.pos[data.edge_index[0]] - data.pos[data.edge_index[1]]
        edge_attr = edge_vec.norm(dim=1, keepdim=True)
        data.edge_vec = edge_vec
        data.edge_attr = torch.cat([edge_vec, edge_attr], dim=1)

        return data

def random_rotation_matrix():
    """Uniform random rotation matrix using QR decomposition."""
    A = torch.randn(3, 3, dtype=torch.float64)
    Q, R = torch.linalg.qr(A)
    # Ensure right-handed coordinate system
    if torch.det(Q) < 0:
        Q[:, 2] *= -1
    return Q

### Plotting the distribution

In [None]:
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

def plot_force_distribution(dataset, file_name, batch_size=32, max_batches=10):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    forces = []

    # Collect forces from a few batches
    for i, data in enumerate(loader):
        if not hasattr(data, 'force'):
            raise AttributeError("Dataset items must have a `force` attribute.")
        x = data.z
        f = data.force.view(-1, 3)  # shape: [num_nodes_total, 3]
        f = f.view(-1, 3)
        forces.append(f)
        if i >= max_batches:
            break

    forces = torch.cat(forces, dim=0).detach().numpy()  # [N, 3]

    # Compute magnitudes and individual components
    magnitudes = (forces ** 2).sum(axis=1) ** 0.5
    fx, fy, fz = forces[:, 0], forces[:, 1], forces[:, 2]

    # Plot
    plt.figure(figsize=(14, 4))

    plt.subplot(1, 4, 1)
    plt.hist(fx, bins=100, alpha=0.7, color='black')
    plt.title("Force X")

    plt.subplot(1, 4, 2)
    plt.hist(fy, bins=100, alpha=0.7, color='black')
    plt.title("Force Y")

    plt.subplot(1, 4, 3)
    plt.hist(fz, bins=100, alpha=0.7, color='black')
    plt.title("Force Z")

    plt.subplot(1, 4, 4)
    plt.hist(magnitudes, bins=100, alpha=0.7, color='black')
    plt.title("Force Magnitude")

    plt.tight_layout()
    plt.savefig(file_name, dpi=300)



data_path = "methanol_aimd_data_prerun/" #change to methanol_aimd_data_prerun if you use the prebuild dataset
box_size= [10,10,10]
transform = RandomRotationTranslation(box_size) #important to instantiate the translation
dataset = NPZTrajectoryDataset(root_dir=data_path, transform=transform)
dataset = NPZTrajectoryDataset(root_dir=data_path)
plot_force_distribution(dataset=dataset, max_batches=200, file_name="force_distribution_dataset_data_aug.png")

## Algorithmic part 

Now comes the algorithmic part, where we explore some concepts of MLFFs. The algorithm of MLFFs needs to fulfil several requirements. 

* Needs to be rotation and translation invariant to the system as a whole (not just parts, that would be a different story)
* Output negative and postive forces of any scale (no activation functions for the output layer) -> only activation functions in the hidden layers 
* Fit a conservative force field using differentiable activation functions SiLU 
* One could also derive the force from the negative gradient of the positions using automatic differentiation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import scatter

class ForceMPNN(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.embedding = nn.Linear(1, hidden_dim)

        self.edge_mlp = nn.Sequential(
            nn.Linear(4, hidden_dim),  # vec_ij (3) + dist (1)
            nn.SiLU()  ,
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.message_mlp = nn.Sequential(
            nn.SiLU()  ,
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.update_mlp = nn.Sequential(
            nn.Linear(hidden_dim, 3),  # Output force vector
        )

    def forward(self, data):
        x = self.embedding(data.z)                    # [n_nodes, hidden]
        e = self.edge_mlp(data.edge_attr)             # [n_edges, hidden]
        row, col = data.edge_index                    # edge from i = row to j = col
        messages = self.message_mlp(x[col] * e)       # message from j to i
        aggregated = scatter(messages, row, dim=0, dim_size=x.size(0), reduce='add')
        force = self.update_mlp(aggregated)           # [n_nodes, 3]
        return force

## Trainig Loop 

The training loop is nothing special. Just fitting the forces with data augmentation and without gradient derivation of the force. 

In [None]:
from tqdm import tqdm 
from torch_geometric.loader import DataLoader

num_steps = 5000
device = torch.device("cpu")

# Dataset class should already be defined as you provided earlier
# Initialize the dataset
data_path = "methanol_aimd_data/"
box_size = [10,10,10]
transform = RandomRotationTranslation(box_size) #important to instantiate the translation
dataset = NPZTrajectoryDataset(root_dir=data_path, transform=transform)
print(dataset[0].z.shape)
print(dataset[0].z)
print(len(dataset))
# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Model initialization
model = ForceMPNN(hidden_dim=256)
model = model.double()
# Loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training loop
loss_store = []
for epoch in tqdm(range(1,num_steps+1), desc="Training MPNN on Methanol Trajectory"):
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        pred = model(batch)
        loss = F.mse_loss(pred, batch.force)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    loss_store.append(total_loss/len(dataloader))
    if epoch%100 == 0:
        print(f"Epoch {epoch:2d} | Loss: {total_loss / len(dataloader):.4f}")

## Plot the loss and save the model for later

You know that for sure

In [None]:
import matplotlib.pyplot as plt 
x = [i for i in range(len(loss_store))]
plt.plot(x, loss_store)
plt.xlabel("Training Step")
plt.ylabel("Loss")
plt.savefig("MPNN_5k_epoch_64_batch_size_256_hidden_dim_labels_data_aug-SiLU.png", dpi=300)

In [None]:
torch.save(model, "mpnn_full_model_data_aug_SiLU_5000.pth")
model_path = "mpnn_full_model_data_aug_SiLU_5000.pth"
model = torch.load(model_path, map_location="cpu", weights_only=False)
model

## Checking the probability distribution of the model 

We now check if the model is producing a reasonable probability distribution. Remember to change the data path if you are using different data

In [None]:
def plot_force_distribution_model(dataset, model, file_name, batch_size=32, max_batches=10):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    forces = []

    # Collect forces from a few batches
    for i, data in enumerate(loader):
        if not hasattr(data, 'force'):
            raise AttributeError("Dataset items must have a `force` attribute.")
        x = data.z
        f = data.force.view(-1, 3)  # shape: [num_nodes_total, 3]
        f = model(data)
        f = f.view(-1, 3)
        forces.append(f)
        if i >= max_batches:
            break

    forces = torch.cat(forces, dim=0).detach().numpy()  # [N, 3]

    # Compute magnitudes and individual components
    magnitudes = (forces ** 2).sum(axis=1) ** 0.5
    fx, fy, fz = forces[:, 0], forces[:, 1], forces[:, 2]

    # Plot
    plt.figure(figsize=(14, 4))

    plt.subplot(1, 4, 1)
    plt.hist(fx, bins=100, alpha=0.7, color='black')
    plt.title("Force X")

    plt.subplot(1, 4, 2)
    plt.hist(fy, bins=100, alpha=0.7, color='black')
    plt.title("Force Y")

    plt.subplot(1, 4, 3)
    plt.hist(fz, bins=100, alpha=0.7, color='black')
    plt.title("Force Z")

    plt.subplot(1, 4, 4)
    plt.hist(magnitudes, bins=100, alpha=0.7, color='black')
    plt.title("Force Magnitude")

    plt.tight_layout()
    plt.savefig(file_name, dpi=300)

model_path = "mpnn_full_model_data_aug_SiLU_5000.pth"
model = torch.load(model_path, map_location="cpu", weights_only=False)
data_path = "methanol_aimd_data/" #change for your data path 
dataset = NPZTrajectoryDataset(root_dir=data_path)
plot_force_distribution_model(model=model, dataset=dataset, max_batches=200, file_name="force_distribution_mp_model_5000_data_aug_SiLU.png")

That's it for the training. Good job!