## Learning a custom scalar field with `rho_learn`/`FHI-aims`: LUMO of water

**NOTE**: this notebook has only been tested on the HPC system "jed". Due to
specific compilation of the quantum chemistry package `FHI-aims`, the Python
interface to running QC calculations may not be generalized to other operating
systems or hardware.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import glob
import shutil
import time
from functools import partial

# Useful standard and scientific ML libraries
import ase.io
import matplotlib.pyplot as plt
import numpy as np
# import py3Dmol
import torch

# M-Stack packages
import metatensor   # storage format for atomistic ML
import chemiscope  # interactive molecular visualization
import rascaline   # generating structural representations
from metatensor import Labels, TensorBlock, TensorMap
from rascaline.utils import clebsch_gordan

# Interfacing with FHI-aims
from rhocalc.aims import aims_calc, aims_parser

# Torch-based density leaning
from rholearn import io, data, loss, models, predictor
import settings

## Visualize structures in dataset

* Use `chemiscope`

In [None]:
from settings import TOP_DIR, DATA_DIR, ML_DIR, DATA_SETTINGS

print("Top directory defined as: ", TOP_DIR)

# Load the frames in the complete dataset
all_frames = DATA_SETTINGS["all_frames"]

# Shuffle the total set of structure indices
idxs = np.arange(len(all_frames))
np.random.default_rng(seed=DATA_SETTINGS["seed"]).shuffle(idxs)

# Take a subset of the frames if desired
idxs = idxs[:DATA_SETTINGS["n_frames"]]
frames = [all_frames[A] for A in idxs]

# Set the ri restart and rebuild indices
ri_restart_idx = 0
ri_rebuild_idx = 1

In [None]:
chemiscope.show(
    frames,
    properties={
        "Mean O-H bond length, Angstrom": [np.mean([f.get_distance(0, 1), f.get_distance(0, 2)]) for f in frames],
        "H-O-H angle, degrees": [f.get_angle(1, 0, 2) for f in frames],
    },
)

## Generate learning targets

* These are the RI coefficients of the HOMO eigenstate
* We generate these in 2 steps: 1) SCF and 2) RI fitting

In [None]:
# A callable that takes structure idx as an argument, returns path to AIMS SCF
# output data
def scf_dir(A):
    return os.path.join(DATA_DIR, f"{A}")

# A callable that takes structure idx as an argument, returns path to AIMS RI
# output data
def ri_dir(A, restart_idx):
    return os.path.join(scf_dir(A), f"{restart_idx}")

# A callable that takes structure idx as an argument, returns path to processed
# data (i.e. metatensor-format)
def processed_dir(A, restart_idx):
    return os.path.join(ri_dir(A, restart_idx), "processed")

# Define callable for saving predictions made during runtime
def eval_dir(A, epoch):
    return os.path.join(ML_DIR, "evaluation", f"epoch_{epoch}", f"{A}")

# Create dirs
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
if not os.path.exists(ML_DIR):
    os.makedirs(ML_DIR)
if not os.path.exists(os.path.join(ML_DIR, "evaluation")):
    os.makedirs(os.path.join(ML_DIR, "evaluation"))

chkpt_dir = os.path.join(ML_DIR, "checkpoints")
if not os.path.exists(chkpt_dir):
    os.makedirs(chkpt_dir)

#### 1. Converge SCF for each structure

In [None]:
# Import the settings needed to run FHI-aims
from settings import AIMS_PATH, BASE_AIMS_KWARGS, SCF_KWARGS, RI_KWARGS, SBATCH_KWARGS

# Build a dict of settings for each calculation (i.e. structure)
# IMPORTANT: zip() is used to pair up the structure index and the structure
calcs = {
    A: {"atoms": frame, "run_dir": scf_dir(A)} for A, frame in zip(idxs, frames)
}

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(SCF_KWARGS)

# Define paths to the aims.out files for RI calcs
all_aims_outs = [os.path.join(scf_dir(A), "aims.out") for A in idxs]
for aims_out in all_aims_outs:
    if os.path.exists(aims_out):
        os.remove(aims_out)

# Run the SCF in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH_KWARGS,
    run_dir=scf_dir,
)

# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if os.path.exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)

#### 2. Perform RI fitting on the scalar field of interest (LUMO)

* First we need to identify the LUMO

* This can be done by parsing the Kohn Sham orbital information from the
  converged SCF calculation. 
  
