In [None]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from scipy.linalg import fractional_matrix_power
from scipy.special import gamma
from scipy.special import eval_legendre

# Plotting
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly

# Atoms
from ase.io import read
from rascal.representations import SphericalInvariants
from rascal.neighbourlist.structure_manager import AtomsList
from rascal.neighbourlist.structure_manager import mask_center_atoms_by_species

# Utilities
import h5py
import json
from tqdm.notebook import tqdm
from project_utils import load_structures_from_hdf5

# SOAP
from soap import quippy_soap, librascal_soap

# Functions

In [None]:
# TODO: move functions to external module

In [None]:
def gto_sigma(cutoff, n, n_max):
    """
        Compute GTO sigma
    """
    return np.maximum(np.sqrt(n), 1) * cutoff / n_max

def gto_width(sigma):
    """
        Compute GTO width
    """
    return 1.0 / (2 * sigma ** 2)

def gto_prefactor(n, sigma):
    """
        Compute GTO prefactor
    """
    return np.sqrt(2 / (sigma ** (2 * n + 3) * gamma(n + 1.5)))

def gto(r, n, sigma):
    """
        Compute GTO
    """
    b = gto_width(sigma)
    N = gto_prefactor(n, sigma)
    return N * r ** (n + 1) * np.exp(-b * r ** 2) # why n+1?

def gto_overlap(n, m, sigma_n, sigma_m):
    """
        Compute overlap of two GTOs
    """
    b_n = gto_width(sigma_n)
    b_m = gto_width(sigma_m)
    N_n = gto_prefactor(n, sigma_n)
    N_m = gto_prefactor(m, sigma_m)
    nm = 0.5 * (3 + n + m)
    return 0.5 * N_n * N_m * (b_n + b_m) ** (-nm) * gamma(nm) # why 0.5?

def legendre_polynomials(l, x):
    """
        Evaluate Legendre Polynomials
    """
    return eval_legendre(l, x)

def reshape_soaps(soaps, n_pairs, n_max, l_max):
    """
        Reshape a SOAP vector to have the shape
        (n_centers, n_species_pairs, n_max, n_max, l_max+1)
    """
    if soaps.ndim == 1:
        return np.reshape(soaps, (1, n_pairs, n_max, n_max, l_max+1))
    else:
        return np.reshape(soaps, (soaps.shape[0], n_pairs, n_max, n_max, l_max+1))

def compute_soap_density(n_max, l_max, cutoff, soaps, 
                         r_grid, p_grid, chunk_size_r=0, chunk_size_p=0):
    """
        Compute SOAP density
    """
    
    n_grid = np.arange(0, n_max)
    l_grid = np.arange(0, l_max + 1)
    sigma_grid = gto_sigma(cutoff, n_grid, n_max)
    
    S = gto_overlap(n_grid[:, np.newaxis],
                    n_grid[np.newaxis, :],
                    sigma_grid[:, np.newaxis],
                    sigma_grid[np.newaxis, :])
    S = fractional_matrix_power(S, -0.5)
    
    R_n = np.matmul(S, gto(r_grid[np.newaxis, :],
                           n_grid[:, np.newaxis],
                           sigma_grid[:, np.newaxis]))
    
    P_l = legendre_polynomials(l_grid[:, np.newaxis],
                               p_grid[np.newaxis, :])
    
    if chunk_size_r <= 0:
        n_chunks_r = 1
    else:
        n_chunks_r = len(r_grid) // chunk_size_r
        if len(r_grid) % chunk_size_r > 0:
            n_chunks_r += 1
    
    if chunk_size_p <= 0:
        n_chunks_p = 1
    else:
        n_chunks_p = len(p_grid) // chunk_size_p
        if len(p_grid) % chunk_size_p > 0:
            n_chunks_p += 1
            
    density = np.zeros((soaps.shape[0], soaps.shape[1], 
                        len(r_grid), len(r_grid), len(p_grid)))
        
    for n in range(0, n_chunks_r):
        for m in range(0, n_chunks_r):
            for p in range(0, n_chunks_p):
                slice_n = slice(n * chunk_size_r, (n + 1) * chunk_size_r, 1)
                slice_m = slice(m * chunk_size_r, (m + 1) * chunk_size_r, 1)
                slice_p = slice(p * chunk_size_r, (p + 1) * chunk_size_p, 1)
                r_n = np.reshape(R_n[:, slice_n], (n_max, 1, 1, -1, 1, 1))
                r_m = np.reshape(R_n[:, slice_m], (1, n_max, 1, 1, -1, 1))
                p_l = np.reshape(P_l[:, slice_p], (1, 1, l_max + 1, 1, 1, -1))
                density[:, :, slice_n, slice_m, slice_p] = np.tensordot(soaps, r_n * r_m * p_l, axes=3)
                
    return density

# Compute atom-resolved density

In [None]:
# Load SOAP hyperparameters
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)

In [None]:
# Manually set a single cutoff for now
cutoff = 6.0
soap_hyperparameters['interaction_cutoff'] = cutoff

In [None]:
# Set grids
r_grid = np.linspace(0, cutoff, 50)
p_grid = np.linspace(-1, 1, 50)

In [None]:
# Make a SphericalInvariants representation
representation = SphericalInvariants(gaussian_sigma_type='Constant',
                                     **soap_hyperparameters)

## DEEM 10k

In [None]:
# Load structure
deem_10k = read('../Raw_Data/DEEM_10k/DEEM_10000.xyz', index=':')

In [None]:
# Load number of Si atoms in each structure
n_Si = np.loadtxt('../Processed_Data/DEEM_10k/n_Si.dat', dtype=int)

In [None]:
# Load full structure-averaged SOAPs
soaps_deem = load_structures_from_hdf5(f'../Processed_Data/DEEM_10k/Data/{cutoff}/soaps_full_avg.hdf5',
                                       datasets=None, concatenate=True)

# Convert to average over all structures
soaps_deem = np.sum(soaps_deem * n_Si[:, np.newaxis], axis=0) / np.sum(n_Si)

In [None]:
# Get feature index mapping
feature_map_deem = representation.get_feature_index_mapping(deem_10k)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_deem.values()]))

soaps_deem = reshape_soaps(soaps_deem, n_pairs, 
                           soap_hyperparameters['max_radial'], 
                           soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_deem = compute_soap_density(soap_hyperparameters['max_radial'],
                                    soap_hyperparameters['max_angular'],
                                    soap_hyperparameters['interaction_cutoff'],
                                    soaps_deem, r_grid, p_grid,
                                    chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid)
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_deem[0][2].flatten(),
                               isomin=0.05,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))
fig.show()

## Sodalite

In [None]:
# Load structure
sod = read('../Raw_Data/SOD/sodalite.xyz', index=':')

In [None]:
# Compute SOAPs
soaps_sod = librascal_soap(sod, [14],
                           **soap_hyperparameters,
                           average=True)
soaps_sod = soaps_sod[0]

In [None]:
# Get feature index mapping
feature_map_sod = representation.get_feature_index_mapping(sod)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_sod.values()]))

soaps_sod = reshape_soaps(soaps_sod, n_pairs, 
                          soap_hyperparameters['max_radial'], 
                          soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_sod = compute_soap_density(soap_hyperparameters['max_radial'],
                                   soap_hyperparameters['max_angular'],
                                   soap_hyperparameters['interaction_cutoff'],
                                   soaps_sod, r_grid, p_grid,
                                   chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid)
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_sod[0][0].flatten(),
                               isomin=0.1,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))
fig.show()