In [None]:
# System
import os
import sys
sys.path.append('/home/helfrech/Tools/Toolbox/utils')

# Maths
import numpy as np
from scipy.linalg import fractional_matrix_power
from scipy.special import gamma
from scipy.special import eval_legendre

# Plotting
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly
import nglview
import ipywidgets

# Atoms
from ase.io import read
from ase.neighborlist import neighbor_list
from rascal.representations import SphericalInvariants
from rascal.neighbourlist.structure_manager import AtomsList
from rascal.neighbourlist.structure_manager import mask_center_atoms_by_species

# ML
from kernels import sqeuclidean_distances
from soap import rrw_neighbors, make_tuples
from soap import reshape_soaps, compute_soap_density

# Utilities
import h5py
import json
import itertools
from tqdm.notebook import tqdm
from project_utils import load_structures_from_hdf5

# SOAP
from soap import quippy_soap, librascal_soap
from soap import reshape_soaps, compute_soap_density

# Compute atom-resolved density

In [None]:
# Load SOAP hyperparameters
with open('../Processed_Data/soap_hyperparameters.json', 'r') as f:
    soap_hyperparameters = json.load(f)

In [None]:
# Manually set a single cutoff for now
cutoff = 6.0
soap_hyperparameters['interaction_cutoff'] = cutoff

In [None]:
# Set grids
r_grid = np.linspace(0, cutoff, 50)
p_grid = np.linspace(-1, 1, 50)

## DEEM 10k

## TODO: also load IZA, and average over both IZA and DEEM soaps in the train set used to build the full KSVC-KPCovR models

## TODO: compute soap densities, weight densities, and triplet indices

In [16]:
# Linear model setup
n_species = 2
#group_names = {'power': ['OO', 'OSi', 'SiSi', 
#                         'OO+OSi', 'OO+SiSi', 'OSi+SiSi',
#                         'OO+OSi+SiSi'], 
#               'radial': ['O', 'Si', 'O+Si']}
group_names = {'power': ['OO', 'OSi', 'SiSi'],
               'radial': ['O', 'Si']}

In [None]:
model_dir = '../Processed_Data/Models'

In [17]:
deem_name = 'DEEM_10k'
iza_name = 'IZA_226'
deem_dir = f'../Processed_Data/{deem_name}/Data'
iza_dir = f'../Processed_Data/{iza_name}/Data'

In [None]:
# Load structures
deem_10k = read('../Raw_Data/DEEM_10k/DEEM_10000.xyz', index=':')
iza_226 = read('../Raw_Data/GULP/IZA_226/IZA.xyz', index=':')

In [7]:
# Load IZA cantons
cantons_iza = np.loadtxt('../Raw_Data/GULP/IZA_226/cantons.txt', usecols=1, dtype=int)
RWY = np.nonzero(cantons_iza == 4)[0][0]

In [8]:
cantons_iza = np.delete(cantons_iza, RWY)
n_iza = len(cantons_iza)

In [None]:
iza_226.pop(RWY)

In [None]:
# Load number of Si atoms in each structure
n_Si_deem = np.loadtxt('../Processed_Data/DEEM_10k/n_Si.dat', dtype=int)
n_Si_iza = np.loadtxt('../Processed_Data/IZA_226/n_Si.dat', dtype=int)
n_Si_iza = np.delete(n_Si_iza, RWY)

In [None]:
n_r_grid = 50
chunk_size_r = 10

n_p_grid = 50
chunk_size_p = 10

r_grid = np.linspace(0.0, cutoff, n_r_grid)
p_grid = np.linspace(-1.0, 1.0, n_p_grid)

In [None]:
# TODO: iteratively compute and save to HDF5 the rrw and indices for each structure as HDF5 datasets
# Do rrw and indices as separate groups, structures as separate datasets in each group
rrw, idxs = rrw_neighbors(sod[0], [14], [8, 14], 6.0, self_interaction=True)

