To install rascal:
(NOTE: See the top-level README for the most up-to-date installation instructions.)
+ mkdir ../build 
+ cd build
+ cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTS=ON ..
+ make -j 4
+ make install

In [None]:
!export OMP_NUM_THREADS=1
!export NUMBA_THREADING_LAYER=1
from mkl import set_num_threads
set_num_threads(1)

In [None]:
import sys, site
from difflib import ndiff
from distutils.sysconfig import get_python_lib
print(sys.prefix,site.USER_SITE,get_python_lib())

In [None]:
 get_python_lib().replace(sys.prefix+'/', '')

In [None]:
!rm -r /home/felix/miniconda2/envs/py36/rascal

In [None]:
!ls /home/felix/miniconda2/envs/py36/lib/libras*

In [None]:
!ls /home/felix/miniconda2/envs/py36/lib/python3.6/site-packages/rascal/lib/../../../../../lib/libras*

In [None]:
!ls /home/felix/miniconda2/envs/py36/lib/python3.6/site-packages/rascal-0.3.5-py3.6-linux-x86_64.egg/lib/python3.6/site-packages/

In [None]:
!ls /home/felix/miniconda2/envs/py36/lib/librascal.so

In [None]:
lines = !objdump -x /home/felix/miniconda2/envs/py36/lib/librascal.so 
for line in lines:
    if 'RPATH' in line:
        print(line)

In [None]:
!ls /home/felix/git/librascal/_skbuild/linux-x86_64-3.6

In [None]:
lines = !objdump -x /home/felix/git/librascal/_skbuild/linux-x86_64-3.6/cmake-install/librascal.so
for line in lines:
    if 'RPATH' in line:
        print(line)

In [None]:
!ls /home/felix/miniconda2/envs/py36/lib/python3.6/site-packages/rascal-0.3.5-py3.6-linux-x86_64.egg/lib/python3.6/site-packages/rascal/lib/_rascal.cpython-36m-x86_64-linux-gnu.so

In [None]:
lines = !objdump -x /home/felix/miniconda2/envs/py36/lib/python3.6/site-packages/rascal-0.3.5-py3.6-linux-x86_64.egg/lib/python3.6/site-packages/rascal/lib/_rascal.cpython-36m-x86_64-linux-gnu.so
lines

In [None]:
%matplotlib inline
from matplotlib import pylab as plt

import os, sys
from ase.io import read
sys.path.insert(0,"../build")
import sys
import time
import rascal
import json

import ase
from ase.io import read, write
from ase.build import make_supercell
from ase.visualize import view
import numpy as np
import sys
from copy import deepcopy
import json

from rascal.representations import SphericalInvariants as SOAP
from rascal.models import Kernel
# from rascal.utils import fps,FPSFilter

In [None]:
frames = read('../reference_data/inputs/small_molecules-1000.xyz',':100')

# SOAP: Power spectrum

In [None]:
frames = read('../reference_data/inputs/diamond_2atom_distorted.json',':100')
frames = read('../reference_data/inputs/CaCrP2O7_mvc-11955_symmetrized.json',':100')

In [None]:
hypers = dict(soap_type="PowerSpectrum",
              interaction_cutoff=3.5, 
              max_radial=2, 
              max_angular=2, 
              gaussian_sigma_constant=0.4,
              gaussian_sigma_type="Constant",
              cutoff_smooth_width=0.5,
              normalize=False,
              )
soap = SOAP(**hypers)
zeta=2
kernel1 = Kernel(soap, zeta=zeta, target_type='Atom')

In [None]:
representation = soap.transform(frames[:10])
X = representation.get_features(soap)
X.shape

In [None]:
from itertools import product
species = []
for ii in range(len(representation)):
    manager = representation[ii]
    for center in manager:
        sp = center.atom_type
        species.append(sp)
        
u_species = np.unique(species)
sp_pairs = []
for sp1 in u_species:
    for sp2 in u_species:
        if sp1 <= sp2:
            sp_pairs.append((sp1,sp2))
