In [None]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from scipy.linalg import fractional_matrix_power
from scipy.special import gamma
from scipy.special import eval_legendre

# Plotting
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly

# Atoms
from ase.io import read
from rascal.representations import SphericalInvariants
from rascal.neighbourlist.structure_manager import AtomsList
from rascal.neighbourlist.structure_manager import mask_center_atoms_by_species

# Utilities
import h5py
import json
from tqdm.notebook import tqdm
from project_utils import load_structures_from_hdf5

# SOAP
from soap import quippy_soap, librascal_soap
from soap import reshape_soaps, compute_soap_density

# Compute atom-resolved density

In [None]:
# Load SOAP hyperparameters
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)

In [None]:
# Manually set a single cutoff for now
cutoff = 6.0
soap_hyperparameters['interaction_cutoff'] = cutoff

In [None]:
# Set grids
r_grid = np.linspace(0, cutoff, 50)
p_grid = np.linspace(-1, 1, 50)

In [None]:
# Make a SphericalInvariants representation
representation = SphericalInvariants(gaussian_sigma_type='Constant',
                                     **soap_hyperparameters)

## DEEM 10k

In [None]:
# Load structure
deem_10k = read('../Raw_Data/DEEM_10k/DEEM_10000.xyz', index=':')

In [None]:
# Load number of Si atoms in each structure
n_Si = np.loadtxt('../Processed_Data/DEEM_10k/n_Si.dat', dtype=int)

In [None]:
# Load full structure-averaged SOAPs
soaps_deem = load_structures_from_hdf5(f'../Processed_Data/DEEM_10k/Data/{cutoff}/soaps_full_avg.hdf5',
                                       datasets=None, concatenate=True)

# Convert to average over all structures
soaps_deem = np.sum(soaps_deem * n_Si[:, np.newaxis], axis=0) / np.sum(n_Si)

In [None]:
# Get feature index mapping
feature_map_deem = representation.get_feature_index_mapping(deem_10k)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_deem.values()]))

soaps_deem = reshape_soaps(soaps_deem, n_pairs, 
                           soap_hyperparameters['max_radial'], 
                           soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_deem = compute_soap_density(soap_hyperparameters['max_radial'],
                                    soap_hyperparameters['max_angular'],
                                    soap_hyperparameters['interaction_cutoff'],
                                    soaps_deem, r_grid, p_grid,
                                    chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid)
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_deem[0][2].flatten(),
                               isomin=0.05,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))
fig.show()

## Sodalite

In [None]:
# Load structure
sod = read('../Raw_Data/SOD/sodalite.xyz', index=':')

In [None]:
# Compute SOAPs
soaps_sod = librascal_soap(sod, [14],
                           **soap_hyperparameters,
                           average=True)
soaps_sod = soaps_sod[0]

In [None]:
# Get feature index mapping
feature_map_sod = representation.get_feature_index_mapping(sod)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_sod.values()]))

soaps_sod = reshape_soaps(soaps_sod, n_pairs, 
                          soap_hyperparameters['max_radial'], 
                          soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_sod = compute_soap_density(soap_hyperparameters['max_radial'],
                                   soap_hyperparameters['max_angular'],
                                   soap_hyperparameters['interaction_cutoff'],
                                   soaps_sod, r_grid, p_grid,
                                   chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid)
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_sod[0][0].flatten(),
                               isomin=0.1,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))
fig.show()