* This info has been written to file "ks_orbital_info.out" above by using the
  `ri_fit_write_ks_orb_info: True` keyword set in `SCF_KWARGS`.

In [None]:
# Write KSO weights for just the HOMO to file
for A in idxs:
    # Parse the Kohn-Sham
    ks_info = aims_parser.get_ks_orbital_info(
        os.path.join(scf_dir(A), "ks_orbital_info.out")
    )
    weights = np.zeros(ks_info.shape[0])
    # kso_idxs = aims_parser.find_homo_kso_idxs(ks_info)  # HOMO
    kso_idxs = aims_parser.find_lumo_kso_idxs(ks_info)  # LUMO

    # Add a weighting of 1.0
    for kso_idx in kso_idxs:  # these are 1-indexed in AIMS
        weights[kso_idx - 1] = 1.0

    # Save these to file so AIMS can read them in
    np.savetxt(os.path.join(scf_dir(A), "ks_orbital_weights.in"), weights)

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(RI_KWARGS)

# Define paths to the aims.out files for RI calcs
all_aims_outs = [os.path.join(ri_dir(A, ri_restart_idx), "aims.out") for A in idxs]
for aims_out in all_aims_outs:
    if os.path.exists(aims_out):
        os.remove(aims_out)

# Copy restart files
for A in idxs:
    if not os.path.exists(ri_dir(A, ri_restart_idx)):
        os.makedirs(ri_dir(A, ri_restart_idx))
    for density_matrix in glob.glob(os.path.join(scf_dir(A), "D*.csc")):
        shutil.copy(density_matrix, ri_dir(A, ri_restart_idx))
    shutil.copy(
            os.path.join(scf_dir(A), "ks_orbital_weights.in"), 
            ri_dir(A, ri_restart_idx),
        )

# Run the RI fitting procedure in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH_KWARGS,
    run_dir=partial(ri_dir, restart_idx=ri_restart_idx),
)

# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if os.path.exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)

In [None]:
# Process AIMS results and print density fitting error
# Run in serial: takes approx 0.5 seconds per structure
df_errors = []
for A, frame in zip(idxs, frames):
    aims_parser.process_aims_ri_results(
        frame=frame,
        aims_output_dir=ri_dir(A, ri_restart_idx),  # includes the restart idx
        process_what=["coeffs", "ovlp"],
        structure_idx=A,
    )
    calc_info = io.unpickle_dict(os.path.join(processed_dir(A, ri_restart_idx), "calc_info.pickle"))
    df_errors.append(calc_info["df_error_percent"]["total"])
print("Mean density fitting error across all structures (%):", np.mean(df_errors))

In [None]:
# Try a rebuild of the density from the target RI coefficients
from functools import partial
import shutil
from settings import REBUILD_KWARGS

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(REBUILD_KWARGS)

# Define paths to the aims.out files for RI calcs
all_aims_outs = [os.path.join(ri_dir(A, ri_rebuild_idx), "aims.out") for A in idxs]
for aims_out in all_aims_outs:
    if os.path.exists(aims_out):
        os.remove(aims_out)

for A in idxs:
    if not os.path.exists(ri_dir(A, ri_rebuild_idx)):
        os.makedirs(ri_dir(A, ri_rebuild_idx))

    # Copy restart files
    shutil.copy(
        os.path.join(ri_dir(A, ri_restart_idx), "ri_coeffs.out"), 
        os.path.join(ri_dir(A, ri_rebuild_idx), "ri_coeffs.in"),
    )

# Run the RI fitting procedure in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH_KWARGS,
    run_dir=partial(ri_dir, restart_idx=ri_rebuild_idx),
)

# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if os.path.exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)

In [None]:
# from rhocalc.cube.rho_cube import RhoCube

# q = RhoCube("/home/abbott/rho/rho_learn/docs/example/field/data/89/1/rho_rebuilt.cube")

# x, y, z = q.get_slab_slice(axis=2, center_coord=q.ase_frame.positions[:, 2].max(), thickness=25.0)

# # Plot contours
# fig, axes = plt.subplots(1, 1, figsize=(5, 5))
# cs = axes.contourf(x, y, np.tanh(z.T), cmap='gray')
# cbar = fig.colorbar(cs)

## Build Structural Descriptors 

* Here we construct $\lambda$-SOAP equivariant descriptors for each structure
* First, find the angular orders present in the decomposition of the target
  scalar field onto the RI basis