feat_idx2coeff_idx = {}
i_feat = 0
for sp_pair,n1,n2,l in product(sp_pairs,range(soap.hypers['max_radial']),
                       range(soap.hypers['max_radial']),range(soap.hypers['max_angular'])):
    feat_idx2coeff_idx[i_feat] = dict(a=sp_pair[0],b=sp_pair[1],n1=n1,n2=n2,l=l)
    i_feat += 1
u_species, sp_pairs,feat_idx2coeff_idx

In [None]:
def get_power_spectrum_index_mapping(soap, managers):
    n_max = soap.hypers['max_radial']
    l_max = soap.hypers['max_angular']+1
    species = []
    for ii in range(len(managers)):
        manager = managers[ii]
        if isinstance(manager, ase.Atoms):
            species.extend(manager.get_atomic_numbers())
        else:
            for center in manager:
                sp = center.atom_type
                species.append(sp)

    u_species = np.unique(species)
    sp_pairs = []
    for sp1 in u_species:
        for sp2 in u_species:
            if sp1 <= sp2:
                sp_pairs.append((sp1,sp2))
    feat_idx2coeff_idx = {}  
    i_feat = 0
    for sp_pair in sp_pairs:
        i_feat_sp = 0
        for n1,n2,l in product(range(n_max), range(n_max), range(l_max)):
            feat_idx2coeff_idx[i_feat] = dict(a=sp_pair[0],b=sp_pair[1],n1=n1,n2=n2,l=l, i_feat_sp=i_feat_sp)
            i_feat += 1
            i_feat_sp += 1
        
    return feat_idx2coeff_idx

get_power_spectrum_index_mapping(soap, representation)

In [None]:
def get_index_mappings_sample_per_species(managers, sps):
    # get various info from the structures about the center atom species and indexing
    types = []
    strides_by_sp = {sp: [0] for sp in sps}
    global_counter = {sp: 0 for sp in sps}
    indices_by_sp = {sp: [] for sp in sps}
    map_by_manager = [{} for ii in range(len(managers))]
    for i_man in range(len(managers)):
        man = managers[i_man]
        counter = {sp: 0 for sp in sps}
        for i_at, at in enumerate(man):
            types.append(at.atom_type)
            if at.atom_type in sps:
                map_by_manager[i_man][global_counter[at.atom_type]] = i_at
                counter[at.atom_type] += 1
                global_counter[at.atom_type] += 1
            else:
                raise ValueError('Atom type {} has not been specified in fselect: {}'.format(
                    at.atom_type, self.Nselect))
        for sp in sps:
            strides_by_sp[sp].append(counter[sp])

    for sp in sps:
        strides_by_sp[sp] = np.cumsum(strides_by_sp[sp])

    for ii, sp in enumerate(types):
        indices_by_sp[sp].append(ii)

    return strides_by_sp, global_counter, map_by_manager, indices_by_sp

def get_index_mappings_sample(managers):
    # get various info from the structures about the center atom species and indexing
    strides = [0]
    global_counter = 0
    map_by_manager = [{} for ii in range(len(managers))]
    for i_man in range(len(managers)):
        man = managers[i_man]
        counter = 0
        for i_at, at in enumerate(man):
            map_by_manager[i_man][global_counter] = i_at
            counter += 1
            global_counter += 1
        strides.append(counter)

    strides = np.cumsum(strides)

    return strides, global_counter, map_by_manager

def convert_selected_global_index2rascal_sample_per_species(managers, selected_ids_by_sp, strides_by_sp, map_by_manager, sps):
    # convert selected center indexing into the rascal format
    selected_ids = [[] for ii in range(len(managers))]
    i_manager = {sp: 0 for sp in sps}
    for sp in sps:
        ids = convert_selected_global_index2rascal_sample(managers, selected_ids_by_sp[sp], strides_by_sp[sp], map_by_manager)
        for ii,selected_idx in zip(ids, selected_ids):
            selected_idx.extend(ii)
    for ii in range(len(selected_ids)):
        selected_ids[ii] = np.sort(selected_ids[ii]).tolist()
    return selected_ids

