Import modules

In [None]:
from featomic.torch.clebsch_gordan import (
    EquivariantPowerSpectrumByPair,
    EquivariantPowerSpectrum,
)
from featomic.torch import SphericalExpansion, SphericalExpansionByPair
import metatensor.torch as mts
import ase.io
import numpy as np
import torch

from metatrain.experimental.nanopet.modules.augmentation import RotationalAugmenter
from metatrain.experimental.nanopet_on_basis.utils import get_system_transformations
from metatrain.utils.data.target_info import TargetInfo

import sys

sys.path.append("/home/pegolo/Software/nanomlelec/src")
from mlelec.nn import HamiltonianDescriptor

dtype = torch.float64

Convenience functions

In [None]:
def compute_node_edge_lambda_soap(spex, spex_by_pair, systems):
    selected_keys = mts.Labels.range("o3_lambda", 3)

    # Node features
    node_calc = EquivariantPowerSpectrum(spex, dtype=dtype)
    target_node = node_calc(
        systems,
        selected_keys=selected_keys,
        neighbors_to_properties=True,
    )

    # Edge features
    edge_calc = EquivariantPowerSpectrumByPair(spex, spex_by_pair, dtype=dtype)
    target_edge = edge_calc(
        systems,
        selected_keys=selected_keys,
        neighbors_to_properties=True,
    )

    return mts.sort(target_node), mts.sort(target_edge)


def compute_node_edge_symm(spex, spex_by_pair, systems):

    calc = HamiltonianDescriptor(
        spex, spex_by_pair, neighbor_types=[1, 6, 7, 16], dtype=dtype
    )

    target_node, target_edge = calc(systems)

    target_node = mts.sort(
        mts.permute_dimensions(
            mts.drop_blocks(
                target_node, mts.Labels("o3_lambda", torch.arange(3, 11).reshape(-1, 1))
            ),
            "properties",
            [2, 3, 0, 1, 4, 5],
        )
    )

    target_edge = mts.sort(
        mts.permute_dimensions(
            mts.drop_blocks(
                target_edge, mts.Labels("o3_lambda", torch.arange(3, 11).reshape(-1, 1))
            ),
            "properties",
            [2, 3, 0, 1, 4, 5],
        )
    )

    return target_node, target_edge

def compute_target_info(target_node, target_edge):
    return {
        "mtt::node": TargetInfo(
            quantity="node",
            layout=mts.TensorMap(
                target_node.keys,
                [
                    mts.TensorBlock(
                        samples=mts.Labels.empty(["system", "atom"]),
                        components=block.components,
                        properties=block.properties,
                        values=torch.empty((0, *block.values.shape[1:])),
                    )
                    for block in target_node
                ],
            ),
        ),
        "mtt::edge": TargetInfo(
            quantity="edge",
            layout=mts.TensorMap(
                target_edge.keys,
                [
                    mts.TensorBlock(
                        samples=mts.Labels.empty(
                            [
                                "system",
                                "first_atom",
                                "second_atom",
                                "cell_shift_a",
                                "cell_shift_b",
                                "cell_shift_c",
                            ]
                        ),
                        components=block.components,
                        properties=block.properties,
                        values=torch.empty((0, *block.values.shape[1:])),
                    )
                    for block in target_edge
                ],
            ),
        ),
    }


def apply_augmentation(rotational_augmenter, systems, target_node, target_edge):

    # Define a random transformation for each training system
    rotations, inversions = get_system_transformations(systems)

    # Apply rotational augmentation - node
    systems_train, targets_train_node = rotational_augmenter.apply_augmentations(
        systems,
        {"mtt::node": target_node},
        rotations,
        inversions,
    )
    targets_train_node = mts.sort(targets_train_node["mtt::node"])

    # Apply rotational augmentation - edge
    _, targets_train_edge = rotational_augmenter.apply_augmentations(
        systems,
        {"mtt::edge": target_edge},
        rotations,
        inversions,
    )

    targets_train_edge = mts.sort(targets_train_edge["mtt::edge"])

    return systems_train, targets_train_node, targets_train_edge

Load system and define hypers

In [None]:
frame = ase.io.read("qm7x_reduced_100.xyz", index="0")
systems = [mts.atomistic.systems_to_torch(frame).to(dtype=dtype)]

hypers = {
    "cutoff": {"radius": 3.5, "smoothing": {"type": "ShiftedCosine", "width": 0.1}},
    "density": {"type": "Gaussian", "width": 0.3},
    "basis": {
        "type": "TensorProduct",
        "max_angular": 5,
        "radial": {"type": "Gto", "max_radial": 2},
    },
}
spex = SphericalExpansion(**hypers)

hypers = {
    "cutoff": {"radius": 6.0, "smoothing": {"type": "ShiftedCosine", "width": 0.1}},
    "density": {"type": "Gaussian", "width": 0.3},
    "basis": {
        "type": "TensorProduct",
        "max_angular": 1,
        "radial": {"type": "Gto", "max_radial": 2},
    },
}
spex_by_pair = SphericalExpansionByPair(**hypers)

# Non-symmetrized features

In [None]:
target_node, target_edge = compute_node_edge_lambda_soap(spex, spex_by_pair, systems)
target_info = compute_target_info(target_node, target_edge)
rotational_augmenter = RotationalAugmenter(target_info)

# Apply rotational augmentation
for _ in range(10):
    systems_train, targets_train_node, targets_train_edge = apply_augmentation(
        rotational_augmenter, systems, target_node, target_edge
    )
    # Checks
    assert mts.equal_metadata(target_node, targets_train_node)
    assert not mts.allclose(target_node, targets_train_node)
    assert mts.equal_metadata(target_edge, targets_train_edge)
    assert not mts.allclose(target_edge, targets_train_edge)

    # Actually compute node and edge features for the rotated systems
    rotated_node, rotated_edge = compute_node_edge_lambda_soap(spex, spex_by_pair, systems_train)
    # Checks
    assert mts.equal_metadata(rotated_node, targets_train_node)
    assert mts.allclose(rotated_node, targets_train_node)
    assert mts.equal_metadata(rotated_edge, targets_train_edge)
    assert mts.allclose(rotated_edge, targets_train_edge)

# Ham descriptor

In [None]:
target_node, target_edge = compute_node_edge_symm(spex, spex_by_pair, systems)
target_info = compute_target_info(target_node, target_edge)
rotational_augmenter = RotationalAugmenter(target_info)

# Apply rotational augmentation
for _ in range(10):
    systems_train, targets_train_node, targets_train_edge = apply_augmentation(
        rotational_augmenter, systems, target_node, target_edge
    )
    # Checks
    assert mts.equal_metadata(target_node, targets_train_node)
    assert not mts.allclose(target_node, targets_train_node)
    assert mts.equal_metadata(target_edge, targets_train_edge)
    assert not mts.allclose(target_edge, targets_train_edge)

    # Actually compute node and edge features for the rotated systems
    rotated_node, rotated_edge = compute_node_edge_symm(spex, spex_by_pair, systems_train)
    # Checks
    assert mts.equal_metadata(rotated_node, targets_train_node)
    assert mts.allclose(rotated_node, targets_train_node)
    assert mts.equal_metadata(rotated_edge, targets_train_edge)
    assert mts.allclose(rotated_edge, targets_train_edge)