In [1]:
from typing import List, Union

import torch
from metatensor.learn.data.dataset import _BaseDataset
from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import (
    NeighborsListOptions,
    System,
    register_autograd_neighbors,
)
from rascaline.torch import NeighborList


REQUIRED_NL_SAMPLES = [
    "first_atom",
    "second_atom",
    "cell_shift_a",
    "cell_shift_b",
    "cell_shift_c",
]

In [2]:
def get_rascaline_neighbors_list(
    systems: Union[System, List[System]], options: NeighborsListOptions
) -> Union[TensorBlock, List[TensorBlock]]:
    """
    Calculates the neighborlist for a given system using
    rascaline.NeghborList calculator. Additionally, it registers
    the neighborlist for torch autograd.

    :param systems: A single systems or a list systems.
    :param options: A NeighborsListOptions object.

    :return: A TensorBlock or a list of TensorBlocks containing neigborlists
    information.
    """

    if not isinstance(systems, list):
        systems = [systems]
    nl_calculator = NeighborList(cutoff=5.0, full_neighbor_list=True)
    nl_list = []
    for system in systems:
        nl_tmap = nl_calculator.compute(system)
        tmp_nl = nl_tmap.keys_to_samples(nl_tmap.keys.names).block()
        required_indices = []
        for i in range(len(tmp_nl.samples.names)):
            if tmp_nl.samples.names[i] in REQUIRED_NL_SAMPLES:
                required_indices.append(i)
        samples = Labels(
            names=REQUIRED_NL_SAMPLES, values=tmp_nl.samples.values[:, required_indices]
        )
        components = Labels(names=["xyz"], values=tmp_nl.components[0].values)
        properties = Labels(
            names=tmp_nl.properties.names,
            values=torch.zeros_like(tmp_nl.properties.values),
        )
        nl = TensorBlock(
            samples=samples,
            components=[components],
            properties=properties,
            values=tmp_nl.values,
        )
        nl_list.append(nl)
    if len(nl_list) == 1:
        return nl_list[0]
    else:
        return nl_list

In [6]:
nl_calculator = NeighborList(
        cutoff=5.0, full_neighbor_list=True
    )

In [7]:
torch.jit.script(nl_calculator)

RecursiveScriptModule(original_name=NeighborList)

In [5]:
def test_rascaline_nl_torch_script():
    nl_calculator = NeighborList(
            cutoff=5.0, full_neighbor_list=True
    )

In [6]:
torch.jit.script(test_rascaline_nl_torch_script)

FrontendError: Cannot instantiate class 'NeighborList' in a script function:
  File "/var/folders/0v/zgrhvkh12xz59z8f0vwvj86c0000gn/T/ipykernel_42824/109602194.py", line 2
def test_rascaline_nl_torch_script():
    nl_calculator = NeighborList(
                    ~~~~~~~~~~~~ <--- HERE
            cutoff=5.0, full_neighbor_list=True
    )


In [None]:
DATASET_PATH = "examples/alchemical_model/hea_samples_bulk.xyz"
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

structures = read_structures(DATASET_PATH)

nl_options = NeighborsListOptions(model_cutoff=5.0, full_list=True)
nls = get_rascaline_neighbors_list(structures, nl_options)
for structure, nl in zip(structures, nls):
    structure.add_neighbors_list(nl_options, nl)

conf = {
    "energy": {
        "quantity": "energy",
        "read_from": DATASET_PATH,
        "file_format": ".xyz",
        "key": "energy",
        "forces": False,
        "stress": False,
        "virial": False,
    }
}
targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(structure=structures, energy=targets["energy"])

In [13]:
for item in dataset:
    pass

ValueError: Index 100 not in dataset

In [9]:
len(dataset)

100

In [10]:
dataset[]

Sample(structure=System with 48 atoms, periodic cell: [7.43546, 0, 0, 0, 7.43546, 0, 0, 0, 11.1532])

In [8]:
dataset._size

100

In [4]:
DATASET_PATH = "examples/alchemical_model/hea_samples_bulk.xyz"
structures = read_structures(DATASET_PATH)
dataset = Dataset(structure=structures, size=len(structures))

for item in dataset:
    pass

ValueError: Index 100 not in dataset

In [None]:
for 

In [13]:
NeighborsListOptions(model_cutoff=5.0, full_list=True, )

RuntimeError: Unknown keyword argument 'max_neighbors' for operator '__init__'. Schema: __init__(__torch__.torch.classes.metatensor.NeighborsListOptions _0, float model_cutoff, bool full_list, str requestor="") -> NoneType _0

In [10]:
item.structure

System with 48 atoms, periodic cell: [6.70861, 0, 0, 0, 6.70861, 0, 0, 0, 10.0629]

In [92]:
nl_list = get_rascaline_neighbors_list(structures, nl_options)

In [93]:
structure = structures[0]
nl = nl_list[0]

In [62]:
from typing import Union

import ase
import torch
from metatensor.torch.atomistic import NeighborsListOptions, System
from metatensor.torch.atomistic.ase_calculator import _compute_ase_neighbors


def get_ase_neighbors_list(
    structure: Union[ase.Atoms, System],
    nl_options: NeighborsListOptions,
):
    if isinstance(structure, torch.ScriptObject):
        structure = ase.Atoms(
            numbers=structure.species.numpy(),
            positions=structure.positions.detach().numpy(),
            cell=structure.cell.detach().numpy(),
            pbc=[True, True, True],
        )
    nl = _compute_ase_neighbors(structure, nl_options)
    return nl, nl_options

In [67]:
nl_ase = get_ase_neighbors_list(atoms_list[0], nl_options)

In [94]:
nl_ase[0].properties.values

tensor([[0]], dtype=torch.int32)

In [95]:
nl.properties.values

tensor([[0]], dtype=torch.int32)

In [96]:
structure.add_neighbors_list(nl_options, nl)

In [3]:


for atoms, structure in zip(atoms_list, structures):
    nl, nl_options = get_primitive_neighbors_list(
        atoms, model_cutoff=5.0, full_list=True
    )
    structure.add_neighbors_list(nl_options, nl)

conf = {
    "energy": {
        "quantity": "energy",
        "read_from": DATASET_PATH,
        "file_format": ".xyz",
        "key": "energy",
        "forces": False,
        "stress": False,
        "virial": False,
    }
}

targets = read_targets(OmegaConf.create(conf))
dataset = Dataset(structure=structures, U0=targets["U0"])