def convert_selected_global_index2rascal_sample(managers, selected_ids_global, strides, map_by_manager):
    # convert selected center indexing into the rascal format
    selected_ids = [[] for ii in range(len(managers))]
    i_manager = 0
    for idx in selected_ids_global:
        carry_on = True
        while carry_on:
            if idx >= strides[i_manager] and idx < strides[i_manager + 1]:
                selected_ids[i_manager].append(
                    map_by_manager[i_manager[idx]])
                carry_on = False
            else:
                i_manager += 1
    for ii in range(len(selected_ids)):
        selected_ids[ii] = np.sort(selected_ids[ii]).tolist()
    return selected_ids

class FPSFilter(object):
    """Farther Point Sampling (FPS) to select samples or features in a given feature matrix. 
    Wrapper around the fps function for convenience.
    Parameters
    ----------
    representation : Calculator
        Representation calculator associated with the kernel
    Nselect: int
        number of points to select. if act_on='sample per specie' then it should
        be a dictionary mapping atom type to the number of samples, e.g.
        Nselect = {1:200,6:100,8:50}.
    act_on: string
        Select how to apply the selection. Can be either of 'sample',
        'sample per species','feature'.
        
    is_deterministic: bool
        flag to switch between selction criteria
    seed: int
        if is_deterministic==False, seed for the random selection
    """

    def __init__(self, representation, Nselect, act_on='sample per specie', starting_index=0):
        super(FPSFilter, self).__init__()
        self._representation = representation
        self.Nselect = Nselect
        self.starting_index = starting_index
        if act_on in ['sample', 'sample per specie', 'feature']:
            self.act_on = act_on
        else:
            raise 'Wrong input: {}'.format(act_on)
    
        self.selected_ids = None
        self.fps_minmax_d2_by_sp = None
        self.fps_minmax_d2 = None
        
    def fit(self, managers):
        """Perform CUR selection of samples/features.
        Parameters
        ----------
        managers : AtomsList
            list of structures containing features computed with representation
        Returns
        -------
        SparsePoints
            Selected samples
        Raises
        ------
        ValueError
            [description]
        NotImplementedError
            [description]
        """
        
        from rascal.utils import fps as do_fps
        # get the dense feature matrix
        X = managers.get_features(self._representation)
        
        if self.act_on in ['sample per specie']:
            sps = list(self.Nselect.keys())

            # get various info from the structures about the center atom species and indexing
            (strides_by_sp, global_counter, map_by_manager,
             indices_by_sp) = get_index_mappings_sample_per_species(managers, sps)

            print('The number of pseudo points selected by central atom species is: {}'.format(
                self.Nselect))

            # organize features w.r.t. central atom type
            X_by_sp = {}
            for sp in sps:
                X_by_sp[sp] = X[indices_by_sp[sp]]
            self._XX = X_by_sp

            # split the dense feature matrix by center species and apply CUR decomposition
            self.selected_ids_by_sp = {}
            self.fps_minmax_d2_by_sp = {}
            self.fps_hausforff_d2_by_sp = {}
            for sp in sps:
                print('Selecting species: {}'.format(sp))
                fps_out = do_fps(X_by_sp[sp], self.Nselect[sp], starting_index=self.starting_index)
                self.selected_ids_by_sp[sp] = fps_out['fps_indices']
                self.fps_minmax_d2_by_sp[sp] = fps_out['fps_minmax_d2']

            return self
        elif self.act_on in ['feature']:
            fps_out = do_fps(X.T, self.Nselect, starting_index=self.starting_index)
            self.selected_ids = fps_out['fps_indices']
            self.fps_minmax_d2 = fps_out['fps_minmax_d2']
        elif self.act_on in ['sample']:
            fps_out = do_fps(X, self.Nselect, starting_index=self.starting_index)
            self.selected_ids_global = fps_out['fps_indices']
            self.fps_minmax_d2 = fps_out['fps_minmax_d2']
        else:
            raise NotImplementedError("method: {}".format(self.act_on))
            
    def transform(self, managers):
        if self.act_on in ['sample per specie']:
            sps = list(self.Nselect.keys())
            # get various info from the structures about the center atom species and indexing
            (strides_by_sp, global_counter, map_by_manager,
             indices_by_sp) = get_index_mappings_sample_per_species(managers, sps)
            selected_ids_by_sp = {key:val[:self.Nselect[key]] for key,val in self.selected_ids_by_sp.items()}
            self.selected_ids = convert_selected_global_index2rascal_sample_per_species(
                managers, selected_ids_by_sp, strides_by_sp, map_by_manager, sps)
            # build the pseudo points
            sparse_points = SparsePoints(self._representation)
            sparse_points.extend(managers, self.selected_ids)
            return sparse_points
        elif self.act_on in ['sample']:
            selected_ids_global = self.selected_ids_global[:self.Nselect]
            strides, _, map_by_manager = get_index_mappings_sample(managers)
            self.selected_ids = convert_selected_global_index2rascal_sample(managers, 
                                                        selected_ids_global, strides, map_by_manager)
            # build the pseudo points
            sparse_points = SparsePoints(self._representation)
            sparse_points.extend(managers, self.selected_ids)
            return sparse_points
                             
        elif self.act_on in ['feature']:
            feat_idx2coeff_idx = get_power_spectrum_index_mapping(self._representation, managers)
            selected_features = {key:[] for key in feat_idx2coeff_idx[0].keys()}
            for idx in self.selected_ids[:self.Nselect]:
                coef_idx = feat_idx2coeff_idx[idx]
                for key in selected_features.keys():
                    selected_features[key].append(int(coef_idx[key]))
            # keep the global indices for ease of use
            selected_features['selected_features_global_ids'] = self.selected_ids[:self.Nselect].tolist()
            return dict(coefficient_subselection=selected_features)
            
    def plot(self):
        if self.fps_minmax_d2_by_sp is None:
            plt.semilogy(self.fps_minmax_d2,label=self.act_on)
            
        else:
            for sp in self.fps_minmax_d2_by_sp:
                plt.semilogy(self.fps_minmax_d2_by_sp[sp],
                            label='{} species {}'.format(self.act_on, sp))
            plt.legend()
        plt.title('FPSFilter')
        plt.ylabel('fps minmax d^2')
        
    def fit_transform(self, managers):
        return self.fit(managers).transform(managers)

