In [None]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from scipy.linalg import fractional_matrix_power
from scipy.special import gamma
from scipy.special import eval_legendre

# Plotting
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly
import nglview
import ipywidgets

# Atoms
from ase.io import read
from ase.neighborlist import neighbor_list
from rascal.representations import SphericalInvariants
from rascal.neighbourlist.structure_manager import AtomsList
from rascal.neighbourlist.structure_manager import mask_center_atoms_by_species

# ML
from kernels import sqeuclidean_distances

# Utilities
import h5py
import json
import itertools
from tqdm.notebook import tqdm
from project_utils import load_structures_from_hdf5

# SOAP
from soap import quippy_soap, librascal_soap
from soap import reshape_soaps, compute_soap_density

# Compute atom-resolved density

In [None]:
# Load SOAP hyperparameters
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)

In [None]:
# Manually set a single cutoff for now
cutoff = 6.0
soap_hyperparameters['interaction_cutoff'] = cutoff

In [None]:
# Set grids
r_grid = np.linspace(0, cutoff, 50)
p_grid = np.linspace(-1, 1, 50)

In [None]:
# Make a SphericalInvariants representation
representation = SphericalInvariants(gaussian_sigma_type='Constant',
                                     **soap_hyperparameters)

## DEEM 10k

In [None]:
# Load structure
deem_10k = read('../Raw_Data/DEEM_10k/DEEM_10000.xyz', index=':')

In [None]:
# Load number of Si atoms in each structure
n_Si = np.loadtxt('../Processed_Data/DEEM_10k/n_Si.dat', dtype=int)

In [None]:
# Load full structure-averaged SOAPs
soaps_deem = load_structures_from_hdf5(f'../Processed_Data/DEEM_10k/Data/{cutoff}/soaps_full_avg.hdf5',
                                       datasets=None, concatenate=True)

# Convert to average over all structures
soaps_deem = np.sum(soaps_deem * n_Si[:, np.newaxis], axis=0) / np.sum(n_Si)

In [None]:
# Get feature index mapping
feature_map_deem = representation.get_feature_index_mapping(deem_10k)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_deem.values()]))

soaps_deem = reshape_soaps(soaps_deem, n_pairs, 
                           soap_hyperparameters['max_radial'], 
                           soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_deem = compute_soap_density(soap_hyperparameters['max_radial'],
                                    soap_hyperparameters['max_angular'],
                                    soap_hyperparameters['interaction_cutoff'],
                                    soaps_deem, r_grid, p_grid,
                                    chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid, indexing='ij')
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_deem[0][2].flatten(),
                               isomin=0.05,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))
fig.show()

## Sodalite

In [None]:
# Load structure
sod = read('../Raw_Data/SOD/sodalite.xyz', index=':')

In [None]:
def rrw_neighbors(frame, center_species, env_species, cutoff, self_interaction=False):
    """
        Compute the neighbor list for every atom of the central atom species
        and generate the r, r', w for each pair of neighbors 
        
        ---Arguments---
        frame: atomic structure
        center_species: species of atoms to use as centers
        env_species: species of atoms to include in the environment
        cutoff: atomic environment cutoff
        self_interaction: include the central atom as its own neighbor
        
        ---Returns---
        rrw: list of a list of numpy 3D numpy arrays. 
            Each numpy array is of shape (3, n_neighbors_a, n_neighbors_b),
            where the axes are organized as follows:
            axis=0: distances to neighbor A from the central atom
            axis=1: distances to neighbor B from the central atom
            axis=2: angle between the distance vectors to neighbors A and B from the central atom
        idxs: same structure as rrw, but holds the indices of the atoms involved in the tuple, i.e.,
            axis=0: index of central atom
            axis=1: index of neighbor A
            axis=2: index of neighbor B
    """
    
    # Extract indices of central atoms and environment atoms
    center_species_idxs = [np.nonzero(frame.numbers == i)[0] for i in center_species]
    env_species_idxs = [np.nonzero(frame.numbers == i)[0] for i in env_species]
    
    # Build neighbor list for all atoms
    nl = {}
    nl['i'], nl['j'], nl['d'], nl['D'] = neighbor_list('ijdD', frame, cutoff, 
                                                       self_interaction=self_interaction)
    
    rrw = []
    idxs = []
    
    # Loop over centers grouped by species
    for center_idxs in center_species_idxs:
        for center in center_idxs:
            
            # Build subset of neighbor list that just has the neighbors of
            # the center
            center_nl_idxs = np.nonzero(nl['i'] == center)[0]
            nl_center = {}
            for k, v in nl.items():
                nl_center[k] = v[center_nl_idxs]
                
            rrw_species = []
            idxs_species = []
                
            # Loop over combinations of environment species
            for env_species_a, env_species_b in itertools.combinations_with_replacement(env_species_idxs, 2):
                a = np.nonzero(np.isin(nl_center['j'], env_species_a))[0]
                b = np.nonzero(np.isin(nl_center['j'], env_species_b))[0]

                # Extract distances to neighbors from the central atom (r, r')
                da = nl_center['d'][a]
                db = nl_center['d'][b]
                Da = nl_center['D'][a]
                Db = nl_center['D'][b]
                r_n, r_m = np.meshgrid(da, db, indexing='ij')                
                
                # Compute angles between neighbors and central atom (w)
                D = np.matmul(Da, Db.T)
                d = np.outer(da, db)
                d[d <= 0.0] = 1.0
                w = D / d

                # Extract indices of the atoms in the rr'w triplet
                ia = nl_center['j'][a]
                ib = nl_center['j'][b]
                j_n, j_m = np.meshgrid(ia, ib, indexing='ij')
                j_center = np.full(j_n.shape, center, dtype=int)
                
                # Build 3D matrix of rr'w triplets
                rrw_species.append(np.stack((r_n, r_m, w)))
                idxs_species.append(np.stack((j_center, j_n, j_m)))
            
            rrw.append(rrw_species)
            idxs.append(idxs_species)
    
    return rrw, idxs