In [None]:
# The basis set definition should is consistent for all atoms in the dataset,
# given consistent AIMS settings. Print this definiton from the parsed outputs
# from one of the structures
basis_set = io.unpickle_dict(
    os.path.join(processed_dir(A, ri_restart_idx), "calc_info.pickle")
)["basis_set"]

print(basis_set)

In [None]:
from settings import RASCAL_SETTINGS, CG_SETTINGS

# Generate a rascaline SphericalExpansion (2 body) representation. As we want to
# retain the original structure indices, we are going to pass * all * of the
# 1000 frames to rascaline, but only compute for the subset of structures in
# `idxs`.
calculator = rascaline.SphericalExpansion(**RASCAL_SETTINGS["hypers"])
density = calculator.compute(
    all_frames, 
    selected_samples=Labels(names=["structure"], values=idxs.reshape(-1, 1)),
    **RASCAL_SETTINGS["compute"],
)
density = density.keys_to_properties(
    keys_to_move=Labels(
        names=["species_neighbor"], values=np.array(DATA_SETTINGS["global_species"]).reshape(-1, 1)
    )
)

# Build a lambda-SOAP descriptor by a Clebsch-Gordan combination
# Build lambda-SOAP vector
lsoap = clebsch_gordan.correlate_density(
    density,
    correlation_order=CG_SETTINGS["correlation_order"],
    angular_cutoff=CG_SETTINGS["angular_cutoff"],
    selected_keys=Labels(
        names=["spherical_harmonics_l", "inversion_sigma"],
        values=np.array(CG_SETTINGS["selected_keys"], dtype=np.int32),
    ),
    skip_redundant=CG_SETTINGS["skip_redundant"],
)

# Check the resulting structure indices match those in `idxs`
assert np.all(
    np.sort(idxs)
    == metatensor.unique_metadata(lsoap, "samples", "structure").values.reshape(-1)
)

# Split into per-structure TensorMaps and save into separate directories.
# This is useful for batched training.
for A in idxs:
    lsoap_A = metatensor.slice(
        lsoap,
        "samples",
        labels=Labels(names="structure", values=np.array([A]).reshape(-1, 1)),
    )
    metatensor.save(os.path.join(processed_dir(A, ri_restart_idx), "lsoap.npz"), lsoap_A)

The settings used to build the descriptor from an `.xyz` file, as well as build
the desired target property (the real-space scalar field) from the model
prediction need to be stored so that the perform can make a true end-to-end
prediction.

In the `rholearn` module "predictor.py", the functions `descriptor_builder` and
`target_builder` are implemented to perform these transformations on the input
and output side of the model, respectively. Both take a
variable input that depends on the structure being predicted on, and some
settings for performing the relevant transformations of the data. The key
physics-related settings used in predicting on an unseen structure shouldbe the
same as were used for generating the data the model was trained on.

Here we store the relevant settings in dictionaries, which will be used to
initialize the ML model later.

## Build cross-validation `IndexedDataset` objects

* For cross-validation we create a train-test-val split of the data.
* Data is stored on the per-structure basis to help with mini-batching in
  training.
* `metatensor-learn` is used here to build the datasets and dataloaders

In [None]:
from settings import CROSSVAL_SETTINGS, ML_SETTINGS, TORCH_SETTINGS

# Perform a train/test/val split of structure idxs
train_idxs, test_idxs, val_idxs = data.group_idxs(
    idxs=idxs,
    n_groups=CROSSVAL_SETTINGS["n_groups"],
    group_sizes=CROSSVAL_SETTINGS["group_sizes"],
    shuffle=CROSSVAL_SETTINGS["shuffle"],
    seed=DATA_SETTINGS["seed"],
)
print(
    "num train_idxs:",
    len(train_idxs),
    "   num test_idxs:",
    len(test_idxs),
    "   num val_idxs:",
    len(val_idxs),
)
np.savez(
    os.path.join(ML_DIR, "idxs.npz"),
    idxs=idxs,
    train_idxs=train_idxs,
    test_idxs=test_idxs,
    val_idxs=val_idxs,
)

* Visualise the dataset again, this time colored by its cross-validation category

In [None]:
crossval_category = lambda A: 0 if A in train_idxs else (1 if A in test_idxs else 2)