# Nselect = {1:300, 6:300, 7:20, 8:20}
# fps_filter = FPSFilter(soap, Nselect, act_on='sample per specie')
Nselect = 12
fps_filter = FPSFilter(soap, Nselect, act_on='feature')

In [None]:
managers = soap.transform(frames)

In [None]:
fps_filter.fit(managers)
fps_filter.plot()

In [None]:
X = managers.get_features(soap)
X[1][ids]

In [None]:
X = managers.get_features(soap)
X[1][ids[:8]]

In [None]:
kernel_s = Kernel(soap_s, zeta=zeta, target_type='Atom')
kk = kernel_s(managers_s, managers_s)
X = managers.get_features(soap)[:,fps_filter.selected_ids]
k_ref = np.dot(X,X.T)
np.allclose(kk,k_ref)

In [None]:
fps_filter.transform(managers)

In [None]:
soap.get_feature_index_mapping?

In [None]:
mapping = soap.get_feature_index_mapping(representation)

selected_features = {key:[] for key in mapping[0].keys()}
ids = [key for key in mapping.keys()]
# np.random.shuffle(ids)
# print(ids)
for idx in ids:
    coef_idx = mapping[idx]
    for key in selected_features.keys():
        selected_features[key].append(int(coef_idx[key]))
selected_features['selected_features_global_ids'] = ids
mapp = dict(coefficient_subselection=selected_features)
print(mapp)

In [None]:
# hypers.update(fps_filter.transform(managers))
hyp = deepcopy(hypers)
hyp.update(mapp)
soap_s = SOAP(**hyp)

managers_s = soap_s.transform(frames)

X = managers.get_features(soap)
X_s = managers_s.get_features(soap_s)
np.allclose(X[:,ids[:8]],X_s)

In [None]:
import json
with open('../reference_data/sparse_soap_input.json','w') as j:
    json.dump(soap_s.hypers, j)
with open('../reference_data/adaptor_input.json','w') as j:
    json.dump(json.loads(managers_s.managers.get_parameters()), j)

In [None]:
soap_s.hypers['coefficient_subselection']['a']

In [None]:
import json
from itertools import product
root = '../'
fns = ['reference_data/inputs/diamond_2atom_distorted.json', 
       'reference_data/inputs/CaCrP2O7_mvc-11955_symmetrized.json',
       'reference_data/inputs/methane.json']
