In [7]:
import numpy as np
import metatensor
from ase.io import read
from metatensor import TensorBlock, TensorMap,Labels
from itertools import product
import ase

from rascaline import NeighborList

# from anisotropic_gaussian_moments_expansion import * 

In [177]:
# Define function to compute all moments for a general dilation matrix.
# The implementation focuses on conceptual simplicity, while sacrifizing
# memory efficiency.
def compute_moments_inefficient_implementation(A, a, maxdeg):
    """
    Parameters:
    - A: symmetric 3x3 matrix (np.ndarray of shape (3,3))
        Dilation matrix of the Gaussian that determines its shape.
        It can be written as cov = RDR^T, where R is a rotation matrix that specifies
        the orientation of the three principal axes, while D is a diagonal matrix
        whose three diagonal elements are the lengths of the principal axes.
    - a: np.ndarray of shape (3,)
        Contains the position vector for the center of the trivariate Gaussian.
    - maxdeg: int
        Maximum degree for which the moments need to be computed.
        
    Returns:
    - The list of moments defined as
        <x^n0 * y^n1 * z^n2> = integral (x^n0 * y^n1 * z^n2) * exp(-0.5*(r-a).T@cov@(r-a)) dxdydz
        Note that the term "moments" in probability theory are defined for normalized Gaussian distributions.
        Here, we take the Gaussian 
    """
    # Make sure that the provided arrays have the correct dimensions and properties
    assert A.shape == (3,3), "Dilation matrix needs to be 3x3"
    assert np.sum((A-A.T)**2) < 1e-14, "Dilation matrix needs to be symmetric"
    assert a.shape == (3,1), "Center of Gaussian has to be given by a 3-dim. vector"
    assert maxdeg > 0, "The maximum degree needs to be at least 1"
    cov = np.linalg.inv(A) # the covariance matrix is the inverse of the matrix A
    global_factor = (2*np.pi)**1.5 / np.sqrt(np.linalg.det(A)) # normalization of Gaussian
    
    # Initialize the array in which to store the moments
    # moments[n0, n1, n2] will be set to <x^n0 * y^n1 * z^n2>
    # This representation is memory inefficient, since only about 1/3 of the
    # array elements will actually be relevant.
    # The advantage, however, is the simplicity in later use.
    moments = np.zeros((maxdeg+1, maxdeg+1, maxdeg+1))
    
    # Initialize the first few elements
    moments[0,0,0] = 1.
    moments[1,0,0] = a[0] # <x>
    moments[0,1,0] = a[1] # <y>
    moments[0,0,1] = a[2] # <z>
    if maxdeg == 1:
        return global_factor * moments
    
    # Initialize the quadratic elements
    moments[2,0,0] = cov[0,0] + a[0]**2
    moments[0,2,0] = cov[1,1] + a[1]**2
    moments[0,0,2] = cov[2,2] + a[2]**2
    moments[1,1,0] = cov[0,1] + a[0]*a[1]
    moments[0,1,1] = cov[1,2] + a[1]*a[2]
    moments[1,0,1] = cov[2,0] + a[2]*a[0]
    if maxdeg == 2:
        return global_factor * moments
    
    # Iterate over all possible exponents to generate all moments
    # Instead of iterating over n1, n2 and n3, we iterate over the total degree of the monomials
    # which will allow us to simplify certain edge cases.
    for deg in range(2, maxdeg):
        for n0 in range(deg+1):
            for n1 in range(deg+1-n0):
                # We consider monomials of degree "deg", and generate moments of degree deg+1.
                n2 = deg - n0 - n1
                
                # Run the x-iteration
                moments[n0+1,n1,n2] = a[0]*moments[n0,n1,n2] + cov[0,0]*n0*moments[n0-1,n1,n2]
                moments[n0+1,n1,n2] += cov[0,1]*n1*moments[n0,n1-1,n2] + cov[0,2]*n2*moments[n0,n1,n2-1]
                
                # If n0 is equal to zero, we also need the y- and z-iterations
                if n0 == 0:
                    # Run the y-iteration
                    moments[n0,n1+1,n2] = a[1]*moments[n0,n1,n2] + cov[1,0]*n0*moments[n0-1,n1,n2]
                    moments[n0,n1+1,n2] += cov[1,1]*n1*moments[n0,n1-1,n2] + cov[1,2]*n2*moments[n0,n1,n2-1]
                    
                    if n0 == 0 and n1 == 0:
                        # Run the z-iteration
                        moments[n0,n1,n2+1] = a[2]*moments[n0,n1,n2] + cov[2,0]*n0*moments[n0-1,n1,n2]
                        moments[n0,n1,n2+1] += cov[2,1]*n1*moments[n0,n1-1,n2] + cov[2,2]*n2*moments[n0,n1,n2-1]
    
    return global_factor * moments