chemiscope.show(
    frames,
    properties={
        "Mean O-H bond length, Angstrom": [
            np.mean([frame.get_distance(0, 1), frame.get_distance(0, 2)])
            for A, frame in zip(idxs, frames)
        ],
        "H-O-H angle, degrees": [
            frame.get_angle(1, 0, 2) for A, frame in zip(idxs, frames)
        ],
        "0: train, 1: test, 2: val": [crossval_category(A) for A, frame in zip(idxs, frames)]
    },
)

In [None]:
from metatensor.learn.data import Dataset, DataLoader, IndexedDataset
from settings import CROSSVAL_SETTINGS, TORCH_SETTINGS

def load_to_torch(
    path: str, torch_kwargs: dict, drop_blocks: bool = False
) -> TensorMap:
    """Loads a TensorMap from file and converts its backend to torch"""
    tensor = metatensor.io.load_custom_array(
        path,
        create_array=metatensor.io.create_torch_array,
    )
    if drop_blocks:
        tensor = metatensor.drop_blocks(
            tensor, 
            keys=Labels(
                names=["spherical_harmonics_l", "species_center"],
                values=np.array([[5, 1]]),
            ),
        )
    tensor = tensor.to(**torch_kwargs)
    tensor = metatensor.requires_grad(tensor, True)
    return tensor


train_dataset, test_dataset, val_dataset = [
    IndexedDataset(
        sample_ids=subset_idxs,
        frames=[all_frames[A] for A in subset_idxs],
        descriptors=[
            load_to_torch(
                os.path.join(processed_dir(A, ri_restart_idx), "lsoap.npz"), 
                TORCH_SETTINGS, 
                drop_blocks=True
            )
            for A in subset_idxs
        ],
        targets=[
            load_to_torch(os.path.join(processed_dir(A, ri_restart_idx), "ri_coeffs.npz"), TORCH_SETTINGS) 
            for A in subset_idxs
        ],
        auxiliaries=[
            load_to_torch(os.path.join(processed_dir(A, ri_restart_idx), "ri_ovlp.npz"), TORCH_SETTINGS) 
            for A in subset_idxs
        ],
    ) for subset_idxs in [train_idxs, test_idxs, val_idxs]
]

## Initialize model

* As typically the magnitude of the invariant RI-coefficients are many orders of
  magnitude larger than the covariant coefficients, it often helps model
  training if the invariant block models predict a baselined quantity. The
  baseline is then added back in on the TensorMap level before loss evaluation,
  acting as a kind of non-learnable bias.

* We can compute the mean features of the invariant blocks of the training data
  and initialize the model with this. Alternatively, one could use free-atom
  superpositions of the scalar field of interest as the baseline.

* We can also calculate the standard deviation of the training data. This is
  defined relative to this invariant baseline.

In [None]:
if ML_SETTINGS["model"]["use_invariant_baseline"]:
    invariant_baseline = data.get_dataset_invariant_means(
        train_dataset, field="targets", torch_kwargs=TORCH_SETTINGS,
    )
else:
    invariant_baseline = None

stddev = data.get_standard_deviation(
    dataset=train_dataset,
    field="targets",
    torch_kwargs=TORCH_SETTINGS,
    invariant_baseline=invariant_baseline,
    overlap_field="auxiliaries",
)
np.savez(os.path.join(ML_DIR, "stddev.npz"), stddev=stddev.detach().numpy())
stddev

* In order to perform end-to-end predictions, we also need to initialize the
  model with the settings used to build the equivariant descriptor and final
  target property. These will be the same as above, used to generate the
  training data.

In [None]:
from settings import RASCAL_SETTINGS, CG_SETTINGS, BASE_AIMS_KWARGS

# For descriptor building, we need to store the rascaline settings for
# generating a SphericalExpansion and performing Clebsch-Gordan combinations.
# The `descriptor_builder` function in "predictor.py" contains the 'recipe' for
# using these settings to transform an ASE Atoms object.
descriptor_kwargs = {
    "rascal_settings": RASCAL_SETTINGS,
    "cg_settings": CG_SETTINGS,
}

# For target building, the base AIMS settings need to be stored, along with the
# basis set definition.
basis_set = io.unpickle_dict(os.path.join(processed_dir(A, ri_restart_idx), "calc_info.pickle"))["basis_set"]

target_kwargs = {
    "aims_kwargs": {**BASE_AIMS_KWARGS},
    "basis_set": {**basis_set},
}

