In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import sys
import pickle
from copy import deepcopy

import numpy as np
from pymatgen.core.structure import Molecule
from scipy.interpolate import InterpolatedUnivariateSpline
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

In [None]:
# https://gist.github.com/x94carbone/f5201b1c44963ff9453b9cc1d5f768ac
sys.path.append(str(Path.home() / Path("local")))
from mpl_utils import MPLAdjutant
adj = MPLAdjutant()
adj.set_defaults()

In [None]:
from xas_nne.feff import FeffWriter, load_completed_FEFF_results  # noqa

Append the `home` path of this project.

In [None]:
sys.path.append(str(Path.cwd().parent))

In [None]:
path = Path("data/qm9/XANES-220622-C-N-O.pkl")
print(path.exists())
data = pickle.load(open(path, "rb"))

# Prepare the FEFF inputs

Let's look at the 10 molecules chosen for the random splits testing.

```
N=C1N=NON=NC1=O 129158_7 7
O=C1C2COC11CNC21 43138_6 6
CC1(CNC=O)CC1O 87244_1 1
OC12C3NC1C1(CO1)C23 67255_2 2
O=COC1(CC#N)CC1 50994_1 1
COC1C2C3NC3(C)C12 110619_4 4
COC12CCCC1C2O 108590_6 6
OCC12CC3OC1C23 17249_7 7
CC1(C)OCC1(O)CO 104189_0 0
OC1(CC1)C1CCCC1 65272_2 2
```

In [None]:
qm9_ids = {129158: 7, 43138: 6, 87244: 1, 67255: 2, 50994: 1, 110619: 4, 108590: 6, 17249: 7, 104189: 0, 65272: 2}

In [None]:
molecules = {
    key: Molecule.from_dict(data[str(key)]["molecule"]) for key in qm9_ids
}

In [None]:
ABSORBERS = ["C"]
xanes = True
distortions = np.unique([round(ii * 0.01, 2) for ii in range(1, 101)]).tolist()
n_distort_per = 50

In [None]:
for distortion_magnitude in tqdm(distortions):
    
    for ii in range(n_distort_per):
        
        for qm9id, molecule in molecules.items():

            for absorber in ABSORBERS:

                m = deepcopy(molecule)
                m.perturb(distortion_magnitude)

                f = FeffWriter(m, xanes=xanes, name=qm9id)
                dname = Path(f"{absorber}-XANES") / Path(f"{int(qm9id):06}-{distortion_magnitude:.02f}-{ii:02}")
                f.write_feff_inputs(str(dname), absorber=absorber)

# Process the FEFF output

In [None]:
def feff_inp_to_molecule(feff_inp_string):
    iistar = [ii for ii, line in enumerate(feff_inp_string) if "ATOMS" in line]
    assert len(iistar) == 1
    lines = feff_inp_string[iistar[0] + 1:-1]
    lines = [xx.strip() for xx in lines]
    coords = [xx.split()[:-1] for xx in lines]
    atoms = [xx.split()[-1] for xx in lines]
    n_unique = len(np.unique(atoms))
    
    jjstar = [ii for ii, line in enumerate(feff_inp_string) if "POTENTIALS" in line]
    assert len(jjstar) == 1
    jjstar = jjstar[0] + 2
    atom_mappings = feff_inp_string[jjstar:jjstar + n_unique]
    atom_mappings = [xx.split("\t") for xx in atom_mappings]
    atom_mappings = {int(xx[0]): xx[2] for xx in atom_mappings}
    atoms = [atom_mappings[int(xx)] for xx in atoms]
    
    return Molecule(coords=np.array(coords, dtype=float), species=atoms)

In [None]:
paths = list(Path("data/qm9/C-XANES-distorted/").rglob("feff.inp"))

In [None]:
distortion_results = dict()

In [None]:
for path in tqdm(paths):
    key = "/".join(path.parts[3:5])
    split = path.parts[3].split("-")
    res = load_completed_FEFF_results(path.parent)
    if res["spectrum"] is None:
        continue
    distortion_results[key] = res
    distortion_results[key]["distortion"] = float(split[1])
    distortion_results[key]["distortion_index"] = int(split[2])
    distortion_results[key]["molecule"] = feff_inp_to_molecule(res["feff.inp"]).as_dict()