In [181]:
hypers = {
    "interaction_cutoff": 4.5, # need to define the neighborlist 
    "A": np.eye(3),#anisotropy/dilation matrix
    "maxdeg":5 # max degree of expansion 
}

In [163]:
frames = read('/Users/jigyasa/scratch/data_papers/data/water/dataset/water_randomized_1000.xyz', ':2')
for f in frames: 
    f.pbc=False
#     f.pbc=True
#     f.cell = [5,5,5]
#     f.center()
global_species = np.unique(np.hstack([np.unique(f.numbers) for f in frames]))

In [170]:
#use rascaline to get the full neighborlist 
nl = NeighborList(hypers["interaction_cutoff"], True).compute(frames)

#nl is a tensormap with keys ('species_first_atom', 'species_second_atom')
#depending on the cutoff some species pairs may not appear 
#self pairs are not present but in PBC pairs between copies of the same atom are accounted for

# nl.keys_to_properties('species_second_atom')

In [182]:
sigma =0.3
A = hypers["A"]*sigma**2
maxdeg = hypers["maxdeg"]

In [193]:
sample_value.reshape((sample_value.shape)+(1,)).shape

(6, 6, 6, 1)

In [195]:
#accumulate the blocks for the pairwise expansion - 
desc_blocks=[]
for center_species in global_species:
    for neighbor_species in global_species:
        if (center_species, neighbor_species) in nl.keys:
            nl_block = nl.block(species_first_atom=center_species, species_second_atom=neighbor_species)
            desc_block_values = []
            for isample, nl_sample in enumerate(nl_block.samples):
                x,y,z = nl_block.values[isample,0], nl_block.values[isample,1],nl_block.values[isample,2]
                sample_value=compute_moments_inefficient_implementation(A, np.array([x,y,z]), maxdeg) #moments for the pair
                #this is a (maxdeg+1, maxdeg+1, maxdeg+1) matrix
                desc_block_values.append(sample_value.reshape((sample_value.shape)+(1,)))
                
#                 print(sample_value.shape)
#             desc_blocks.append(TensorBlock(values = np.asarray(desc_block_values),
#                                       samples = nl_block.samples,
#                                       components = [Labels(),Labels(),Labels()]
#                                       properties = [Labels(["dummy"], ["0"])]
            
            
#                                         )
                         
#                          )
# pair_aniso_desc = TensorMap(nl.keys, desc_blocks)


In [89]:
# To get the final descriptor, we just need to sum over the neighbor species 

desc_blocks = []
for center_species in global_species:
    for neighbor_species in global_species:
        if (center_species, neighbor_species) in pair_aniso_desc.keys:
            desc_samples=[]
            pair_block = pair_aniso_desc.block(species_first_atom=center_species, species_second_atom=neighbor_species)
            desc_block_values = []
            desc_samples = list(product(np.unique(block.samples['structure']), np.unique(block.samples['center'])))
            for isample, sample in enumerate(desc_samples):
                sample_idx = [idx for idx, tup in enumerate(nl_block.samples) if tup['structure']==sample[0] and tup['center']==sample[1]]
                
                sample_value+= desc_block_values.append(nl_block.values[sample_idx].sum(axis=0))
                desc_block_values.append(sample_value)
#             desc_blocks.append(TensorBlock(values = np.asarray(desc_block_values),
#                                       samples = Labels(["structure", "central_atom"], desc_samples),
#                                       components = [Labels()]
#                                       properties = [Labels()]
            
            
#                                         )

aniso_desc = TensorMap(Labels("center_species", np.asarray(global_species, dtype=int32)), desc_blocks)

Labels([(1, 0), (8, 0)],
       dtype=[('species_second_atom', '<i4'), ('distance', '<i4')])