In [15]:
import torch
import torch.nn as nn

import numpy as np
import math

import random, json

from ase.io import read

from torch.utils.data import Dataset, DataLoader

In [7]:
torch.manual_seed(42)

<torch._C.Generator at 0x10ed2bdb0>

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU available:", torch.cuda.get_device_name(0))

elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS available")

else:
    device = torch.device("cpu")
    print("Using CPU")

device = torch.device("cpu")
dtype = torch.float32

MPS available


In [8]:
atoms_all = read("data/train.extxyz", index=":")
print("Structures:", len(atoms_all))

Structures: 1127


In [9]:
def atoms_to_data(atoms):
    data = {}

    data["Z"] = torch.tensor(atoms.numbers, dtype=torch.long)

    pos = torch.tensor(atoms.positions, dtype=dtype)
    pos.requires_grad_(True)
    data["pos"] = pos

    # === YOUR FORMAT ===
    data["energy"] = torch.tensor(atoms.info["REF_energy"], dtype=dtype)
    data["forces"] = torch.tensor(atoms.arrays["REF_forces"], dtype=dtype)

    # Cell for stress
    data["cell"] = torch.tensor(atoms.cell.array, dtype=dtype)
    data["volume"] = torch.tensor(atoms.get_volume(), dtype=dtype)

    return data

dataset = [atoms_to_data(a) for a in atoms_all]

In [12]:
N = len(dataset)
indices = list(range(N))
random.shuffle(indices)

split = int(0.8 * N)
train_idx = indices[:split]
test_idx = indices[split:]

json.dump({"train": train_idx, "test": test_idx}, open("split.json", "w"))

print("Train:", len(train_idx), "Test:", len(test_idx))

Train: 901 Test: 226


In [13]:
def build_graph(pos, cell, cutoff=5.0):
    # minimum image convention
    inv_cell = torch.inverse(cell)

    diff = pos[:, None, :] - pos[None, :, :]
    frac = diff @ inv_cell
    frac -= frac.round()
    rij = frac @ cell

    dist = torch.norm(rij, dim=-1)
    mask = (dist < cutoff) & (dist > 1e-6)

    edge_index = mask.nonzero(as_tuple=False)
    edge_vec = rij[edge_index[:,0], edge_index[:,1]]
    edge_len = torch.norm(edge_vec, dim=-1)

    return edge_index, edge_vec, edge_len

In [16]:
class BesselBasis(nn.Module):
    def __init__(self, n_rbf=8, cutoff=5.0):
        super().__init__()
        self.freq = torch.arange(1, n_rbf+1) * math.pi / cutoff
        self.cutoff = cutoff

    def forward(self, r):
        r = r.unsqueeze(-1)
        return torch.sin(self.freq.to(r.device) * r) / (r + 1e-6)

In [17]:
class EquivariantLayer(nn.Module):
    def __init__(self, dim=64, cutoff=5.0):
        super().__init__()
        self.rbf = BesselBasis(8, cutoff)
        self.edge_mlp = nn.Sequential(
            nn.Linear(8, dim),
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, h, pos, cell):
        edge_index, edge_vec, edge_len = build_graph(pos, cell)

        src, dst = edge_index[:,0], edge_index[:,1]

        rbf = self.rbf(edge_len)
        m = self.edge_mlp(rbf)

        # Directional weighting (equivariant flavor)
        direction = edge_vec / (edge_len[:, None] + 1e-6)
        msg = h[src] * m * direction.norm(dim=1, keepdim=True)

        agg = torch.zeros_like(h)
        agg.index_add_(0, dst, msg)

        return h + agg

In [18]:
class CustomMLIP(nn.Module):
    def __init__(self, n_species=100, dim=64, layers=4):
        super().__init__()

        self.embed = nn.Embedding(n_species, dim)
        self.layers = nn.ModuleList(
            [EquivariantLayer(dim) for _ in range(layers)]
        )

        self.energy_head = nn.Sequential(
            nn.Linear(dim, dim),
            nn.SiLU(),
            nn.Linear(dim, 1)
        )

    def forward(self, Z, pos, cell):
        h = self.embed(Z)

        for layer in self.layers:
            h = layer(h, pos, cell)

        per_atom_E = self.energy_head(h)
        return per_atom_E.sum()

In [19]:
def energy_forces(model, Z, pos, cell):
    pos.requires_grad_(True)
    E = model(Z, pos, cell)
    F = -torch.autograd.grad(E, pos, create_graph=True)[0]
    return E, F

In [20]:
def compute_stress(pos, forces, volume):
    virial = torch.einsum("ni,nj->ij", pos, forces)
    return virial / volume

In [21]:
from torch.optim import Adam
from tqdm import tqdm

model = CustomMLIP().to(device)
opt = Adam(model.parameters(), lr=1e-3)

wE, wF, wS = 1.0, 50.0, 5.0

In [None]:
for epoch in range(50):
    model.train()
    total = 0

    for i in tqdm(train_idx):
        batch = dataset[i]

        Z = batch["Z"].to(device)
        pos = batch["pos"].to(device)
        cell = batch["cell"].to(device)

        true_E = batch["energy"].to(device)
        true_F = batch["forces"].to(device)
        vol = batch["volume"].to(device)

        opt.zero_grad()

        pred_E, pred_F = energy_forces(model, Z, pos, cell)
        pred_S = compute_stress(pos, pred_F, vol)

        loss = (
            wE * (pred_E - true_E)**2 +
            wF * (pred_F - true_F).pow(2).mean() +
            wS * pred_S.pow(2).mean()
        )

        loss.backward()
        opt.step()
        total += loss.item()

    print(f"Epoch {epoch}: {total/len(train_idx):.4f}")

In [None]:
model.eval()
E_err, F_err = [], []

with torch.no_grad():
    for i in test_idx:
        batch = dataset[i]

        Z = batch["Z"].to(device)
        pos = batch["pos"].to(device)
        cell = batch["cell"].to(device)

        pred_E, pred_F = energy_forces(model, Z, pos, cell)

        E_err.append(abs(pred_E.item() - batch["energy"].item()))
        F_err.append((pred_F.cpu() - batch["forces"]).abs().mean().item())

print("Energy MAE:", sum(E_err)/len(E_err))
print("Force MAE:", sum(F_err)/len(F_err))