In [None]:
for cutoff in tqdm(cutoffs, desc='Cutoff', leave=True):
    linear_dir = f'{model_dir}/{cutoff}/Linear_Models/LSVC-LPCovR'
    
    for spectrum_type in tqdm(('power', 'radial'), desc='Spectrum', leave=False):
        spectrum_name = spectrum_type.capitalize()
        
        # Load SOAPs
        deem_file = f'{deem_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        iza_file = f'{iza_dir}/{cutoff}/soaps_{spectrum_type}_full_avg_nonorm.hdf5'
        
        # TODO: what do we want to use to compute the SOAP density?
        # Just the train set? All IZA+DEEM structures?
        #soaps_train, soaps_test = utils.load_soaps(deem_file, iza_file,
        #                                           idxs_deem_train, idxs_deem_test,
        #                                           idxs_iza_train, idxs_iza_test,
        #                                           idxs_iza_delete=[RWY],
        #                                           train_test_concatenate=True)
        
        soaps_deem = utils.load_structures_from_hdf5(deem_file, datasets=None, concatenate=True)
        soaps_deem = np.sum(soaps_deem * n_Si_deem[:, np.newaxis], axis=0)
        
        soaps_iza = utils.load_structures_from_hdf5(iza_file, datasets=None, concatenate=True)
        soaps_iza = np.delete(iza_soaps, RWY, axis=0)
        soaps_iza = np.sum(soaps_iza * n_Si_iza[:, np.newaxis], axis=0)
        
        soaps_all = np.concatenate(iza_soaps, deem_soaps)
        soaps_all /= np.sum(n_Si_deem) + np.sum(n_Si_iza)
        
        n_features = soaps_all.shape[1]
        feature_groups = extract_species_pair_groups(n_features, n_species, 
                                                     spectrum_type=spectrum_type,
                                                     combinations=False)

        for species_pairing, feature_idxs in zip(tqdm(group_names[spectrum_type], 
                                                      desc='Species', leave=False),
                                                 feature_groups):
            
            soaps_all = reshape_soaps(soaps_all[:, feature_idxs], 1, 
                                      soap_hyperparameters['max_radial'], 
                                      soap_hyperparameters['max_angular'])
            
            # Compute density
            soap_density = compute_soap_density(soap_hyperparameters['max_radial'],
                                                soap_hyperparameters['max_angular'],
                                                soap_hyperparameters['interaction_cutoff'],
                                                soaps_all, r_grid, p_grid,
                                                chunk_size_r=chunk_size_r, chunk_size_p=chunk_size_p)
            
            for n_cantons in tqdm((2, 4), desc='Classes', leave=False):
                
                output_dir = f'{linear_dir}/{n_cantons}-Class/{spectrum_name}/{species_pairing}'
                weights_file = f'{output_dir}/svc_weights.dat'
                
                weights = np.loadtxt(weights_file)
                
                weights = reshape_soaps(weights, 1,
                                        soap_hyperparameters['max_radial'],
                                        soap_hyperparameters['max_angular'])
                
                weight_density = compute_soap_density(soap_hyperparameters['max_radial'],
                                                      soap_hyperparameters['max_angular'],
                                                      soap_hyperparamters['interaction_cutoff'],
                                                      weights, r_grid, p_grid,
                                                      chunk_size_r=chunk_size_r, chunk_size_p=chunk_size_p)
                
                # TODO: save soap density
                # TODO: save weight density

## Sodalite

In [None]:
# Load structure
sod = read('../Raw_Data/SOD/sodalite.xyz', index=':')

In [None]:
# Make a SphericalInvariants representation (for sodalite)
representation = SphericalInvariants(gaussian_sigma_type='Constant',
                                     **soap_hyperparameters)

In [None]:
# Compute SOAPs
soaps_sod = librascal_soap(sod, [14],
                           **soap_hyperparameters,
                           normalize=False,
                           average=True)

# TODO: do an average just like DEEM+IZA -- should be the same as a single environment
# as they are all equivalent, but do this for consistency
soaps_sod = soaps_sod[0]

In [None]:
rrw, idxs = rrw_neighbors(sod[0], [14], [8, 14], 6.0, self_interaction=True)

In [None]:
# TODO: set n_pairs
soaps_sod = reshape_soaps(soaps_sod, n_pairs, 
                          soap_hyperparameters['max_radial'], 
                          soap_hyperparameters['max_angular'])

# Compute density
density_sod = compute_soap_density(soap_hyperparameters['max_radial'],
                                   soap_hyperparameters['max_angular'],
                                   soap_hyperparameters['interaction_cutoff'],
                                   soaps_sod, r_grid, p_grid,
                                   chunk_size_r=10, chunk_size_p=10)

# TODO: save density, rrw, and indices