In [1]:
import collections
from functools import partial
from tqdm import tqdm
import numpy as np
import torch
import matgl
from matgl.config import DEFAULT_ELEMENTS
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.data import MGLDataLoader, MGLDataset, collate_fn_pes
from matgl.utils.training import ModelLightningModule
from pymatgen.io.ase import AseAtomsAdaptor

from gridmlip import Grid

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
grid = Grid.from_file('../example_files/NaGaPO4F_sym.cif',
                      specie = 11,       # atomic number
                      resolution = 0.25, # grid spacing
                      r_min = 1.5        # minimum allowed distance between the specie and framework
                      )
cfgs = grid.construct_configurations()

creating configurations: 50400it [00:01, 46861.56it/s]


In [3]:
structures = []
labels = collections.defaultdict(list)

for cfg in tqdm(cfgs):
    structures.append(AseAtomsAdaptor.get_structure(cfg))
    # labels are just to build MGLDataLoader
    labels["energies"].append(0.0)
    labels["forces"].append(np.zeros(len(cfg)).tolist())

100%|██████████| 17176/17176 [00:12<00:00, 1326.30it/s]


In [4]:
element_types = DEFAULT_ELEMENTS
cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0)
dataset = MGLDataset(structures=structures, converter=cry_graph,
                      labels = labels
                      )

100%|██████████| 17176/17176 [00:29<00:00, 586.84it/s]


In [5]:
collate_fn = partial(collate_fn_pes, include_line_graph=False, include_stress=False)
# create loader for the inference
_, _, loader = MGLDataLoader(
    train_data=[None],
    val_data = [None],
    test_data = dataset,
    collate_fn=collate_fn,
    batch_size=128,
    num_workers=0,
)

In [6]:
device = 'cuda'

In [7]:
model = matgl.load_model("TensorNet-MatPES-PBE-v2025.1-PES").model
model.to(device)
inference_module = ModelLightningModule(model)

  state = torch.load(f, map_location=map_location)
  d = torch.load(f, map_location=map_location)


In [8]:
energies = []
for g, lat, l_g, state_attr, e, f in tqdm(loader):
    with torch.no_grad():
        e_pred = inference_module(g=g.to(device), lat=lat.to(device), l_g=l_g.to(device), state_attr=state_attr.to(device))
        energies.extend(e_pred.detach().cpu().numpy())
energies = np.array(energies)

100%|██████████| 135/135 [00:19<00:00,  6.87it/s]


In [9]:
# load computed energies
grid.load_energies(energies)

In [10]:
# calculate barriers
grid.percolation_barriers()

{'E_1D': 0.0977, 'E_2D': 0.1562, 'E_3D': 0.1562}

In [11]:
# save .grd file
grid.write_grd('NaGaPO4F_sym_grid.grd')