In [None]:
# Initialize model
model = models.RhoModel(
    # Standard model architecture
    model_type=ML_SETTINGS["model"]["model_type"],  # "linear" or "nonlinear"
    input=train_dataset[0].descriptors,   # example input data for init metadata
    output=train_dataset[0].targets,  # example output data for init metadata
    bias_invariants=ML_SETTINGS["model"]["bias_invariants"],

    # Architecture settings if using a nonlinear base model
    hidden_layer_widths=ML_SETTINGS["model"].get("hidden_layer_widths"),
    activation_fn=ML_SETTINGS["model"].get("activation_fn"),
    bias_nn=ML_SETTINGS["model"].get("bias_nn"),

    # Invariant baselining
    invariant_baseline=invariant_baseline,

    # Settings for descriptor/target building
    descriptor_kwargs=descriptor_kwargs,
    target_kwargs=target_kwargs,

    # Torch tensor settings
    **TORCH_SETTINGS
)

* Let's update the model with some settings needed to rebuild the
  real-space density from RI-coefficients in AIMS, needed for computing the L1 error.

In [None]:
from settings import AIMS_PATH, REBUILD_KWARGS, SBATCH_KWARGS

# Update the AIMS and SBATCH kwargs
tmp_aims_kwargs = {**model.target_kwargs["aims_kwargs"]}
tmp_aims_kwargs.update(REBUILD_KWARGS)
model.update_target_kwargs(
    {
        "aims_path": AIMS_PATH,
        "aims_kwargs": tmp_aims_kwargs,
        "sbatch_kwargs": SBATCH_KWARGS,
    }
)

## Initialize training objects: loaders, loss, optimizer

In [None]:
import metatensor.learn
from metatensor.learn.data import group

# Construct dataloaders
train_loader = metatensor.learn.data.DataLoader(
    train_dataset,
    collate_fn=group,
    batch_size=ML_SETTINGS["loading"]["batch_size"],
    **ML_SETTINGS["loading"]["args"],
)

val_loader = metatensor.learn.data.DataLoader(
    val_dataset,
    collate_fn=group,
    batch_size=ML_SETTINGS["loading"]["batch_size"],
    **ML_SETTINGS["loading"]["args"],
)

test_loader = metatensor.learn.data.DataLoader(
    test_dataset,
    collate_fn=group,
    batch_size=ML_SETTINGS["loading"]["batch_size"],
    **ML_SETTINGS["loading"]["args"],
)

In [None]:
# Initialize loss fxn and optimizer (don't use a scheduler for now)
loss_fn = loss.L2Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = None

## Run model training

* We train the model by gradient descent, evaluating the loss of the prediction
  against the density-fitted quantity. As the learning target is the real-space
  density, which has been expanded on a non-orthogonal basis, and not just the
  RI coefficients, the overlap matrices must be used in loss evaluation:

* For a structure $A$, the L2 loss of the predicted scalar field relative to the **RI
  approximated scalar field** is defined as:

  * $ \mathcal{L}_A^{\text{L2}} = \left( \textbf{c}_A^{\text{ML}} - \textbf{c}_A^{\text{RI}} \right) \ . \ \hat{S}_A \ . \ \left( \textbf{c}_A^{\text{ML}} - \textbf{c}_A^{\text{RI}} \right)  $


* And L1 (MAE) error of the of the ML prediction relative to the **true QM
  scalar field** is defined as: 

  * $ \mathcal{L}_A^{\text{L1}} = \frac{1}{\int d\textbf{r} \rho_A^{\text{QM}} (\textbf{r})} \int d\textbf{r} \vert \rho_A^{\text{ML}} (\textbf{r}) - \rho_A^{\text{QM}} (\textbf{r}) \vert $
  
* While the L2 loss is used to train the model, we will also periodically
  calculate the L1 loss to assess the performance of the model relative to the
  actual true QM scalar field.

In [None]:
from rholearn import train
from settings import ML_SETTINGS

# Define a log file for writing losses at each epoch
log_path = os.path.join(ML_DIR, "training.log")
io.log(log_path, "# epoch train_loss val_loss test_error time")

