In [None]:
import numpy as np

import torch
torch.set_default_dtype(torch.float64)

from ase.io import read
from ase.build import bulk

from mlelec.data.dataset import QMDataset
from mlelec.features.acdc import compute_features

import metatensor.torch as mts

import rascaline.torch
from rascaline.torch.utils.clebsch_gordan import EquivariantPowerSpectrumByPair

In [None]:
device = 'cpu'

In [None]:
orbitals = {'gthszv':  {14: [[3,0,0], [3,1,1], [3,1,-1], [3,1,0]]}}
ORBS = 'gthszv'
frames = [bulk('Si', crystalstructure='diamond')]

qmdata = QMDataset(frames = frames, 
                   kmesh = [1,1,1], 
                   dimension = 3,
                   fock_kspace = [np.random.randn(1, 4, 4)], 
                   device = device, 
                   orbs = orbitals[ORBS],
                   orbs_name = ORBS)

In [None]:
max_radial  = 3
max_angular = 3
atomic_gaussian_width = 0.3
cutoff = 6

hypers_pair = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': atomic_gaussian_width,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}


hypers_atom = {'cutoff': 4,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': 0.5,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}


return_rho0ij = False
both_centers = False
LCUT = 3

In [None]:
desc1_ = compute_features(qmdata, 
                          hypers_atom, 
                          hypers_pair = hypers_pair, 
                          lcut = None,
                          return_rhoij=True,
                          device = 'cpu')
desc1_ = mts.rename_dimension(mts.permute_dimensions(desc1_, axis = 'keys', dimensions_indexes = [0, 2, 1, 3, 4]), axis = 'keys', old = 'spherical_harmonics_l', new = 'o3_lambda')
desc1_ = mts.rename_dimension(desc1_, axis = 'keys', old = 'inversion_sigma', new = 'o3_sigma')
desc1_ = mts.rename_dimension(desc1_, axis = 'keys', old = 'species_center', new = 'first_atom_type')
desc1 = mts.rename_dimension(desc1_, axis = 'keys', old = 'species_neighbor', new = 'second_atom_type')

In [None]:
global_atom_types = [14]

calc = EquivariantPowerSpectrumByPair(
    spherical_expansion_hypers=hypers_atom,
    spherical_expansion_by_pair_hypers=hypers_pair,
    atom_types=global_atom_types
)

desc2 = calc(rascaline.torch.systems_to_torch(frames))

In [None]:
%%timeit -r 1 -n 1
desc1_ = compute_features(qmdata, 
                         hypers_atom, 
                         hypers_pair = hypers_pair, 
                         lcut = None,
                         device = 'cpu')

In [None]:
%%timeit -r 1 -n 1

calc = EquivariantPowerSpectrumByPair(
    spherical_expansion_hypers=hypers_atom,
    spherical_expansion_by_pair_hypers=hypers_pair,
    atom_types=global_atom_types
)
desc2 = calc(rascaline.torch.systems_to_torch(frames))