# Plot the results

In [None]:
cmap = cm.get_cmap("rainbow", 10)

In [None]:
scale = 7
fig, axs = plt.subplots(10, len(qm9_ids), figsize=(3*scale, 2*scale), sharex=True, sharey=True)


for ii, distortion in enumerate([0.01 + 0.01 * ii for ii in range(10)]):
    distortion = round(distortion, 2)
    for jj, qm9id in enumerate(list(qm9_ids.keys())[::-1]):
        ax = axs[jj, ii]
        
        # adj.set_grids(ax, grid=False)
        # ax.set_yticks([])
        ax.axis('off')
        
        distortion_results_list = [
            value for key, value in distortion_results.items()
            if value["distortion"] == distortion and value["qm9id"] == str(qm9id) and value["site"] == qm9_ids[qm9id]
        ][::5]
        for result in distortion_results_list:
            arr = np.array(result["spectrum"])
            ax.plot(arr[:, 0], arr[:, 3], color=cmap(ii), alpha=0.6)

plt.subplots_adjust(wspace=0.1)
plt.ylim(-1, 10)
plt.title("Distortion~(\AA)")

# plt.show()
plt.savefig("Figures/qm9_distortion_waterfall.svg", bbox_inches="tight", dpi=300)

## Construct ACSF vectors

In [None]:
N = 200
grids = {
    "O": np.linspace(528, 582, N),
    "N": np.linspace(395, 449, N),
    "C": np.linspace(275, 329, N)
}

In [None]:
CENTRAL_ATOM = "C"   # C, N or O
grid = grids[CENTRAL_ATOM]

In [None]:
from dscribe.descriptors import ACSF

In [None]:
species = ["H", "C", "O", "N", "F"]
rcut = 6.0
g2_params = [[1.0, 0], [0.1, 0], [0.01, 0]]
g4_params=[
    [0.001, 1.0, -1.0],
    [0.001, 2.0, -1.0],
    [0.001, 4.0, -1.0],
    [0.01, 1.0, -1.0],
    [0.01, 2.0, -1.0],
    [0.01, 4.0, -1.0],
    [0.1, 1.0, -1.0],
    [0.1, 2.0, -1.0],
    [0.1, 3.0, -1.0]
]  # aenet paper
acsf = ACSF(
    species=species,
    rcut=rcut,
    g2_params=g2_params,
    g4_params=g4_params
)

In [None]:
from ase import Atom, Atoms

In [None]:
molecule_site_pairs = []
acsf_array = []
spectra = []
names = []

# cc = 0
for key, datum in tqdm(distortion_results.items()):
    qm9id = datum["qm9id"]
    
    molecule = Molecule.from_dict(datum["molecule"])
    atoms = []
    absorber_site = datum["site"]

    for ii, site in enumerate(molecule):
        atom = Atom(site.specie.symbol, site.coords)
        atoms.append(atom)
    atoms = Atoms(atoms)

    tmp_acsf = acsf.create(atoms, positions=[absorber_site])

    s = np.array(datum["spectrum"])

    try:
        spline = InterpolatedUnivariateSpline(s[:, 0], s[:, 3])
    except IndexError:
        continue

    res = spline(grid)

    spectra.append(res)
    acsf_array.append(tmp_acsf.squeeze())
    molecule_site_pairs.append(f"{qm9id}_{absorber_site}")
    names.append(key)
    
#     cc += 1
    
#     if cc > 20:
#         break

acsf_array = np.array(acsf_array)
spectra = np.array(spectra)

And finally save to disk.

In [None]:
now = datetime.now().strftime("%y%m%d")
fname = Path(f"data/qm9/XANES-{now}-ACSF-{CENTRAL_ATOM}-distorted.pkl")
print(fname)
assert fname.parent.exists()

We take the convention that `"x"` is the input and `"y"` is the output. These are the only two required keys for the ML pipeline. The rest is considered metadata.

In [None]:
pickle.dump(
    {"grid": grid, "y": spectra, "x": acsf_array, "molecule_site_pairs": molecule_site_pairs, "names": names},
    open(fname, "wb"),
    protocol=pickle.HIGHEST_PROTOCOL
)