In [1]:
import numpy as np
from ase.io import read

from skcosmo.feature_selection import FPS
from rascal.representations import SphericalInvariants as SOAP
from rascal.utils import 

# read all of the frames and book-keep the centers and species
filename = "./make_tensor_data/train_tensor/CSD-3k+S546_shift_tensors.xyz"
frames = np.asarray(
    read(filename,format="extxyz",index=":"),
    dtype=object,
)

for frame in frames:
    frame.wrap(eps=1e-10)
    
n_centers = [len(frame) for frame in frames]
#could also be solved with np.cumsum
n_env_accum = [sum(n_centers[: i + 1]) for i in range(len(n_centers))]
n_env = sum(n_centers)

#get concatenated atomic numbers over whole dataset
numbers = np.concatenate([frame.numbers for frame in frames])
#get indices, where elements are stored. maybe use argwhere -> np.argwhere(numbers == 1)
#allows numbers[np.argwhere(numbers == 1)] in multidimensional array?
numbers = np.concatenate([frame.numbers for frame in frames])
number_loc = np.array([np.where(numbers == i)[0] for i in [1, 6, 7, 8]], dtype=object)


# compute radial soap vectors as first pass
hypers = dict(
    soap_type="PowerSpectrum",
    interaction_cutoff=3.5,
    max_radial=6,
    max_angular=0,
    gaussian_sigma_type="Constant",
    gaussian_sigma_constant=0.4,
    cutoff_smooth_width=0.5
)
soap = SOAP(**hypers)
X_raw = soap.transform(frames).get_features(soap)


# select 100 diverse samples
i_selected = FPS(n_to_select=100).fit(X_raw.T).selected_idx_

# book-keep which frames these samples belong in
# should be substracted by n_env a
# runs in O(N^2) ?
# solution with O(N) could simply zip through both arrays
# for i in i_selected:
#        while i_selected < 
# use numpy.digitize()! bins are sorted and are given by n_envs accumulated
# data is given by 


frames_select = [np.where(n_env_accum > i)[0][0] for i in i_selected]
reduced_frames_select = list(sorted(set(frames_select)))

properties_select = [
    frame.arrays["cs_iso"] for frame in frames[reduced_frames_select]
]

n_centers_select = [len(frame) for frame in frames[reduced_frames_select]]
n_env_accum_select = [
    sum(n_centers_select[: i + 1]) for i in range(len(n_centers_select))
]
n_env_select = sum(n_centers_select)


# compute a larger power spectrum for these frames
hypers["max_angular"] = 6
soap_select = SOAP(**hypers)
X_raw_select = soap_select.transform(frames[reduced_frames_select]).get_features(
    soap_select
)


# pull the soap vectors only pertaining to the selected environments
i_select_reduced = []
properties_select_reduced = np.zeros(len(i_selected), dtype=float)
for i in range(len(i_selected)):
    my_orig_frame = frames_select[i]
    my_frame = reduced_frames_select.index(my_orig_frame)
    if my_orig_frame != 0:
        orig_loc = i_selected[i] - n_env_accum[my_orig_frame - 1]
        new_loc = orig_loc + n_env_accum_select[my_frame - 1]
    else:
        orig_loc = i_selected[i]
        new_loc = i_selected[i]
    i_select_reduced.append(new_loc)
    properties_select_reduced[i] = frames[my_orig_frame].arrays["cs_iso"][orig_loc]

X_sample_select = X_raw_select[i_select_reduced]


# select 100 / 2520 soap features
n_select = 100
X_select = FPS(n_to_select=n_select).fit_transform(X_sample_select)
Y_select = properties_select_reduced.reshape(-1, 1)

data = dict(X=X_select, Y=Y_select)
#np.savez("skcosmo/datasets/data/csd-1000r.npz", **data)


SyntaxError: invalid syntax (<ipython-input-1-e32dc5a0d0fb>, line 6)

In [2]:
import numpy as np
from ase.io import read

from skcosmo.feature_selection import FPS
from rascal.representations import SphericalInvariants as SOAP


# read all of the frames and book-keep the centers and species
filename = "./make_tensor_data/train_tensor/CSD-3k+S546_shift_tensors.xyz"
frames = np.asarray(
    read(filename,format="extxyz",index=":"),
    dtype=object,
)

for frame in frames:
    frame.wrap(eps=1e-10)
    

In [7]:
%pip install joblib

You should consider upgrading via the '/ssd/scratch/kellner/miniconda3/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
from joblib import Parallel, delayed
from helpers import grouper

In [4]:
from nice.utilities import get_all_species

In [45]:
list(get_all_species(frames))

[1, 6, 7, 8, 16]

In [9]:
hypers = dict(
    soap_type="PowerSpectrum",
    interaction_cutoff=3.5,
    max_radial=8,
    max_angular=5,
    gaussian_sigma_type="Constant",
    gaussian_sigma_constant=0.4,
    cutoff_smooth_width=0.5,
    expansion_by_species_method="user defined",
    global_species = [1, 6, 7, 8, 16]
)
soap = SOAP(**hypers)
#replace global species later with function that determines it
#better: replace it in function call

