In [5]:
from dscribe.descriptors import SOAP
import qml

In [6]:
from glob import glob
import numpy as np

In [7]:
database_xyzs = sorted(glob("qm7/*.xyz"))

In [8]:
database_mols = [qml.Compound(x) for x in database_xyzs]

In [14]:
pt = {1:"H", 6:"C", 7:"N", 8:"O", 9:"F", 16:"S"}

In [15]:
def get_ncharges_coords(mol):
    ncharges = mol.nuclear_charges
    heavy_ints = [i for i,x in enumerate(ncharges) if x!=1]
    symbols = [pt[x] for x in ncharges]
    heavy_ncharges = [ncharges[i] for i in heavy_ints]
    heavy_symbols = [symbols[i] for i in heavy_ints]
    coords = mol.coordinates
    heavy_coords = [coords[i] for i in heavy_ints]
    return heavy_ncharges, heavy_symbols, heavy_coords

In [16]:
import ase

In [17]:
def get_rep(mol, elements=[6,7,8,16]):
    ncharges, atomtypes, coords = get_ncharges_coords(mol)
    atomsobj = ase.Atoms(symbols=atomtypes, positions=coords)
    soap = SOAP(
             species=elements,
             rcut=5.0,
             nmax=8,
             lmax=8,
             sigma=0.2,
             periodic=False,
             crossover=True,
             sparse=False,
         )
    return ncharges, soap.create(atomsobj)

In [42]:
# pad size is based on largest target 

In [19]:
database_ncharges = []
database_reps = []
for mol in database_mols:
    ncharge, rep = get_rep(mol)
    database_ncharges.append(ncharge)
    database_reps.append(rep)

In [20]:
database_reps = np.array(database_reps)

  """Entry point for launching an IPython kernel.


In [21]:
database_labels = [t.split("/")[-1].split(".xyz")[0] for t in database_xyzs]

In [22]:
database_labels = np.array(database_labels)

In [23]:
database_ncharges = np.array(database_ncharges)

  """Entry point for launching an IPython kernel.


In [24]:
database_reps[0].shape

(1, 4752)

In [25]:
np.savez("database_SOAP.npz", 
         database_labels=database_labels, 
         database_reps=database_reps,
        database_ncharges=database_ncharges)

In [26]:
database_labels

array(['qm7_0', 'qm7_1', 'qm7_10', ..., 'qm7_997', 'qm7_998', 'qm7_999'],
      dtype='<U8')