In [1]:
import collections
from functools import partial
from tqdm import tqdm
import numpy as np
import torch
import matgl
from matgl.config import DEFAULT_ELEMENTS
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_pes
from matgl.utils.training import ModelLightningModule
from gridmlip import Grid

In [5]:
grid = Grid.from_file('../example_files/NaGaPO4F_sym.cif',
                      specie=11,         # atomic number
                      resolution=0.25,    # grid spacing
                      r_min=1.2,         # minimum allowed distance between the specie and framework
                      empty_framework=True # remove mobile species 
                      )
cfgs = grid.construct_configurations(config_format='pymatgen') # format can be 'ase' or pymatgen

Mesh Summary
────────────────────────────────────
  Mesh shape:                         (51, 25, 43)
  Symops:                             4
  Total points:                       54,825
  Irreducible points:                 12,600
  Irreducible points (r_min filter):  7,265
  Compression ratio:                    7.55×
────────────────────────────────────



creating configurations: 12600it [00:02, 5922.07it/s]


In [6]:
labels = collections.defaultdict(list)
for cfg in tqdm(cfgs):
    labels["energies"].append(0.0)
    labels["forces"].append(np.zeros(len(cfg)).tolist())

100%|█████████████████████████████████████████████████████████████████████████████| 7265/7265 [00:00<00:00, 331285.26it/s]


In [7]:
element_types = DEFAULT_ELEMENTS
converter = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(structures=cfgs,
                     converter=converter,
                     labels=labels,
                     clear_processed=True,
                     save_cache=False,
                    )

100%|████████████████████████████████████████████████████████████████████████████████| 7265/7265 [00:13<00:00, 527.09it/s]


In [8]:
collate_fn = partial(collate_fn_pes, include_line_graph=False, include_stress=False)
# create a loader for the inference
_, _, loader = MGLDataLoader(
    train_data=[None],
    val_data=[None],
    test_data=dataset,
    collate_fn=collate_fn,
    batch_size=128,
    num_workers=0,
)

In [12]:
device = 'cuda:0'

In [13]:
model = matgl.load_model("TensorNet-MatPES-PBE-v2025.1-PES").model
model.to(device)
inference_module = ModelLightningModule(model)

In [14]:
energies = []
for g, lat, l_g, state_attr, e, f in tqdm(loader):
    with torch.no_grad():
        e_pred = inference_module(g=g.to(device), lat=lat.to(device), l_g=l_g.to(device), state_attr=state_attr.to(device))
        energies.extend(e_pred.detach().cpu().numpy())
energies = np.array(energies)

100%|█████████████████████████████████████████████████████████████████████████████████████| 57/57 [00:14<00:00,  4.02it/s]


In [15]:
# load computed energies
grid.load_energies(energies)

In [16]:
# calculate barriers
grid.percolation_barriers()

{'E_1D': 0.0977, 'E_2D': 0.1562, 'E_3D': 0.1562}

In [17]:
# save .grd file
grid.write_grd('NaGaPO4F_sym_grid.grd')