In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from datetime import datetime
from pathlib import Path
import pickle
import sys
import numpy as np

import matplotlib.pyplot as plt
from pymatgen.core.structure import Molecule

Custom plotting code... ignore this if you don't know what it is, it should gracefully do nothing if you don't have the `MPLAdjutant` class. 

In [None]:
sys.path.append(str(Path.home() / Path("local")))
class NullClass:
    def do_nothing(*args, **kwargs):
        pass
    def add_colorbar(self, im, **kwargs):
        return plt.colorbar(im)
    def __getattr__(self, _):
        return self.do_nothing
try:
    from mpl_utils import MPLAdjutant
    adj = MPLAdjutant()
    adj.set_defaults()
except ImportError:
    adj = NullClass()

In [None]:
import json

def save_json(d, path):
    with open(path, 'w') as outfile:
        json.dump(d, outfile, indent=4, sort_keys=True)

def read_json(path):
    with open(path, 'r') as infile:
        dat = json.load(infile)
    return dat

Append the `home` path of this project.

In [None]:
sys.path.append(str(Path.cwd().parent))

# Load the data

In [None]:
ATOM_TYPE = "O"
MAX_ABS = "3"

In [None]:
# Standard random split
if MAX_ABS is None:
    data = pickle.load(open(f"../data/qm9/ml_ready/XANES-220626-ACSF-{ATOM_TYPE}-RANDOM-SPLITS.pkl", "rb"))
else:
    print("Loading abs data")
    data = pickle.load(open(f"../data/qm9/ml_ready/XANES-220629-ACSF-{ATOM_TYPE}-MAX_TRAINING_ABSORBERS-{MAX_ABS}.pkl", "rb"))

# Create the ML database

Construct a ML `Data` object. To do so we get a random validation split, then instantiate the `Data` object.

In [None]:
import torch

In [None]:
from xas_nne.ml import Ensemble

In [None]:
now = datetime.now().strftime("%y%m%d")

from_random_architecture_kwargs={
    "min_layers": 4,
    "max_layers": 7,
    "min_neurons_per_layer": 160,
    "max_neurons_per_layer": 300,
    "dropout": 0.0,
    "batch_norm": True,
    "activation": "leaky_relu",
    "last_activation": "softplus",
    "criterion": "mae",
    "last_batch_norm": False,
}

if MAX_ABS is not None:
    print("Root set abs data")
    root = f"Ensembles/{now}-{ATOM_TYPE}-MAX_TRAINING_ABSORBERS-{MAX_ABS}"
else:
    root = f"Ensembles/{now}-{ATOM_TYPE}"

ensemble = Ensemble.from_random_architectures(
    root=root,
    n_estimators=15,
    seed=125,
    from_random_architecture_kwargs=from_random_architecture_kwargs,
)

In [None]:
ensemble.train_ensemble_parallel(
    training_data=data["train"],
    ensemble_index=0,
    epochs=1000,
    n_jobs=3
)

In [None]:
d = ensemble.as_dict()
path = Path(ensemble._root) / Path("Ensemble.json")
save_json(d, path)

In [None]:
gt = data["test"]["y"]

In [None]:
pred = ensemble.predict(data["test"]["x"])

## Plot some examples

In [None]:
grid = data["train"]["grid"]

In [None]:
ii = -9
predicted_spectra = pred[:, ii, :]
ground_truth_spectra = gt[ii, :]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))

print(data["test"]["origin_smiles"][ii])

ax.plot(grid, ground_truth_spectra, "k-")

for prediction in predicted_spectra:
    ax.plot(grid, prediction, 'r-', linewidth=0.5, alpha=0.5)

# err = np.log10(np.mean(np.abs(gt[ii] - pred[ii])))
# print(f"{err:.02f}")

plt.show()