In [6]:
type(soap.hypers)

dict

In [46]:
soap.update_hyperparameters()

{'max_radial': 6,
 'max_angular': 0,
 'soap_type': 'PowerSpectrum',
 'normalize': True,
 'inversion_symmetry': True,
 'expansion_by_species_method': 'user defined',
 'global_species': [1, 6, 7, 8, 16],
 'compute_gradients': False,
 'cutoff_function': {'type': 'ShiftedCosine',
  'cutoff': {'value': 3.5, 'unit': 'AA'},
  'smooth_width': {'value': 0.5, 'unit': 'AA'}},
 'gaussian_density': {'type': 'Constant',
  'gaussian_sigma': {'value': 0.4, 'unit': 'AA'}},
 'radial_contribution': {'type': 'GTO',
  'optimization': {'Spline': {'accuracy': 1e-08}}}}

In [7]:
def build_lambda(spex,structures,CG,sel_lambda = 2):
    feat_scaling = 1e6            # just a scaling to make coefficients O(1)
    feats = spex.transform(structures).get_features(spex)
    ref_feats = feat_scaling*spherical_expansion_reshape(feats, **hypers)
    lsoap = compute_lambda_soap(ref_feats, CG, sel_lambda, 1)
    return np.moveaxis(lsoap,-1,1).reshape((lsoap.shape[0]*lsoap.shape[-1], -1)) #is this correct?

In [4]:
from copy import deepcopy
#TODO: replace this to reduce dependency on nice
from nice.utilities import get_all_species 
from joblib import Parallel, delayed
from helpers import grouper

def retrieve_features(calculator, chunk):
    """helper function that allows for calling a class method in joblib
    """
    return calculator.transform(chunk).get_features(calculator)

def retrieve_equivariants(calculator, chunk, CG, sel_lambda):
    feat_scaling = 1e6            # just a scaling to make coefficients O(1)
    feats = spex.transform(chunk).get_features(spex)
    ref_feats = feat_scaling*spherical_expansion_reshape(feats, **hypers)
    lsoap = compute_lambda_soap(ref_feats, CG, sel_lambda, 1)
    return np.moveaxis(lsoap,-1,1).reshape((lsoap.shape[0]*lsoap.shape[-1], -1)) #is this correct?


def get_features_in_parallel(frames,calculator,blocksize=100,n_jobs=-1):
    """helper function that returns the features of a calculator (from calculator.transform())
       in parallel
    """
    #for np.concatenate. arrays in list should all have same shape
    calculator = deepcopy(calculator)
    hypers = calculator.hypers
    hypers["expansion_by_species_method"] = "user defined"
    hypers["global_species"] = get_all_species(frames).tolist()
    calculator.update_hyperparameters(**hypers)
    return Parallel(n_jobs=n_jobs)(delayed(retrieve_features)(calculator, chunk)\
                                              for chunk in grouper(blocksize,frames))
    
def get_equivariants_in_parallel(frames,calculator,sel_lambda,blocksize=100,n_jobs=-1):
    """helper function that returns the features of a calculator (from calculator.transform())
       in parallel
    """
    #for np.concatenate. arrays in list should all have same shape
    calculator = deepcopy(calculator)
    hypers = calculator.hypers
    hypers["expansion_by_species_method"] = "user defined"
    hypers["global_species"] = get_all_species(frames).tolist()
    calculator.update_hyperparameters(**hypers)
    CG = ClebschGordanReal(lmax=hypers["max_angular"])
    return np.concatenate(Parallel(n_jobs=n_jobs)(delayed(retrieve_equivariants)(calculator, chunk, CG, sel_lambda)\
                                              for chunk in grouper(blocksize,frames)))

In [110]:
get_all_species(frames).tolist()

[1, 6, 7, 8, 16]

In [169]:
from rascal.utils.io import from_dict, to_dict

In [107]:
type(list(get_all_species(frames)))  #CALL TO_LIST!!!

list

In [5]:
from rascal.representations import SphericalInvariants as SOAP
from rascal.representations import SphericalExpansion
from rascal.utils import (WignerDReal, ClebschGordanReal, 
                          spherical_expansion_reshape, spherical_expansion_conjugate,
                    lm_slice, real2complex_matrix, compute_lambda_soap, xyz_to_spherical, spherical_to_xyz)

In [11]:
hypers_2 = dict(
    interaction_cutoff=3.5,
    max_radial=9,
    max_angular=9,
    gaussian_sigma_type="Constant",
    gaussian_sigma_constant=0.4,
    cutoff_smooth_width=0.5
)
spex = SphericalExpansion(**hypers_2)

In [19]:
X_lamd_parallel = get_equivariants_in_parallel(frames[::5],spex,sel_lambda=0,blocksize=10)

exception calling callback for <Future at 0x7f8794687fd0 state=finished raised BrokenProcessPool>
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/ssd/scratch/kellner/miniconda3/lib/python3.8/site-packages/joblib/externals/loky/process_executor.py", line 624, in wait_result_broken_or_wakeup
    result_item = result_reader.recv()
  File "/ssd/scratch/kellner/miniconda3/lib/python3.8/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/ssd/scratch/kellner/miniconda3/lib/python3.8/multiprocessing/connection.py", line 421, in _recv_bytes
    return self._recv(size)
  File "/ssd/scratch/kellner/miniconda3/lib/python3.8/multiprocessing/connection.py", line 386, in _recv
    buf.write(chunk)
