In [1]:
import qml 

In [2]:
from glob import glob
import numpy as np

In [5]:
NUCLEAR_CHARGE = {
    "H":1,
    "C":6,
    "O":8,
    "N":7,
    "F":9,
    "Cl":17,
    "S":16
}

In [4]:
def read_xyz(filename):
    with open(filename, "r") as f:
        lines = f.readlines()

    natoms = int(lines[0])
    nuclear_charges = []
    coordinates = []

    for i, line in enumerate(lines[2:natoms+2]):
        tokens = line.split()

        if len(tokens) < 4:
            break
        
        ncharge = tokens[0]
        if ncharge != 'H':
            nuclear_charges.append(NUCLEAR_CHARGE[tokens[0]])
            coordinates.append([float(token) for token in tokens[1:4]])
   
    return nuclear_charges, coordinates

In [3]:
target_xyzs = sorted(glob("../targets/*.xyz"))

In [6]:
target_xyzs

['../targets/penicillin.xyz', '../targets/qm9_0.xyz']

In [7]:
conf_data = [read_xyz(x) for x in target_xyzs]

In [8]:
ncharges_list, coords_list = zip(*conf_data)

In [9]:
len(ncharges_list[0])

23

In [22]:
# mbtypes separate to each target

In [10]:
target_reps = np.array(
[np.array(qml.representations.generate_slatm(coords_list[i], ncharges_list[i], 
                                    mbtypes=qml.representations.get_slatm_mbtypes([ncharges_list[i]]),
                                            local=True))
for i in range(len(ncharges_list))])

In [11]:
target_labels = [t.split("/")[-1].split(".xyz")[0] for t in target_xyzs]

In [12]:
target_labels

['penicillin', 'qm9_0']

In [13]:
np.savez("../representations/target_SLATM_data.npz", 
         target_labels=target_labels, 
         target_reps=target_reps, 
         target_ncharges=ncharges_list,)