# fns = ['reference_data/inputs/CaCrP2O7_mvc-11955_symmetrized.json',]
soap_types = ["PowerSpectrum"]
Nselects = [-1, 12]
sparsification_inputs = []
for fn, soap_type, Nselect in product(fns,soap_types,Nselects):
    frames = read(root+fn,':')

    hypers = dict(soap_type=soap_type,
                  interaction_cutoff=3.5, 
                  max_radial=2, 
                  max_angular=2, 
                  gaussian_sigma_constant=0.4,
                  gaussian_sigma_type="Constant",
                  cutoff_smooth_width=0.5,
                  normalize=False,
                  expansion_by_species_method='structure wise',
                  )
    
    soap = SOAP(**hypers)
    managers = soap.transform(frames)
    X = managers.get_features(soap)
    print(X.shape)
    if Nselect == -1:
        Nselect = X.shape[1]
    fps_filter = FPSFilter(soap, Nselect, act_on='feature')
    fps_filter.fit(managers)
    
    hyp = deepcopy(hypers)
    hyp.update(fps_filter.transform(managers))
    
    soap_s = SOAP(**hyp)

    managers_s = soap_s.transform(frames)
    X_s = managers_s.get_features(soap_s)
    print(np.allclose(X[:,hyp['coefficient_subselection']['selected_features_global_ids']],X_s))
    sparsification_inputs.append(dict(hypers=dict(rep=soap_s.hypers, 
                                                  adaptors=managers_s.managers.get_parameters()), 
                                     filename=fn, Nselect=Nselect))
    
# with open('../reference_data/tests_only/sparsification_inputs.json','w') as j:
#     json.dump(sparsification_inputs, j)
    

In [None]:
SOAP?

In [None]:
hyp = deepcopy(hypers)
hyp.update(fps_filter.transform(managers))
# hyp.update(mapp)
soap_s = SOAP(**hyp)

managers_s = soap_s.transform(frames)

X = managers.get_features(soap)
X_s = managers_s.get_features(soap_s)
np.allclose(X[:,ids[:8]],X_s)

In [None]:
aa = dict(a=1,b=3)
bb = {'c':5}
aa.update(bb)
aa

In [None]:
Xs = []
for frame in frames:
    representation = soap.transform([frame])
    X = representation.get_features(soap, species=[1, 6, 7, 8])
    Xs.append(X)

In [None]:
%%time
for ii,X in enumerate(Xs):
    for jj,Y in enumerate(Xs):
        # if jj < ii: continue
        aa = np.sum(np.power(np.dot(X, Y.T), zeta))

In [None]:
representation = soap.transform(frames)
X = representation.get_features(soap)

In [None]:
%%time 
kk = np.power(np.dot(X, X.T), zeta)

In [None]:
%time kernel1(representation)

In [None]:
hypers = dict(soap_type="PowerSpectrum",
              interaction_cutoff=3.5, 
              max_radial=6, 
              max_angular=4, 
              gaussian_sigma_constant=0.4,
              gaussian_sigma_type="Constant",
              cutoff_smooth_width=0.5,
              expansion_by_species_method='user defined',
              normalize=True,
              global_species=[1, 6, 7, 8],
              )
soap = SOAP(**hypers)
zeta=2
kernel1 = Kernel(soap, zeta=zeta, target_type='Atom')
representation = soap.transform(frames)

In [None]:
%time kernel1(representation)

In [None]:
aa = kernel1(representation)
bb = kernel1(representation, representation)
np.allclose(aa, bb)

In [None]:
bb[:10,:10]

In [None]:
aa[:20,:20]

In [None]:
aa

In [None]:
140/43. # 900

In [None]:
1.8/0.513 #100

# Learning the formation energies of small molecules

In [None]:
# Load the small molecules 
frames = read('../reference_data/inputs/small_molecules-1000.xyz',':600')

## learning utilities

In [None]:
def compute_representation(representation,frames):
    expansions = soap.transform(frames)
    return expansions

