In [15]:
from nequip.ase import NequIPCalculator
from pymatgen.io.ase import AseAtomsAdaptor
from ase.filters import FrechetCellFilter
from ase.io import Trajectory
import pickle
from ase.atoms import Atoms, units 
import numpy as np 
import json, os
from ase.optimize import LBFGS 
from pymatgen.core import Structure 

class TrajectoryObserver:
    """Trajectory observer is a hook in the relaxation process that saves the
    intermediate structures.
    """

    # thanks to CHGNet and M3GNET teams

    def __init__(self, atoms: Atoms) -> None:
        """Create a TrajectoryObserver from an Atoms object.

        Args:
            atoms (Atoms): the structure to observe.
        """
        self.atoms = atoms
        self.energies: list[float] = []
        self.forces: list[np.ndarray] = []
        #self.stresses: list[np.ndarray] = []
        #self.magmoms: list[np.ndarray] = []
        self.atom_positions: list[np.ndarray] = []
        self.cells: list[np.ndarray] = []

    def __call__(self) -> None:
        """The logic for saving the properties of an Atoms during the relaxation."""
        self.energies.append(self.compute_energy())
        self.forces.append(self.atoms.get_forces())
        #self.stresses.append(self.atoms.get_stress())
        #self.magmoms.append(self.atoms.get_magnetic_moments())
        self.atom_positions.append(self.atoms.get_positions())
        self.cells.append(self.atoms.get_cell()[:])

    def __len__(self) -> int:
        """The number of steps in the trajectory."""
        return len(self.energies)

    def compute_energy(self) -> float:
        """Calculate the potential energy.

        Returns:
            energy (float): the potential energy.
        """
        return self.atoms.get_potential_energy()

    def save(self, filename: str) -> None:
        """Save the trajectory to file.

        Args:
            filename (str): filename to save the trajectory
        """
        out_pkl = {
            "energy": self.energies,
            "forces": self.forces,
            #"stresses": self.stresses,
            #"magmoms": self.magmoms,
            "atom_positions": self.atom_positions,
            "cell": self.cells,
            "atomic_number": self.atoms.get_atomic_numbers(),
        }
        with open(filename, "wb") as file:
            pickle.dump(out_pkl, file)


def allegro_relaxer(atoms, potential_path, species, device='cpu', fmax = 0.01, steps = 250, verbose=False, relax_cell=True, loginterval=1):
    atoms.calc = NequIPCalculator.from_deployed_model(
        model_path=potential_path,
        species_to_type_name = species
    )
    ecf = FrechetCellFilter(atoms, constant_volume= not relax_cell)
    obs = TrajectoryObserver(atoms)
    optimizer = LBFGS(ecf)
    optimizer.attach(obs, interval=loginterval)
    optimizer.run(fmax=fmax, steps=steps)

    if isinstance(atoms, Filter):
        atoms = atoms.atoms 
    struct = AseAtomsAdaptor.get_structure(atoms)
    return {"final_structure" : struct, "trajectory" : obs}

In [4]:
# test the relaxing function
#load all the structures
data_path = '../Visualization/Job_Structures/Post_VASP/VCrTiWZr_Summit/gen_0_4/Vacancies/converted_vcrtiwzr_gen_0_4_vacancies_data.json'
data = json.load(open(data_path, 'r'))

In [13]:
species = {
            "Ti": "NequIPTypeNameForTitanium",
            "V": "NequIPTypeNameForVanadium",
            "Cr" : "NequIPTypeNameForChromium",
            "Zr" : "NequIPTypeNameForZirconium",
            "W" : "NequIPTypeNameForTungsten",
        }

species = {"Ti" : "Ti", "V" : "V", "Cr" : "Cr", "Zr" : "Zr", "W" : "W"}

In [11]:
# get a structure 
structure = Structure.from_dict(data['supercell_gen0_comp11_struct3_vac_site3_start']['structures'][-1])
test_energy = data['supercell_gen0_comp11_struct3_vac_site3_start']['energies'][-1]
test_forces = data['supercell_gen0_comp11_struct3_vac_site3_start']['forces'][-1] 

atoms = AseAtomsAdaptor.get_atoms(structure)

In [19]:
pot_path = '../Potentials/vcrtiwzr_vac_deployed.pth'
#relax_endpoint = allegro_relaxer(atoms, potential_path= pot_path , species = species , relax_cell= False)
atoms.calc = NequIPCalculator.from_deployed_model(
        model_path=pot_path,
        species_to_type_name = species
    )

print(atoms.get_potential_energy())
print(atoms.get_forces())



-583.4518972628949
[[ 1.14335003e-02 -5.48178350e-03  9.59480372e-03]
 [ 1.02342769e-02 -1.45829930e-03  1.07749358e-03]
 [-2.83645557e-03 -8.30297904e-03 -5.04277550e-04]
 [ 4.68724447e-03  6.46664165e-03  7.09093791e-03]
 [-1.92889874e-02 -4.89046764e-03  6.18245081e-03]
 [ 7.91208199e-03  4.80869082e-03  1.74734331e-02]
 [ 1.13995349e-02 -9.04563962e-03 -1.00054801e-02]
 [ 1.79275993e-02  1.35558272e-02  1.39590009e-02]
 [ 5.97147876e-03  5.98771223e-03 -1.20134679e-02]
 [ 7.37654092e-03 -1.51764768e-03 -8.58058320e-03]
 [ 1.52052458e-02 -7.60596133e-03 -1.58171444e-03]
 [-2.20839180e-02  1.23625875e-02  1.19971610e-02]
 [ 9.75285172e-03  1.91645982e-03 -6.27655446e-03]
 [-1.62893866e-02  6.35855033e-03 -4.07007604e-03]
 [ 8.63570280e-03 -1.05572456e-02  3.38544765e-03]
 [-4.46969863e-03 -1.69506917e-03  6.58106130e-03]
 [ 2.06692670e-02 -1.03895792e-03 -8.37129389e-03]
 [-1.00751298e-02  9.79879868e-03 -1.64734669e-03]
 [-5.38136053e-03  1.70957274e-02 -2.83268678e-03]
 [-6.0917886

In [20]:
from ase.calculators.test import numeric_stress 
print(numeric_stress(atoms))

[-1.93178405e-03  1.75682902e-04 -1.07303132e-03  3.47542718e-04
  4.95189106e-05 -5.52263396e-04]