def make_tuples(data):
    """
        Take a list of lists of rr'w formatted 3D arrays (see rrw_neighbors)
        and reshape into a list of lists of 2D arrays of shape (n_neighbor_pairs, 3),
        where each row is a rr'w triplet and the columns are in the order r, r', w
        
        ---Arguments---
        data: list of lists of arrays to "reshape"
        
        ---Returns---
        center_tuple: "reshaped" data list
    """
    n_centers = len(data)
    center_tuple = []
    
    # Loop over centers
    for nctr in range(0, n_centers):
        n_pairs = len(data[nctr])
        pair_tuple = []
        
        # Loop over species pairs
        for npr in range(0, n_pairs):
            data_shape = np.shape(data[nctr][npr])
            
            # Reshape the 3D array to a 2D array
            tuple_array = np.reshape(np.moveaxis(data[nctr][npr], 0, -1), 
                                     (np.prod(data_shape[1:]), data_shape[0]))
            
            pair_tuple.append(tuple_array)
        
        center_tuple.append(pair_tuple)
    
    return center_tuple

def unique_environments(data, idxs, tol=1.0E-8):
    """
        Extract unique environments from an rr'w list of lists of arrays.
        Identical environments are those that are less than a distance tol
        from each other
        
        ---Arguments---
        data: rr'w list of lists of arrays (see rrw_neighbors)
        idxs: indices for the rr'w list of lists of arrays
        
        ---Returns---
        list of unique environments, in some TBD format
    """
    
    # Reshape the rr'w and index lists
    data_tuples = make_tuples(data)
    idxs_tuples = make_tuples(idxs)

    n_centers = len(data)
    for nctr in range(0, n_centers):
        n_pairs = len(data[nctr])
        for npr in range(0, n_pairs):
            
            #https://stackoverflow.com/questions/37847053/uniquify-an-array-list-with-a-tolerance-in-python-uniquetol-equivalent
            D = sqeuclidean_distances(data_tuples[nctr][npr], data_tuples[nctr][npr])
            print(~np.triu(D <= tol, k=1))
            idxs = np.nonzero(np.all(~np.triu(D <= tol, k=1), axis=0))[0] # this is close, fix it
            print(idxs)
            # TODO: remove self-distances
            
    print(D)

In [None]:
rrw, idxs = rrw_neighbors(sod[0], [14], [8, 14], 6.0, self_interaction=True)

In [None]:
unique_environments(rrw, idxs, tol=1.0)

In [None]:
rrw_tuple = make_tuples(rrw)

In [None]:
# Compute SOAPs
soaps_sod = librascal_soap(sod, [14],
                           **soap_hyperparameters,
                           average=True)
soaps_sod = soaps_sod[0]

In [None]:
# Get feature index mapping
feature_map_sod = representation.get_feature_index_mapping(sod)

n_pairs = len(set([(v['a'], v['b']) for v in feature_map_sod.values()]))

soaps_sod = reshape_soaps(soaps_sod, n_pairs, 
                          soap_hyperparameters['max_radial'], 
                          soap_hyperparameters['max_angular'])

In [None]:
# Compute density
density_sod = compute_soap_density(soap_hyperparameters['max_radial'],
                                   soap_hyperparameters['max_angular'],
                                   soap_hyperparameters['interaction_cutoff'],
                                   soaps_sod, r_grid, p_grid,
                                   chunk_size_r=10, chunk_size_p=10)

In [None]:
# Plot
species_idx = 1
rx_grid, ry_grid, tz_grid = np.meshgrid(r_grid, r_grid, p_grid, indexing='ij')
fig = go.Figure(data=go.Volume(x=rx_grid.flatten(),
                               y=ry_grid.flatten(),
                               z=tz_grid.flatten(),
                               value=density_sod[0][species_idx].flatten(),
                               isomin=0.01,
                               isomax=None,
                               opacity=0.2,
                               surface_count=20))

fig.add_trace(go.Scatter3d(x=rrw[0][species_idx][0].flatten(),
                           y=rrw[0][species_idx][1].flatten(),
                           z=rrw[0][species_idx][2].flatten(),
                           mode='markers',
                           marker=dict(size=2,
                                       color='red'),
                           hovertemplate='x: %{x}<br>y: %{y}<br>z: %{z}<br>(i, j): %{text}',
                           text=['{}'.format(i) for i in zip(idxs[0][species_idx][0].flatten(),
                                                             idxs[0][species_idx][1].flatten(), 
                                                             idxs[0][species_idx][2].flatten())]))
fig.show()

In [None]:
def view_grid(frame, idxs, n_col):
    """
        Adapted from: https://github.com/nglviewer/nglview/blob/master/examples/users/ase.md
    """
    viewers = []
    for c, a, b in zip(idxs[0].flatten()[0:3], idxs[1].flatten()[0:3], idxs[2].flatten()[0:3]):
        viewer = nglview.show_ase(frame)
        viewer._set_size('500px', '500px')
        viewer.add_representation('ball+stick', selection=[c], color='black', radius=0.3)
        viewer.add_representation('ball+stick', selection=[a, b], color='blue', radius=0.3)
        viewers.append(viewer)
        
    n_rows = int(np.ceil(len(viewers) / n_col))
    row_viewers = [ipywidgets.HBox(viewers[n*n_col:(n+1)*n_col]) for n in range(0, n_rows)]
    return ipywidgets.VBox(row_viewers)

In [None]:
view = view_grid(sod[0], idxs[0][0], 3)