def compute_kernel(zeta, rep1, rep2=None):
    if rep2 is None:
        kernel = rep1.cosine_kernel_global(zeta)
    else:
        kernel = rep1.cosine_kernel_global(rep2,zeta)
    return kernel

def extract_energy(frames):
    prop = [[]]*len(frames)
    for ii,cc in enumerate(frames):
        prop[ii] = cc.info['dft_formation_energy_per_atom_in_eV']
    y = np.array(prop)
    return y

def split_dataset(frames, test_fraction, seed=10):
    N = len(frames)
    ids = np.arange(N)
    np.random.seed(seed)
    np.random.shuffle(ids)
    Ntrain = int(N*test_fraction)
    train = ids[:Ntrain]
    test = ids[Ntrain:]
    targets = extract_energy(frames)
    return [frames[ii] for ii in train],targets[train],[frames[ii] for ii in test],targets[test]

def get_mae(ypred,y):
    return np.mean(np.abs(ypred-y))
def get_rmse(ypred,y):
    return np.sqrt(np.mean((ypred-y)**2))
def get_sup(ypred,y):
    return np.amax(np.abs((ypred-y)))
def get_r2(y_pred,y_true):
    weight = 1
    sample_weight = None
    numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0,dtype=np.float64)
    denominator = (weight * (y_true - np.average(
        y_true, axis=0, weights=sample_weight)) ** 2).sum(axis=0,dtype=np.float64)
    output_scores = 1 - (numerator / denominator)
    return np.mean(output_scores)


score_func = dict(
    MAE=get_mae,
    RMSE=get_rmse,
    SUP=get_sup,
    R2=get_r2,
)

def get_score(ypred,y):
    scores = {}
    for k,func in score_func.items():
        scores[k] = func(ypred,y)
    return scores

class KRR(object):
    def __init__(self,zeta,weights,representation,X):
        self.weights = weights
        self.representation = representation
        self.zeta = zeta
        self.X = X
        
    def predict(self,frames):
        features = compute_representation(self.representation,frames)
        kernel = compute_kernel(self.zeta , self.X, features)
        return np.dot(self.weights, kernel)
    
def train_krr_model(zeta,Lambda,representation,frames,y,jitter=1e-8):
    features = compute_representation(representation,frames)
    kernel = compute_kernel(zeta,features)    
    # adjust the kernel so that it is properly scaled
    delta = np.std(y) / np.mean(kernel.diagonal())
    kernel[np.diag_indices_from(kernel)] += Lambda**2 / delta **2 + jitter
    # train the krr model
    weights = np.linalg.solve(kernel,y)
    model = KRR(zeta, weights,representation, features)
    return model,kernel



## With the full power spectrum

In [None]:
hypers = dict(soap_type="PowerSpectrum",
              interaction_cutoff=3.5, 
              max_radial=6, 
              max_angular=6, 
              gaussian_sigma_constant=0.4,
              gaussian_sigma_type="Constant",
              cutoff_smooth_width=0.5,
              )
soap = SOAP(**hypers)

In [None]:
frames_train, y_train, frames_test, y_test = split_dataset(frames,0.8)

In [None]:
zeta = 2
Lambda = 5e-3
krr,k = train_krr_model(zeta, Lambda, soap, frames_train, y_train)

In [None]:
y_pred = krr.predict(frames_test)
get_score(y_pred, y_test)

In [None]:
plt.scatter(y_pred, y_test, s=3)
plt.axis('scaled')
plt.xlabel('DFT energy / (eV/atom)')
plt.ylabel('Predicted energy / (eV/atom)')

## With just the radial spectrum

In [None]:
hypers = dict(soap_type="RadialSpectrum",
              interaction_cutoff=3.5, 
              max_radial=6, 
              max_angular=0, 
              gaussian_sigma_constant=0.4,
              gaussian_sigma_type="Constant",
              cutoff_smooth_width=0.5,
              )
soap = SOAP(**hypers)

In [None]:
frames_train, y_train, frames_test, y_test = split_dataset(frames,0.8)

In [None]:
zeta = 2
Lambda = 5e-4
krr,k = train_krr_model(zeta, Lambda, soap, frames_train, y_train)

In [None]:
y_pred = krr.predict(frames_test)
get_score(y_pred, y_test)

