In [1]:
import os
import sys
import time
from typing import Callable, List, Optional
from glob import glob

from ase.io import read
from tqdm import tqdm
import torch
import torch.multiprocessing as mp

from sevenn.atom_graph_data import AtomGraphData
from sevenn.nn.sequential import AtomGraphSequential
from sevenn.train.dataload import unlabeled_atoms_to_graph
import sevenn.train.dataload as dl
import sevenn._keys as KEY
import sevenn.util as util
from sevenn.sevennet_calculator import SevenNetCalculator

# for test
from torch_geometric.loader import DataLoader # test
from torch_geometric.data import Batch # batch test

In [2]:
def unlabeled_graph_build(
    atoms_list: List,
    cutoff: float,
    num_cores: int = 1,
    transfer_info: bool = True,
    y_from_calc: bool = False,
) -> List[AtomGraphData]:
    """
    parallel version of graph_build
    build graph from atoms_list and return list of AtomGraphData
    Args:
        atoms_list (List): list of ASE atoms
        cutoff (float): cutoff radius of graph
        num_cores (int): number of cores to use
        transfer_info (bool): if True, copy info from atoms to graph,
                              defaults to True
        y_from_calc (bool): Get reference y labels from calculator, defaults to False
    Returns:
        List[AtomGraphData]: list of AtomGraphData
    """
    serial = num_cores == 1
    inputs = [(atoms, cutoff) for atoms in atoms_list]

    if not serial:
        pool = mp.Pool(num_cores)
        graph_list = pool.starmap(
            unlabeled_atoms_to_graph,
            tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'),
        )
        pool.close()
        pool.join()
    else:
        graph_list = [
            unlabeled_atoms_to_graph(*input_)
            for input_ in tqdm(inputs, desc='graph_build (1)')
        ]

    graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list]

    return graph_list

In [3]:
material = 'Mg4Ta8O24'
working_dir = os.getcwd()
device='cpu'
# Read the structure paths
structure_paths = glob(os.path.join(working_dir, material, "*.cif"))
structures = [] # List of ASE Atoms objects
start = time.time()
for i, structure_path in enumerate(structure_paths):
    structure = read(structure_path)
    structures.append(structure)
    
end = time.time()
print(f"Reading structures took {end-start:.2f} seconds.")

Reading structures took 0.04 seconds.


In [4]:
structure = structures[0]

In [5]:
calc = SevenNetBatchCalculator(device=device)

NameError: name 'SevenNetBatchCalculator' is not defined

In [6]:
structure.calc = calc

NameError: name 'calc' is not defined

In [7]:
structure.get_potential_energy()

RuntimeError: Atoms object has no calculator.

In [18]:
calc.atoms

In [3]:

from ase import Atoms

class AtomsList(list):
    """A custom list class to hold ASE Atoms objects and add functionality."""

    def __init__(self, *args):
        """Initialize the AtomsList with a list of Atoms objects."""
        super().__init__(*args)
        self._validate_atoms()

    def _validate_atoms(self):
        """Ensure all elements in the list are ASE Atoms objects."""
        for item in self:
            if not isinstance(item, Atoms):
                raise TypeError(f"All elements must be ASE Atoms objects. Found: {type(item)}")

    def append(self, item):
        """Override append to validate Atoms objects."""
        if not isinstance(item, Atoms):
            raise TypeError(f"Only ASE Atoms objects can be added. Found: {type(item)}")
        super().append(item)

    def extend(self, items):
        """Override extend to validate multiple Atoms objects."""
        for item in items:
            if not isinstance(item, Atoms):
                raise TypeError(f"Only ASE Atoms objects can be added. Found: {type(item)}")
        super().extend(items)

    def get_energies(self):
        """Get potential energies for all Atoms objects in the list."""
        return [atoms.get_potential_energy() for atoms in self]

    def get_atomic_numbers(self):
        """Get atomic numbers for all Atoms objects."""
        return [atoms.get_atomic_numbers() for atoms in self]

atoms_list = AtomsList(structures)
atoms_list.get_atomic_numbers()

NameError: name 'structures' is not defined

In [4]:
import os
import pathlib
from typing import Any, Optional, Union

import numpy as np
import torch
import torch.jit
import torch.jit._script
from ase.calculators.calculator import Calculator, all_changes
from ase.data import chemical_symbols

import sevenn._keys as KEY
import sevenn.util as util
from sevenn.atom_graph_data import AtomGraphData
from sevenn.nn.sequential import AtomGraphSequential
from sevenn.train.dataload import unlabeled_atoms_to_graph

torch_script_type = torch.jit._script.RecursiveScriptModule