MemoryError
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/ssd/scratch/kellner/miniconda3/lib/python3.8/site-packages/joblib/externals/loky/_b

BrokenProcessPool: A result has failed to un-serialize. Please ensure that the objects returned by the function are always picklable.

In [18]:
X_lamd_parallel.shape

(204168, 0)

In [None]:
X_raw_parallel = get_features_in_parallel(frames,spex,n_jobs=-1)

In [88]:
soap

<rascal.representations.spherical_invariants.SphericalInvariants at 0x7f75b1aa09d0>

In [6]:
%%time
X_raw = soap.transform(frames).get_features(soap)

CPU times: user 15.3 s, sys: 3.28 s, total: 18.6 s
Wall time: 18.6 s


In [8]:
%%time
#problem: some parts of the expansions could be reused
X_raw_parallel = get_features_in_parallel(frames,soap,n_jobs=4)

CPU times: user 4.2 s, sys: 9.03 s, total: 13.2 s
Wall time: 15.9 s


In [147]:
8.69/2.83

3.0706713780918724

In [131]:
for i in X_raw_parallel:
    print(i.shape)

(9822, 540)
(9802, 540)
(9348, 540)
(9816, 540)
(9062, 540)
(10114, 540)
(9624, 540)
(9780, 540)
(8552, 540)
(9130, 540)
(8764, 540)
(8835, 540)
(9346, 540)
(9225, 540)
(10910, 540)
(8608, 540)
(9030, 540)
(8942, 540)
(9596, 540)
(9794, 540)
(11004, 540)
(9769, 540)
(10638, 540)
(10504, 540)
(9464, 540)
(10488, 540)
(10792, 540)
(10100, 540)
(10001, 540)
(10443, 540)
(8708, 540)
(8924, 540)
(9546, 540)
(9005, 540)
(8966, 540)
(4489, 540)


In [117]:
#time = 6.

(340941, 540)

In [77]:
soap = SOAP(**hypers)
X_raw_parallel = np.concatenate(Parallel(n_jobs=-1)(delayed(retrieve_features)(soap, chunk) for chunk in grouper(100,frames)))

In [40]:
for i in X_raw_parallel:
    print(i.shape)

(9822, 540)
(9802, 540)
(9348, 540)
(9816, 540)
(9062, 540)
(10114, 540)
(9624, 540)
(9780, 540)
(8552, 540)
(9130, 540)
(8764, 540)
(8835, 540)
(9346, 540)
(9225, 540)
(10910, 540)
(8608, 540)
(9030, 540)
(8942, 540)
(9596, 540)
(9794, 540)
(11004, 540)
(9769, 540)
(10638, 540)
(10504, 540)
(9464, 540)
(10488, 540)
(10792, 540)
(10100, 540)
(10001, 540)
(10443, 540)
(8708, 540)
(8924, 540)
(9546, 540)
(9005, 540)
(8966, 540)
(4489, 540)


In [30]:
X_raw_parallel.shape

(340941, 540)

In [118]:
np.allclose(X_raw,X_raw_parallel)

True

In [19]:
X_raw