In [None]:
plt.scatter(y_pred, y_test, s=3)
plt.axis('scaled')
plt.xlabel('DFT energy / (eV/atom)')
plt.ylabel('Predicted energy / (eV/atom)')

# Make a map of the dataset

## utils

In [None]:
def compute_representation(representation,frames):
    expansions = soap.transform(frames)
    return expansions

def compute_kernel(zeta, rep1, rep2=None):
    if rep2 is None:
        kernel = rep1.cosine_kernel_global(zeta)
    else:
        kernel = rep1.cosine_kernel_global(rep2,zeta)
    return kernel

In [None]:
def link_ngl_wdgt_to_ax_pos(ax, pos, ngl_widget):
    from matplotlib.widgets import AxesWidget
    from scipy.spatial import cKDTree
    r"""
    Initial idea for this function comes from @arose, the rest is @gph82 and @clonker
    """
    
    kdtree = cKDTree(pos)        
    #assert ngl_widget.trajectory_0.n_frames == pos.shape[0]
    x, y = pos.T
    
    lineh = ax.axhline(ax.get_ybound()[0], c="black", ls='--')
    linev = ax.axvline(ax.get_xbound()[0], c="black", ls='--')
    dot, = ax.plot(pos[0,0],pos[0,1], 'o', c='red', ms=7)

    ngl_widget.isClick = False
    
    def onclick(event):
        linev.set_xdata((event.xdata, event.xdata))
        lineh.set_ydata((event.ydata, event.ydata))
        data = [event.xdata, event.ydata]
        _, index = kdtree.query(x=data, k=1)
        dot.set_xdata((x[index]))
        dot.set_ydata((y[index]))
        ngl_widget.isClick = True
        ngl_widget.frame = index
    
    def my_observer(change):
        r"""Here comes the code that you want to execute
        """
        ngl_widget.isClick = False
        _idx = change["new"]
        try:
            dot.set_xdata((x[_idx]))
            dot.set_ydata((y[_idx]))            
        except IndexError as e:
            dot.set_xdata((x[0]))
            dot.set_ydata((y[0]))
            print("caught index error with index %s (new=%s, old=%s)" % (_idx, change["new"], change["old"]))
    
    # Connect axes to widget
    axes_widget = AxesWidget(ax)
    axes_widget.connect_event('button_release_event', onclick)
    
    # Connect widget to axes
    ngl_widget.observe(my_observer, "frame", "change")

## make a map with kernel pca projection

In [None]:
# Load the small molecules 
frames = read('./reference_data/small_molecules-1000.xyz',':600')

In [None]:
hypers = dict(soap_type="PowerSpectrum",
              interaction_cutoff=3.5, 
              max_radial=6, 
              max_angular=6, 
              gaussian_sigma_constant=0.4,
              gaussian_sigma_type="Constant",
              cutoff_smooth_width=0.5,
              )
soap = SOAP(**hypers)

In [None]:
zeta = 2

features = compute_representation(soap, frames)

kernel = compute_kernel(zeta,features)

In [None]:
from sklearn.decomposition import KernelPCA

In [None]:
kpca = KernelPCA(n_components=2,kernel='precomputed')
kpca.fit(kernel)

In [None]:
X = kpca.transform(kernel)

In [None]:
plt.scatter(X[:,0],X[:,1],s=3)

## make an interactive map

In [None]:
# package to visualize the structures in the notebook
# https://github.com/arose/nglview#released-version
import nglview

In [None]:
iwdg = nglview.show_asetraj(frames)
# set up the visualization
iwdg.add_unitcell()
iwdg.add_spacefill()
iwdg.remove_ball_and_stick()
iwdg.camera = 'orthographic'
iwdg.parameters = { "clipDist": 0 }
iwdg.center()
iwdg.update_spacefill(radiusType='covalent',
                                   scale=0.6,
                                   color_scheme='element')
iwdg._remote_call('setSize', target='Widget',
                               args=['%dpx' % (600,), '%dpx' % (400,)])
iwdg.player.delay = 200.0

In [None]:
link_ngl_wdgt_to_ax_pos(plt.gca(), X, iwdg)
plt.scatter(X[:,0],X[:,1],s=3)
iwdg