class SevenNetBatchCalculator(Calculator):
    """ASE calculator for SevenNet models

    Multi-GPU parallel MD is not supported for this mode.
    Use LAMMPS for multi-GPU parallel MD.
    This class is for convenience who want to run SevenNet models with ase.

    Note than ASE calculator is designed to be interface of other programs.
    But in this class, we simply run torch model inside ASE calculator.
    So there is no FileIO things.

    Here, free_energy = energy
    """

    def __init__(
        self,
        model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0',
        file_type: str = 'checkpoint',
        device: Union[torch.device, str] = 'auto',
        sevennet_config: Optional[Any] = None,  # hold meta information
        **kwargs,
    ):
        """Initialize the calculator

        Args:
            model (SevenNet): path to the checkpoint file, or pretrained
            device (str, optional): Torch device to use. Defaults to "auto".
        """
        super().__init__(**kwargs)
        self.sevennet_config = None

        if isinstance(model, pathlib.PurePath):
            model = str(model)

        file_type = file_type.lower()
        if file_type not in ['checkpoint', 'torchscript', 'model_instance']:
            raise ValueError('file_type should be checkpoint or torchscript')

        if isinstance(device, str):  # TODO: do we really need this?
            if device == 'auto':
                self.device = torch.device(
                    'cuda' if torch.cuda.is_available() else 'cpu'
                )
            else:
                self.device = torch.device(device)
        else:
            self.device = device

        if file_type == 'checkpoint' and isinstance(model, str):
            if os.path.isfile(model):
                checkpoint = model
            else:
                checkpoint = util.pretrained_name_to_path(model)
            model_loaded, config = util.model_from_checkpoint(checkpoint)
            model_loaded.set_is_batch_data(False)
            self.type_map = config[KEY.TYPE_MAP]
            self.cutoff = config[KEY.CUTOFF]
            self.sevennet_config = config
        elif file_type == 'torchscript' and isinstance(model, str):
            extra_dict = {
                'chemical_symbols_to_index': b'',
                'cutoff': b'',
                'num_species': b'',
                'model_type': b'',
                'version': b'',
                'dtype': b'',
                'time': b'',
            }
            model_loaded = torch.jit.load(
                model, _extra_files=extra_dict, map_location=self.device
            )
            chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8')
            sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)}
            self.type_map = {
                sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split())
            }
            self.cutoff = float(extra_dict['cutoff'].decode('utf-8'))
        elif isinstance(model, AtomGraphSequential):
            if model.type_map is None:
                raise ValueError(
                    'Model must have the type_map to be used with calculator'
                )
            if model.cutoff == 0.0:
                raise ValueError('Model cutoff seems not initialized')
            model.eval_type_map = torch.tensor(True)  # ?
            model.set_is_batch_data(False)
            model_loaded = model
            self.type_map = model.type_map
            self.cutoff = model.cutoff
        else:
            raise ValueError('Unexpected input combinations')

        if self.sevennet_config is None and sevennet_config is not None:
            self.sevennet_config = sevennet_config

        self.model = model_loaded

        self.model.to(self.device)
        self.model.eval()

        self.implemented_properties = [
            'free_energy',
            'energy',
            'forces',
            'stress',
            'energies',
        ]

    def calculate(self, atoms=None, properties=None, system_changes=all_changes):
        # call parent class to set necessary atom attributes
        # Calculator.calculate(self, atoms, properties, system_changes)
        if atoms is None:
            raise ValueError('No atoms to evaluate')
        data = AtomGraphData.from_numpy_dict(
            unlabeled_atoms_to_graph(atoms, self.cutoff)
        )

        data.to(self.device)  # type: ignore

        if isinstance(self.model, torch_script_type):
            data[KEY.NODE_FEATURE] = torch.tensor(
                [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]],
                dtype=torch.int64,
                device=self.device,
            )
            data[KEY.POS].requires_grad_(True)  # backward compatibility
            data[KEY.EDGE_VEC].requires_grad_(True)  # backward compatibility
            data = data.to_dict()
            del data['data_info']

        output = self.model(data)
        energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item()
        # Store results
        self.results = {
            'free_energy': energy,
            'energy': energy,
            'energies': (
                output[KEY.ATOMIC_ENERGY].detach().cpu().reshape(len(atoms)).numpy()
            ),
            'forces': output[KEY.PRED_FORCE].detach().cpu().numpy(),
            'stress': np.array(
                (-output[KEY.PRED_STRESS])
                .detach()
                .cpu()
                .numpy()[[0, 1, 2, 4, 5, 3]]  # as voigt notation
            ),
        }


In [5]:
structure.get_potential_energy()

NameError: name 'structure' is not defined

In [6]:
structure

Atoms(symbols='Mg4Ta8O24', pbc=True, cell=[[6.21514988, 0.0, 0.0], [-0.1320216956869857, 6.577446281775212, 0.0], [0.01274563472079938, 0.02048320062139419, 10.05122817772354]], spacegroup_kinds=...)

In [8]:
# _graph_build_f = dl._graph_build_ase
# # graph1 = dl.unlabeled_atoms_to_graph(structure, 5.0)
# calc = SevenNetCalculator(device=device)
# structure.calc = calc
# print(structure.get_potential_energy())

# _graph_build_f = dl._graph_build_matscipy
calc = SevenNetCalculator(device=device)
structure.calc = calc
print(structure.get_potential_energy())
# graph2 = dl.unlabeled_atoms_to_graph(structure, 5.0)

-314.3199768066406


In [23]:
_graph_build_f

<function sevenn.train.dataload._graph_build_matscipy(cutoff: float, pbc, cell, pos)>

In [11]:
graph['pos']

NameError: name 'graph' is not defined

In [13]:
util.pretrained_name_to_path('7net-0')

'/home/haekwan98/miniconda3/envs/seven_test/lib/python3.9/site-packages/sevenn/pretrained_potentials/SevenNet_0__11July2024/checkpoint_sevennet_0.pth'