[array([[ 0.00892556,  0.0176354 ,  0.01627523, ...,  0.00603637,
          0.00692119,  0.0023751 ],
        [ 0.00892573,  0.01763565,  0.01627535, ...,  0.00603636,
          0.00692119,  0.0023751 ],
        [ 0.00892573,  0.01763563,  0.01627534, ...,  0.00603636,
          0.00692118,  0.00237509],
        ...,
        [ 0.42807663,  0.08824391, -0.03429232, ...,  0.        ,
          0.        ,  0.        ],
        [ 0.42807532,  0.0882438 , -0.0342927 , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.42807654,  0.08824414, -0.03429308, ...,  0.        ,
          0.        ,  0.        ]]),
 array([[ 3.02217178e-06, -2.00574161e-05, -8.30937402e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 3.02189496e-06, -2.00569894e-05, -8.30917053e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 3.02137524e-06, -2.00554506e-05, -8.30869293e-05, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+

In [3]:
#todo:
#load structures and exclude not passing structures

def filter_by_status(frames, status="PASSING"):
    """Helper function that filters structures by info dict 'STATUS' entry
    """
    return [frame for frame in frames if frame.info['STATUS'] == status]

#write a random selection of 80:20 split to a validation set


In [4]:
filtered_structs = filter_by_status(frames)

In [5]:
len(filtered_structs)

3430

In [6]:
len(frames)

3546

In [107]:
number_loc

array([array([     8,      9,     10, ..., 188097, 188098, 188099]),
       array([    20,     21,     22, ..., 188081, 188082, 188083]),
       array([    12,     13,     14, ..., 187993, 187994, 187995]),
       array([     0,      1,      2, ..., 188069, 188070, 188071])],
      dtype=object)

In [75]:
locs_gathered = []
for i in range(len(i_selected)):
    my_orig_frame = frames_select[i]
    my_frame = reduced_frames_select.index(my_orig_frame)
    if my_orig_frame != 0:
        orig_loc = i_selected[i] - n_env_accum[my_orig_frame - 1]
        new_loc = orig_loc + n_env_accum_select[my_frame - 1]
    else:
        orig_loc = i_selected[i]
        new_loc = i_selected[i]
    locs_gathered.append(orig_loc )
    i_select_reduced.append(new_loc)
    properties_select_reduced[i] = frames[my_orig_frame].arrays["CS"][orig_loc]

In [117]:
number_loc = np.array([np.where(numbers == i)[0] for i in [1, 6, 7, 8]], dtype=object)

In [118]:
number_loc = np.array([np.where(numbers == i)[0] for i in [1, 6, 7, 8]], dtype=object)

In [120]:
#extracting indices of masked structures
for frame in frames: mask_center_atoms_by_species(frame,species_select=[1])
masked_loc = np.where(numbers_masked == True)[0]
#index through index set, through obtain full indices
#masked_loc[selected_indices]

In [122]:
frames[1].arrays["center_atoms_mask"]

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False,  True,  True])

In [125]:
number_loc[0]

array([     8,      9,     10, ..., 188097, 188098, 188099])

In [126]:
np.where(numbers_masked == True)[0]

array([     0,      3,      8, ..., 188097, 188098, 188099])

In [123]:
numbers_masked = np.concatenate([frame.arrays["center_atoms_mask"] for frame in frames])
#number_loc = np.array([])

In [77]:
np.array(locs_gathered)

array([  0,  78,  21,  47,  87,  13,   5,  17,  99,  25,  63,   1,  49,
        75,  99,  91,  50,  86,  55,  70,  58,  85,  47,  29,  36,   7,
       146,   8,  18,  63,  24,   0,  19,  22,   9,   8,   6, 131,  38,
        12,   2,  17,  70,  60,  25,  10,   3,   2,  21,   2,  83,  12,
        65,   9,   9, 162,  76, 157,  72,   8,  64,  59,  23,  43,  69,
        33,  87,  35,  81,  11,  15,  15,  15,   5,   1,  24,  73,  32,
       127,   3,  34,  29,  36,  44, 120, 131,   2,  23,  56,  16,   3,
        29,  11, 105,  27,   3,  50,  53,  33,  50])

In [79]:
i_selected[0] 

0

In [80]:
correct_is = i_selected
correct_is[1:] = i_selected[1:] - bins[np.digitize(i_selected,np.array(n_env_accum))-1][1:]
#and then do loop through indeces to mask atoms

In [127]:
correct_is

array([  0,  78,  21,  47,  87,  13,   5,  17,  99,  25,  63,   1,  49,
        75,  99,  91,  50,  86,  55,  70,  58,  85,  47,  29,  36,   7,
       146,   8,  18,  63,  24,   0,  19,  22,   9,   8,   6, 131,  38,
        12,   2,  17,  70,  60,  25,  10,   3,   2,  21,   2,  83,  12,
        65,   9,   9, 162,  76, 157,  72,   8,  64,  59,  23,  43,  69,
        33,  87,  35,  81,  11,  15,  15,  15,   5,   1,  24,  73,  32,
       127,   3,  34,  29,  36,  44, 120, 131,   2,  23,  56,  16,   3,
        29,  11, 105,  27,   3,  50,  53,  33,  50])

In [None]:
# use number_loc 
# number_loc = np.array([np.where(numbers == i)[0] for i in [1, 6, 7, 8]], dtype=object)
# or even better:
# 

In [None]:
#get indexes of some selector
#DURING ALL THE LOOPING ONLY CONSIDER STATUS PASSING, only append those to list
#alternative of all this is using dicts wirth identifier

indices = FPS(n_to_select=100).fit(X_raw.T).selected_idx_

#get accumulated count of atomic envs
#when masking was used: len(frames[0][frames[0].numbers == 1])?
n_centers = [len(frame) for frame in frames]
#could also be solved with np.cumsum
n_env_accum = [sum(n_centers[: i + 1]) for i in range(len(n_centers))]


#extracting indices of masked structures
for frame in frames: mask_center_atoms_by_species(frame,species_select=[1])
numbers_masked = np.concatenate([frame.arrays["center_atoms_mask"] for frame in frames])
masked_loc = np.where(numbers_masked == True)[0]
#or: use: 
# numbers = np.concatenate([frame.numbers for frame in frames])
# number_loc = np.array([np.where(numbers == i)[0] for i in [1, 6, 7, 8]], dtype=object)
#index through index set, through obtain full indices
#masked_loc[selected_indices]

#bin them
i_selected = i_selected
correct_is[1:] = i_selected[1:] - bins[np.digitize(i_selected,np.array(n_env_accum))-1][1:] #indices of frames the selected envs belong to


#TODO: Write simple test, selecting every sample up to, or something


In [98]:
from rascal.neighbourlist.structure_manager import mask_center_atoms_by_species, mask_center_atoms_by_id, 

In [103]:
len(frames[0])

156

In [104]:
mask_center_atoms_by_species(frames[0],species_select=[])


In [105]:
mask_center_atoms_by_id(frames[0],id_select=[0,3])

In [106]:
frames[0].arrays.pop('center_atoms_mask')

{'numbers': array([8, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 7, 7, 7, 7, 1, 1, 1, 1, 6, 6,
        6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1,
        6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1,
        1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6,
        1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6,
        6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]),
 'positions': array([[ 5.74953047,  0.66494065,  2.36463517],
        [ 3.37387218,  5.4527765 ,  4.89762156],
        [ 2.13721082,  8.91073104, 12.15988865],
        [ 4.51286911,  4.1228952 ,  9.62690227],
        [ 3.93328511,  3.8033387 ,  3.04713615],
        [ 5.19011754,  8.59117455,  4.21512058],
        [ 3.95345618,  5.77232264, 11.47738767],
        [ 2.69662376,  0.98448679, 10.30940324],
        [ 3.65029968,  4.45127847,  3.8052864 ],
 

In [108]:
frames[0].arrays

{'numbers': array([8, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 7, 7, 7, 7, 1, 1, 1, 1, 6, 6,
        6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1,
        6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1,
        1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6,
        1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6,
        6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]),
 'positions': array([[ 5.74953047,  0.66494065,  2.36463517],
        [ 3.37387218,  5.4527765 ,  4.89762156],
        [ 2.13721082,  8.91073104, 12.15988865],
        [ 4.51286911,  4.1228952 ,  9.62690227],
        [ 3.93328511,  3.8033387 ,  3.04713615],
        [ 5.19011754,  8.59117455,  4.21512058],
        [ 3.95345618,  5.77232264, 11.47738767],
        [ 2.69662376,  0.98448679, 10.30940324],
        [ 3.65029968,  4.45127847,  3.8052864 ],
 

In [84]:
%config IPCompleter.use_jedi = False

In [86]:
this = frames[0]

In [None]:
this.nu

In [112]:
frames[0].arrays

{'numbers': array([8, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 7, 7, 7, 7, 1, 1, 1, 1, 6, 6,
        6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1,
        6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1,
        1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6,
        1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6,
        6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1]),
 'positions': array([[ 5.74953047,  0.66494065,  2.36463517],
        [ 3.37387218,  5.4527765 ,  4.89762156],
        [ 2.13721082,  8.91073104, 12.15988865],
        [ 4.51286911,  4.1228952 ,  9.62690227],
        [ 3.93328511,  3.8033387 ,  3.04713615],
        [ 5.19011754,  8.59117455,  4.21512058],
        [ 3.95345618,  5.77232264, 11.47738767],
        [ 2.69662376,  0.98448679, 10.30940324],
        [ 3.65029968,  4.45127847,  3.8052864 ],
 

In [116]:
frames[0]

Atoms(symbols='C68H76N4O8', pbc=True, cell=[[10.360074368, 0.0, 0.0], [0.0, 9.5756716973, 0.0], [-2.47333307439, 0.0, 14.5245238218]], CS=..., center_atoms_mask=...)

In [115]:
frames[0][frames[0].arrays["center_atoms_mask"]]

Atoms(symbols='H76O2', pbc=True, cell=[[10.360074368, 0.0, 0.0], [0.0, 9.5756716973, 0.0], [-2.47333307439, 0.0, 14.5245238218]], CS=..., center_atoms_mask=...)

In [87]:
len(frames[0][frames[0].numbers == 1])

Atoms(symbols='H76', pbc=True, cell=[[10.360074368, 0.0, 0.0], [0.0, 9.5756716973, 0.0], [-2.47333307439, 0.0, 14.5245238218]], CS=...)

In [81]:
correct_is

array([  0,  78,  21,  47,  87,  13,   5,  17,  99,  25,  63,   1,  49,
        75,  99,  91,  50,  86,  55,  70,  58,  85,  47,  29,  36,   7,
       146,   8,  18,  63,  24,   0,  19,  22,   9,   8,   6, 131,  38,
        12,   2,  17,  70,  60,  25,  10,   3,   2,  21,   2,  83,  12,
        65,   9,   9, 162,  76, 157,  72,   8,  64,  59,  23,  43,  69,
        33,  87,  35,  81,  11,  15,  15,  15,   5,   1,  24,  73,  32,
       127,   3,  34,  29,  36,  44, 120, 131,   2,  23,  56,  16,   3,
        29,  11, 105,  27,   3,  50,  53,  33,  50])

In [74]:
new_loc

3942

In [73]:
orig_loc

50

In [38]:
type(i_selected)

numpy.ndarray

In [40]:
n_env_select

8972

In [52]:
bins = np.array(n_env_accum)

In [62]:
len(np.unique(np.digitize(i_selected,np.array(n_env_accum)),return_index=True)[1])

95

In [67]:
frames_select_inds = np.digitize(i_selected,np.array(n_env_accum)) #indices of frames the selected envs belong to

In [72]:
i_select_reduced

[0,
 2142,
 2709,
 4259,
 491,
 7053,
 7361,
 173,
 6823,
 2805,
 6569,
 2575,
 1573,
 1067,
 7931,
 7759,
 2314,
 6120,
 7659,
 8626,
 590,
 7265,
 4897,
 5803,
 664,
 163,
 8946,
 3262,
 2910,
 4323,
 4974,
 4650,
 4761,
 8474,
 3153,
 5722,
 1786,
 5375,
 7994,
 5636,
 6304,
 5405,
 5928,
 1960,
 4449,
 1718,
 1451,
 2350,
 5565,
 6656,
 5085,
 8232,
 2541,
 393,
 8361,
 3786,
 4088,
 6503,
 6222,
 608,
 760,
 3003,
 1351,
 4507,
 3013,
 3447,
 3887,
 3247,
 4417,
 7143,
 1989,
 8083,
 3511,
 781,
 817,
 7532,
 4209,
 4562,
 8759,
 2727,
 4684,
 265,
 8160,
 6274,
 7024,
 1299,
 2838,
 5175,
 1160,
 792,
 6699,
 1677,
 215,
 377,
 4007,
 387,
 938,
 5145,
 7809,
 3942]

In [None]:
np.argwhere()

In [69]:
i_selected

array([     0,  27957,  49576,  77698,   5313, 138347, 147411,    431,
       133947,  53478, 128073,  48526,  19877,  13539, 164501, 155475,
        30079, 124818, 152607, 181220,   6890, 146953,  94086, 121007,
         9580,    421, 184764,  64569,  55455,  78272,  96248,  92367,
        92614, 175354,  60854, 118542,  22806, 110657, 167652, 114228,
       125800, 112281, 121196,  26355,  84892,  22658,  19389,  40235,
       112811, 129698, 101779, 169486,  42962,   2219, 171807,  67065,
        75671, 126983, 124920,   7708,  10460,  58450,  15761,  89536,
        58460,  65618,  68384,  60948,  78366, 144577,  26782, 168149,
        65894,  11005,  11121, 150512,  76284,  91717, 182509,  53008,
        92401,   2015, 169414, 125490, 137818,  15657,  53983, 103337,
        14674,  11016, 131619,  21841,   1067,   2203,  74226,   2213,
        11846, 102391, 155525,  69595])

In [71]:
i_selected - (bins[frames_select_inds]-bins[0]) #? is this the index of the atom in the selected structure?

array([  0,  34, 141, 155, 115,  77,   9, 125,  75, 125,  71,  43,  81,
       119, 131, 139, 122, 126, 147, 150, 146,  65, 103, 101, 124, 115,
       130,   4, 122, 143, 128,  64,  67,  74,  97, 104,  42, 143,  82,
        78, 114,  17,  50, 142, 141,  94,  83,  30,  97, 116, 149,  36,
       123, 145,  65, 142, 108, 153, 148, 136, 140,  15,  59, 133,  25,
       107, 151, 149, 149, 119,  81, 115,  43, 121,  85,  84, 153,  68,
       115, 103,  98, 149,  96, 128, 140, 127, 102,  87, 148, 132, 131,
       125, 135, 149, 151, 139, 102, 149, 133, 118])

In [None]:
i_selected[1:] - bins[np.digitize(i_selected,np.array(n_env_accum))-1][1:]

In [57]:
#arrays run from 0 - N_env-1. Bins run from 1, N_env, 
#problem? last one explodes. adding a zero?

i_selected[1:] - bins[np.digitize(i_selected,np.array(n_env_accum))-1][1:] #



array([ 78,  21,  47,  87,  13,   5,  17,  99,  25,  63,   1,  49,  75,
        99,  91,  50,  86,  55,  70,  58,  85,  47,  29,  36,   7, 146,
         8,  18,  63,  24,   0,  19,  22,   9,   8,   6, 131,  38,  12,
         2,  17,  70,  60,  25,  10,   3,   2,  21,   2,  83,  12,  65,
         9,   9, 162,  76, 157,  72,   8,  64,  59,  23,  43,  69,  33,
        87,  35,  81,  11,  15,  15,  15,   5,   1,  24,  73,  32, 127,
         3,  34,  29,  36,  44, 120, 131,   2,  23,  56,  16,   3,  29,
        11, 105,  27,   3,  50,  53,  33,  50])

In [51]:
i_selected.size

100

In [50]:
for n in range(i_selected.size):
    print(n_env_accum[i_selected[n]-1], "<=", i_selected[n], "<", n_env_accum[i_selected[n]])

188100 <= 0 < 156


IndexError: list index out of range

In [48]:
np.digitize(i_selected,np.array(n_env_accum))

array([   0,  317,  559,  874,   60, 1508, 1598,    5, 1464,  603, 1405,
        548,  227,  157, 1764, 1671,  340, 1370, 1645, 1929,   73, 1594,
       1049, 1332,  110,    5, 1966,  720,  624,  879, 1073, 1032, 1034,
       1872,  681, 1305,  263, 1223, 1793, 1261, 1382, 1244, 1334,  300,
        950,  261,  220,  455, 1249, 1420, 1128, 1813,  484,   29, 1837,
        743,  852, 1393, 1371,   85,  119,  656,  182, 1002,  656,  729,
        760,  682,  880, 1570,  305, 1799,  732,  126,  128, 1627,  859,
       1025, 1943,  597, 1032,   26, 1812, 1379, 1502,  180,  609, 1146,
        171,  126, 1442,  250,   14,   28,  834,   29,  136, 1135, 1672,
        776])

In [45]:
np.unique(np.digitize(i_selected,np.array(n_env_accum))) == np.array(reduced_frames_select)

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True])

In [44]:
reduced_frames_select

[0,
 5,
 14,
 26,
 28,
 29,
 60,
 73,
 85,
 110,
 119,
 126,
 128,
 136,
 157,
 171,
 180,
 182,
 220,
 227,
 250,
 261,
 263,
 300,
 305,
 317,
 340,
 455,
 484,
 548,
 559,
 597,
 603,
 609,
 624,
 656,
 681,
 682,
 720,
 729,
 732,
 743,
 760,
 776,
 834,
 852,
 859,
 874,
 879,
 880,
 950,
 1002,
 1025,
 1032,
 1034,
 1049,
 1073,
 1128,
 1135,
 1146,
 1223,
 1244,
 1249,
 1261,
 1305,
 1332,
 1334,
 1370,
 1371,
 1379,
 1382,
 1393,
 1405,
 1420,
 1442,
 1464,
 1502,
 1508,
 1570,
 1594,
 1598,
 1627,
 1645,
 1671,
 1672,
 1764,
 1793,
 1799,
 1812,
 1813,
 1837,
 1872,
 1929,
 1943,
 1966]

In [36]:
n_env_accum_select

[156,
 204,
 236,
 272,
 384,
 404,
 532,
 600,
 628,
 696,
 776,
 816,
 888,
 992,
 1104,
 1168,
 1328,
 1448,
 1524,
 1648,
 1708,
 1780,
 1900,
 1974,
 2064,
 2264,
 2348,
 2476,
 2574,
 2688,
 2724,
 2780,
 2836,
 2892,
 2944,
 3144,
 3212,
 3254,
 3414,
 3496,
 3624,
 3800,
 3892,
 3980,
 4012,
 4136,
 4212,
 4260,
 4336,
 4424,
 4464,
 4530,
 4650,
 4742,
 4850,
 4950,
 5002,
 5092,
 5152,
 5244,
 5388,
 5544,
 5624,
 5714,
 5774,
 5858,
 6034,
 6150,
 6230,
 6302,
 6346,
 6506,
 6654,
 6696,
 6724,
 6904,
 7040,
 7132,
 7180,
 7356,
 7508,
 7604,
 7668,
 7776,
 7832,
 7956,
 8068,
 8124,
 8220,
 8352,
 8452,
 8556,
 8632,
 8800,
 8972]

In [35]:
reduced_frames_select #-> not atomic indices

[0,
 5,
 14,
 26,
 28,
 29,
 60,
 73,
 85,
 110,
 119,
 126,
 128,
 136,
 157,
 171,
 180,
 182,
 220,
 227,
 250,
 261,
 263,
 300,
 305,
 317,
 340,
 455,
 484,
 548,
 559,
 597,
 603,
 609,
 624,
 656,
 681,
 682,
 720,
 729,
 732,
 743,
 760,
 776,
 834,
 852,
 859,
 874,
 879,
 880,
 950,
 1002,
 1025,
 1032,
 1034,
 1049,
 1073,
 1128,
 1135,
 1146,
 1223,
 1244,
 1249,
 1261,
 1305,
 1332,
 1334,
 1370,
 1371,
 1379,
 1382,
 1393,
 1405,
 1420,
 1442,
 1464,
 1502,
 1508,
 1570,
 1594,
 1598,
 1627,
 1645,
 1671,
 1672,
 1764,
 1793,
 1799,
 1812,
 1813,
 1837,
 1872,
 1929,
 1943,
 1966]

In [30]:
n_env_accum

[156,
 200,
 320,
 396,
 414,
 462,
 510,
 710,
 762,
 782,
 902,
 952,
 992,
 1056,
 1088,
 1116,
 1236,
 1388,
 1402,
 1446,
 1582,
 1622,
 1658,
 1730,
 1810,
 1986,
 2022,
 2098,
 2210,
 2230,
 2278,
 2438,
 2574,
 2714,
 2822,
 2834,
 2990,
 3102,
 3190,
 3298,
 3362,
 3446,
 3590,
 3710,
 3870,
 3914,
 3946,
 4078,
 4186,
 4222,
 4398,
 4478,
 4506,
 4614,
 4690,
 4736,
 4880,
 4996,
 5140,
 5226,
 5354,
 5546,
 5706,
 5794,
 5922,
 5998,
 6082,
 6248,
 6432,
 6552,
 6600,
 6784,
 6832,
 6900,
 6948,
 7100,
 7140,
 7204,
 7240,
 7280,
 7376,
 7476,
 7604,
 7652,
 7700,
 7728,
 7848,
 7932,
 8060,
 8128,
 8152,
 8196,
 8254,
 8374,
 8434,
 8492,
 8556,
 8596,
 8696,
 8724,
 8828,
 8988,
 9048,
 9232,
 9308,
 9340,
 9368,
 9448,
 9504,
 9544,
 9612,
 9672,
 9736,
 9832,
 10012,
 10108,
 10256,
 10336,
 10396,
 10476,
 10512,
 10588,
 10700,
 10820,
 10876,
 11000,
 11040,
 11120,
 11192,
 11280,
 11440,
 11500,
 11520,
 11596,
 11764,
 11796,
 11900,
 12052,
 12156,
 12238,
 12310,

In [None]:
n_env_accum

In [33]:
np.where(np.array(n_env_accum) > 1000)[0][0]

13

In [29]:
np.where(n_env_accum > 100)[0][0]

TypeError: '>' not supported between instances of 'list' and 'int'

In [27]:
i_selected

array([     0,  27957,  49576,  77698,   5313, 138347, 147411,    431,
       133947,  53478, 128073,  48526,  19877,  13539, 164501, 155475,
        30079, 124818, 152607, 181220,   6890, 146953,  94086, 121007,
         9580,    421, 184764,  64569,  55455,  78272,  96248,  92367,
        92614, 175354,  60854, 118542,  22806, 110657, 167652, 114228,
       125800, 112281, 121196,  26355,  84892,  22658,  19389,  40235,
       112811, 129698, 101779, 169486,  42962,   2219, 171807,  67065,
        75671, 126983, 124920,   7708,  10460,  58450,  15761,  89536,
        58460,  65618,  68384,  60948,  78366, 144577,  26782, 168149,
        65894,  11005,  11121, 150512,  76284,  91717, 182509,  53008,
        92401,   2015, 169414, 125490, 137818,  15657,  53983, 103337,
        14674,  11016, 131619,  21841,   1067,   2203,  74226,   2213,
        11846, 102391, 155525,  69595])

In [23]:
np.where(numbers == 1)[0]

array([     8,      9,     10, ..., 188097, 188098, 188099])

In [26]:
numbers[np.where(numbers == 1)[0]]

array([1, 1, 1, ..., 1, 1, 1])

In [21]:
i_selected

array([     0,  27957,  49576,  77698,   5313, 138347, 147411,    431,
       133947,  53478, 128073,  48526,  19877,  13539, 164501, 155475,
        30079, 124818, 152607, 181220,   6890, 146953,  94086, 121007,
         9580,    421, 184764,  64569,  55455,  78272,  96248,  92367,
        92614, 175354,  60854, 118542,  22806, 110657, 167652, 114228,
       125800, 112281, 121196,  26355,  84892,  22658,  19389,  40235,
       112811, 129698, 101779, 169486,  42962,   2219, 171807,  67065,
        75671, 126983, 124920,   7708,  10460,  58450,  15761,  89536,
        58460,  65618,  68384,  60948,  78366, 144577,  26782, 168149,
        65894,  11005,  11121, 150512,  76284,  91717, 182509,  53008,
        92401,   2015, 169414, 125490, 137818,  15657,  53983, 103337,
        14674,  11016, 131619,  21841,   1067,   2203,  74226,   2213,
        11846, 102391, 155525,  69595])

In [11]:
X_raw_select.shape

(8972, 2520)

In [10]:
X_sample_select.shape

(100, 2520)

In [12]:
X_raw.shape

(188100, 360)

In [13]:
n_env_accum

[156,
 200,
 320,
 396,
 414,
 462,
 510,
 710,
 762,
 782,
 902,
 952,
 992,
 1056,
 1088,
 1116,
 1236,
 1388,
 1402,
 1446,
 1582,
 1622,
 1658,
 1730,
 1810,
 1986,
 2022,
 2098,
 2210,
 2230,
 2278,
 2438,
 2574,
 2714,
 2822,
 2834,
 2990,
 3102,
 3190,
 3298,
 3362,
 3446,
 3590,
 3710,
 3870,
 3914,
 3946,
 4078,
 4186,
 4222,
 4398,
 4478,
 4506,
 4614,
 4690,
 4736,
 4880,
 4996,
 5140,
 5226,
 5354,
 5546,
 5706,
 5794,
 5922,
 5998,
 6082,
 6248,
 6432,
 6552,
 6600,
 6784,
 6832,
 6900,
 6948,
 7100,
 7140,
 7204,
 7240,
 7280,
 7376,
 7476,
 7604,
 7652,
 7700,
 7728,
 7848,
 7932,
 8060,
 8128,
 8152,
 8196,
 8254,
 8374,
 8434,
 8492,
 8556,
 8596,
 8696,
 8724,
 8828,
 8988,
 9048,
 9232,
 9308,
 9340,
 9368,
 9448,
 9504,
 9544,
 9612,
 9672,
 9736,
 9832,
 10012,
 10108,
 10256,
 10336,
 10396,
 10476,
 10512,
 10588,
 10700,
 10820,
 10876,
 11000,
 11040,
 11120,
 11192,
 11280,
 11440,
 11500,
 11520,
 11596,
 11764,
 11796,
 11900,
 12052,
 12156,
 12238,
 12310,

In [19]:
number_loc[4].shape

IndexError: index 4 is out of bounds for axis 0 with size 4