# Run training loop
for epoch in range(1, ML_SETTINGS["training"]["n_epochs"] + 1):

    # ====== Training step ======
    t0 = time.time()
    train_loss_epoch, val_loss_epoch = train.training_step(
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        scheduler=scheduler,
        check_args=epoch == 1  # Check metadata only on 1st epoch
    )

    # ====== Evaluation step ======
    test_error_epoch = np.nan
    if epoch % ML_SETTINGS["evaluation"]["interval"] == 0:
        loaders = {"train": train_loader, "val": val_loader, "test": test_loader}
        test_error_epoch = train.evaluation_step(
            model,
            dataloader=loaders[ML_SETTINGS["evaluation"]["subset"]],
            save_dir=partial(eval_dir, epoch=epoch),
            calculate_error=ML_SETTINGS["evaluation"]["calculate_error"],
            target_type=ML_SETTINGS["evaluation"]["target_type"],
            reference_dir=partial(ri_dir, restart_idx=ri_restart_idx),
        )

    # ====== Log results ======
    dt = time.time() - t0
    io.log(
        log_path,
        f"{epoch} {train_loss_epoch} {val_loss_epoch} {test_error_epoch} {dt}",
    )
    if epoch % ML_SETTINGS["training"]["save_interval"] == 0:
        torch.save(model, os.path.join(chkpt_dir, f"model_{epoch}.pt"))
        torch.save(optimizer.state_dict, os.path.join(chkpt_dir, f"opt_{epoch}.pt"))

* This process is slow in a notebook - better to make use of HPC resources! The
  training procedure can essentially be copied to a seperate python script and
  run using a job scheduler - see "run_training.py".

* We can load a model that has been pre-trained and validate its performance

* Here we load a model trained to over 1500 epochs, only a dataset of only 10
  water molecules.

## Evaluate model performance on test set

In [None]:
# Load model from checkpoint and the training log file
model = torch.load(os.path.join(chkpt_dir, "model_25.pt"))
log_file = np.loadtxt(os.path.join(ML_DIR, "training.log"))

In [None]:
# Unpack data from each row
epochs, train_loss, val_loss, test_error, times = losses.T

# Plot the various errors
fig, ax = plt.subplots()
ax.plot(epochs, train_loss, label="L2 Loss (ML/RI), train")
ax.plot(epochs, val_loss, label="L2 Loss (ML/RI), val")
ax.scatter(
    epochs, test_error / 100, label="L1 Error (ML/QM), train", marker="."
)

ax.set_xlabel("Epochs")
ax.set_ylabel("Error per structure")
ax.set_yscale("log")
ax.set_title("HOMO-learning w/ NN | water monomer | AIMS cluster calculation")
ax.legend()

## End-to-end prediction on the test set

* Now let's make an end-to-end prediction (from xyz -> real-space scalar field)
  on the test set, up until now unseen by the model

In [None]:
# Make a directory for storing the prediction
pred_dir = lambda A: os.path.join(ML_DIR, "evaluation", "final", f"{A}")
test_sample = test_dataset[0]

model.target_kwargs["aims_kwargs"]["output"] = ["cube ri_fit"]
model.target_kwargs["aims_kwargs"]["ri_fit_write_rebuilt_field_cube"] = True

predictions = model.predict(
    frames=[test_sample.frames],
    structure_idxs=[test_sample.sample_id],
    build_targets=True,
    return_targets=True,
    save_dir=pred_dir,
)

In [None]:
# Calculate MAE
grid = np.loadtxt(  # integration weights
    os.path.join(ri_dir(test_sample.sample_id, ri_restart_idx), "partition_tab.out")
    )
target = np.loadtxt(  # target scalar field
    os.path.join(ri_dir(test_sample.sample_id, ri_restart_idx), f"rho_ref.out")
)
percent_mae = aims_parser.get_percent_mae_between_fields(  # calc MAE
    input=predictions[0],
    target=target,
    grid=grid,
)
percent_mae

* Let's visualize these predictions using the cube file outputs

* Pick one validation structure and vizualize the target QM and predicted ML
  HOMO scalar fields

In [None]:
import py3Dmol

# Visualize the predicted density of a single example structure
A = test_sample.sample_id

qm_cube = os.path.join(ri_dir(A, ri_restart_idx), "rho_ref.cube")
ml_cube = os.path.join(pred_dir(A), "rho_rebuilt.cube")

for cube_file in [qm_cube, ml_cube]:
    v = py3Dmol.view()
    v.addModelsAsFrames(open(cube_file, "r").read(), "cube")
    v.setStyle({"stick": {}})
    v.addVolumetricData(
        open(cube_file, "r").read(),
        "cube",
        {"isoval": 0.002, "color": "blue", "opacity": 0.8},
    )
    v.show()