In [1]:
from torch_geometric.data import Data, Batch

from pet.pet import PET
from pet.hypers import load_hypers_from_file
from pet.data_preparation import get_all_species
from pet.molecule import Molecule, batch_to_dict
import torch
import ase.io

from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import (
    MetatensorAtomisticModel,
    ModelCapabilities,
    NeighborsListOptions,
    ModelOutput,
    System,
)

import numpy as np

from pet_wrapper import PETMetatensorWrapper

# native PET prediction

In [2]:
hypers = load_hypers_from_file('../pet/default_hypers/default_hypers.yaml')

FITTING_SCHEME = hypers.FITTING_SCHEME
ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS
ARCHITECTURAL_HYPERS.D_OUTPUT = 1
ARCHITECTURAL_HYPERS.TARGET_TYPE = 'structural'
ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = 'sum'

model = PET(ARCHITECTURAL_HYPERS, 0.0, 2).to('cpu')

In [3]:
structure = ase.io.read('../pet/example/methane_train.xyz', index = 0)

molecule = Molecule(structure, ARCHITECTURAL_HYPERS.R_CUT, 
                            ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
                            ARCHITECTURAL_HYPERS.USE_LONG_RANGE, ARCHITECTURAL_HYPERS.K_CUT)
        
graph = molecule.get_graph(molecule.get_max_num(), get_all_species([structure]), molecule.get_num_k())
batch = Batch.from_data_list([graph])
prediction = model(batch_to_dict(batch))
#print(batch_to_dict(batch))
print(prediction)

tensor([[-2.3753]], grad_fn=<ScatterAddBackward0>)


# PETMetatensorWrapper prediction

In [4]:
from metatensor.torch.atomistic.ase_calculator import _compute_ase_neighbors

structure = ase.io.read('../pet/example/methane_train.xyz', index = 0)
print(structure)
system = System(species = torch.IntTensor(structure.get_atomic_numbers()),
                positions = torch.FloatTensor(structure.positions),
                cell = torch.zeros(3, 3, dtype = torch.float32))

options = NeighborsListOptions(model_cutoff=ARCHITECTURAL_HYPERS.R_CUT, full_list=True)

system.add_neighbors_list(options, _compute_ase_neighbors(structure, options).to(torch.float32))


Atoms(symbols='CH4', pbc=False, forces=..., calculator=SinglePointCalculator(...))


In [5]:
all_species = np.array([1, 6], dtype = int) # Hydrogen, Carbon
wrapped_model = PETMetatensorWrapper(model, all_species)
print(wrapped_model([system]))

tensor([[-2.3753]], grad_fn=<ScatterAddBackward0>)
