## Pre-requisites

### Requirements

* python > 3.10

### Required packages

For stability, it is best to install `rascaline` and `rho_learn` from specific commits.

1. `pip install ase chemiscope metatensor torch==2.1.0 wigners`
1. `pip install git+https://github.com/luthaf/rascaline.git@b2cedfe870541e6d037357db58de1901eb116c41`
1. `pip install git+https://github.com/jwa7/rho_learn.git@9139fd560ff663236bc416621118fb0de9a36009`

**NOTE**: this notebook has only been tested on the HPC system "jed". Due to
specific compilation of the quantum chemistry package `FHI-aims`, the Python
interface to running QC calculations may not be generalized to other operating
systems or hardware.

In [None]:
%load_ext autoreload
%autoreload 2

# Useful standard and scientific ML libraries
import os
import time
import ase.io
import matplotlib.pyplot as plt
import numpy as np
import py3Dmol
import torch

# M-Stack packages
import metatensor   # storage format for atomistic ML
import chemiscope  # interactive molecular visualization
import rascaline   # generating structural representations
from metatensor import Labels, TensorBlock, TensorMap
from rascaline.utils import clebsch_gordan

# Interfacing with FHI-aims
from rhocalc.aims import aims_calc, aims_parser

# Torch-based density leaning
from rholearn import io, data, loss, models, predictor
import settings

## Visualize structures in dataset

* Use `chemiscope`

In [None]:
from settings import TOP_DIR, DATA_DIR, ML_DIR, DATA_SETTINGS

print("Top directory defined as: ", TOP_DIR)

# Load the frames in the complete dataset
all_frames = DATA_SETTINGS["all_frames"]

# Shuffle the total set of structure indices
idxs = np.arange(len(all_frames))
np.random.default_rng(seed=DATA_SETTINGS["seed"]).shuffle(idxs)

# Take a subset of the frames if desired
idxs = idxs[:DATA_SETTINGS["n_frames"]]
frames = [all_frames[A] for A in idxs]

In [None]:
chemiscope.show(
    frames,
    properties={
        "Mean O-H bond length, Angstrom": [np.mean([f.get_distance(0, 1), f.get_distance(0, 2)]) for f in frames],
        "H-O-H angle, degrees": [f.get_angle(1, 0, 2) for f in frames],
    },
)

## Generate learning targets

* These are the RI coefficients of the HOMO eigenstate
* We generate these in 2 steps: 1) SCF and 2) RI fitting

In [None]:
# A callable that takes structure idx as an argument, returns path to AIMS SCF
# output data
def scf_dir(A):
    return os.path.join(DATA_DIR, f"{A}")

# A callable that takes structure idx as an argument, returns path to AIMS RI
# output data
def ri_dir(A, restart_idx):
    return os.path.join(scf_dir(A), f"{restart_idx}")

# A callable that takes structure idx as an argument, returns path to processed
# data (i.e. metatensor-format)
def processed_dir(A, restart_idx):
    return os.path.join(ri_dir(A, restart_idx), "processed")

# Create dirs
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

#### 1. Converge SCF for each structure

In [None]:
# Import the settings needed to run FHI-aims
from settings import AIMS_PATH, BASE_AIMS_KWARGS, SCF_KWARGS, RI_KWARGS, SBATCH_KWARGS

# Build a dict of settings for each calculation (i.e. structure)
# IMPORTANT: zip() is used to pair up the structure index and the structure
calcs = {
    A: {"atoms": frame, "run_dir": scf_dir(A)} for A, frame in zip(idxs, frames)
}

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(SCF_KWARGS)

# Run the SCF in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    AIMS_PATH=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    SBATCH_KWARGS=SBATCH_KWARGS,
    run_dir=scf_dir,
)

In [None]:
# Running this code will ensure that all calculations have finished.
# Should take < 10 seconds (1 node, 10 cores on 'jed' HPC)
all_finished = False
while not all_finished:
    calcs_finished = []
    for A in idxs:
        aims_out_path = os.path.join(scf_dir(A), "aims.out")
        if os.path.exists(aims_out_path):
            with open(aims_out_path, "r") as f:
                calcs_finished.append("Leaving FHI-aims." in f.read())  # AIMS finished
        else:
            calcs_finished.append(False)
    all_finished = np.all(calcs_finished)

# Check SCF converged
converged = []
for A in idxs:
    aims_out_path = os.path.join(scf_dir(A), "aims.out")
    if os.path.exists(aims_out_path):
        with open(aims_out_path, "r") as f:
            converged.append(
                "Self-consistency cycle converged." in f.read()
            )  # calculation converged
    else:
        converged.append(False)
if np.all(converged):
    print("All SCF calculations converged!")
else:
    print(
        "Some SCF calculations did not converge. structure idxs: ",
        [A for i, A in enumerate(idxs) if not converged[i]],
    )

#### 2. Perform RI fitting on the scalar field of interest (HOMO)

* First we need to identify the HOMO
* This can be done by parsing the Kohn Sham orbital information from the
  converged SCF calculation. 
* This info got written to file "ks_orbital_info.out" above by using the
  `ri_fit_write_ks_orb_info: True` keyword set in `SCF_KWARGS`.

In [None]:
for restart_idx in [0]:  # HOMO, LUMO

    # Write KSO weights for just the HOMO to file
    for A in idxs:
        # Parse the Kohn-Sham
        ks_info = aims_parser.get_ks_orbital_info(
            os.path.join(scf_dir(A), "ks_orbital_info.out")
        )
        weights = np.zeros(ks_info.shape[0])
        if restart_idx == 0: # Find HOMO states
            kso_idxs = aims_parser.find_homo_kso_idxs(ks_info)
        elif restart_idx == 1:
            kso_idxs = aims_parser.find_lumo_kso_idxs(ks_info)
        else:
            raise ValueError("...")

        # Add a weighting of 1.0
        for kso_idx in kso_idxs:  # these are 1-indexed in AIMS
            weights[kso_idx - 1] = 1.0

        # Save these to file so AIMS can read them in
        np.savetxt(os.path.join(scf_dir(A), "ks_orbital_weights.in"), weights)

    # And the general settings for all calcs
    aims_kwargs = BASE_AIMS_KWARGS.copy()
    aims_kwargs.update(RI_KWARGS)

    # Run the RI fitting procedure in AIMS
    aims_calc.run_aims_array(
        calcs=calcs,
        AIMS_PATH=AIMS_PATH,
        aims_kwargs=aims_kwargs,
        SBATCH_KWARGS=SBATCH_KWARGS,
        run_dir=scf_dir,
        restart_idx=restart_idx,   # Restart from a converged SCF
        copy_files=["ks_orbital_weights.in"],  # Copy the KS weights into the restart dir
    )

In [None]:
for restart_idx in [0]:
    # Running this code will ensure that all calculations have finished
    # Should take < 30 seconds (1 node, 10 cores, on 'jed' HPC)
    all_finished = False
    while not all_finished:
        calcs_finished = []
        for A in idxs:
            aims_out_path = os.path.join(ri_dir(A, restart_idx), "aims.out")
            if os.path.exists(aims_out_path):
                with open(aims_out_path, "r") as f:
                    calcs_finished.append("Leaving FHI-aims." in f.read())  # AIMS finished
            else:
                calcs_finished.append(False)
        all_finished = np.all(calcs_finished)

    print("All RI calculations finished!")

In [None]:
for restart_idx in [0]:
    # Process AIMS results and print density fitting error
    # Run in serial: takes approx 0.5 seconds per structure
    df_errors = []
    for A, frame in zip(idxs, frames):
        aims_parser.process_aims_ri_results(
            frame=frame,
            aims_output_dir=ri_dir(A, restart_idx),  # includes the restart idx
            process_what=["coeffs", "ovlp"],
            structure_idx=A,
        )
        calc_info = io.unpickle_dict(os.path.join(processed_dir(A, restart_idx), "calc_info.pickle"))
        df_errors.append(calc_info["df_error_percent"]["total"])
    print("Mean density fitting error across all structures (%):", np.mean(df_errors))

## Build Structural Descriptors 

* Here we construct $\lambda$-SOAP equivariant descriptors for each structure
* First, find the angular orders present in the decomposition of the target
  scalar field onto the RI basis

In [None]:
for restart_idx in [0]:
    # The basis set definition should is consistent for all atoms in the dataset,
    # given consistent AIMS settings. Print this definiton from the parsed outputs
    # from one of the structures
    basis_set = io.unpickle_dict(
        os.path.join(processed_dir(A, restart_idx), "calc_info.pickle")
    )["basis_set"]

    print(basis_set["def"])

The maximum angular order here is $l = 8$. This should be reflected in the
settings used to generate the equivariant descriptor - specifically in
`CG_SETTINGS` in "settings.py".

In [None]:
from settings import RASCAL_SETTINGS, CG_SETTINGS

# Generate a rascaline SphericalExpansion (2 body) representation. As we want to
# retain the original structure indices, we are going to pass * all * of the
# 1000 frames to rascaline, but only compute for the subset of structures in
# `idxs`.
calculator = rascaline.SphericalExpansion(**RASCAL_SETTINGS["hypers"])
nu_1_tensor = calculator.compute(
    all_frames, 
    selected_samples=Labels(names=["structure"], values=idxs.reshape(-1, 1)),
    **RASCAL_SETTINGS["compute"],
)
nu_1_tensor = nu_1_tensor.keys_to_properties("species_neighbor")

# Build a lambda-SOAP descriptor by a CLebsch-Gordan combination
lsoap = clebsch_gordan.lambda_soap_vector(nu_1_tensor, **CG_SETTINGS)

# Check the resulting structure indices match those in `idxs`
assert np.all(
    np.sort(idxs)
    == metatensor.unique_metadata(lsoap, "samples", "structure").values.reshape(-1)
)

# Split into per-structure TensorMaps and save into separate directories.
# This is useful for batched training.
for A in idxs:
    lsoap_A = metatensor.slice(
        lsoap,
        "samples",
        labels=Labels(names="structure", values=np.array([A]).reshape(-1, 1)),
    )
    for restart_idx in [0]:
        metatensor.save(os.path.join(processed_dir(A, restart_idx), "lsoap.npz"), lsoap_A)

The settings used to build the descriptor from an `.xyz` file, as well as build
the desired target property (the real-space scalar field) from the model
prediction need to be stored so that the perform can make a true end-to-end
prediction.

In the `rholearn` module "predictor.py", the functions `descriptor_builder` and
`target_builder` are implemented to perform these transformations on the input
and output side of the model, respectively. Both take a
variable input that depends on the structure being predicted on, and some
settings for performing the relevant transformations of the data. The key
physics-related settings used in predicting on an unseen structure shouldbe the
same as were used for generating the data the model was trained on.

Here we store the relevant settings in dictionaries, which will be used to
initialize the ML model later.

## Build `dataset`

* For cross-validation we create a train-test-val split of the data.
* Data is stored on the per-structure basis to help with mini-batching in
  training.


* Although we have generated data for both the HOMO and LUMO, let's just learn
  the HOMO. Fix the "restart_idx" to 0 (for the HOMO), then we can build a
  dataset and start training.

In [None]:
# Fix the restart index so that we just learn the HOMO
restart_idx = 0

# Define callable for saving predictions made during runtime
def pred_dir(A):
    return os.path.join(ML_DIR, "predictions", f"{A}")

# Define dir where model checkpoints are saved
chkpt_dir = os.path.join(ML_DIR, "checkpoints")

if not os.path.exists(ML_DIR):
    os.makedirs(ML_DIR)
if not os.path.exists(os.path.join(ML_DIR, "predictions")):
    os.makedirs(os.path.join(ML_DIR, "predictions"))
if not os.path.exists(chkpt_dir):
    os.makedirs(chkpt_dir)

In [None]:
from settings import CROSSVAL_SETTINGS, ML_SETTINGS, TORCH_SETTINGS

# Perform a train/test/val split of structure idxs
train_idxs, test_idxs, val_idxs = data.group_idxs(
    idxs=idxs,
    n_groups=CROSSVAL_SETTINGS["n_groups"],
    group_sizes=CROSSVAL_SETTINGS["group_sizes"],
    shuffle=CROSSVAL_SETTINGS["shuffle"],
    seed=DATA_SETTINGS["seed"],
)
print(
    "num train_idxs:",
    len(train_idxs),
    "   num test_idxs:",
    len(test_idxs),
    "   num val_idxs:",
    len(val_idxs),
)
np.savez(
    os.path.join(ML_DIR, "idxs.npz"),
    idxs=idxs,
    train_idxs=train_idxs,
    test_idxs=test_idxs,
    val_idxs=val_idxs,
)

* Visualise the dataset again, this time colored by its cross-validation category

In [None]:
crossval_category = lambda A: 0 if A in train_idxs else (1 if A in test_idxs else 2)

chemiscope.show(
    frames,
    properties={
        "Mean O-H bond length, Angstrom": [
            np.mean([frame.get_distance(0, 1), frame.get_distance(0, 2)])
            for A, frame in zip(idxs, frames)
        ],
        "H-O-H angle, degrees": [
            frame.get_angle(1, 0, 2) for A, frame in zip(idxs, frames)
        ],
        "0: train, 1: test, 2: val": [crossval_category(A) for A, frame in zip(idxs, frames)]
    },
)

* Although we have generated data for both the HOMO and LUMO, let's just learn
  the HOMO. Fix the "restart_idx" to 0 (for the HOMO), then we can build a
  dataset and start training.

In [None]:
# Construct dataset, defining callables to access the input, output, and overlap
# data from the structure indices
rho_data = data.RhoData(
    idxs=idxs,
    in_path=lambda A: os.path.join(processed_dir(A, restart_idx), "lsoap.npz"),
    out_path=lambda A: os.path.join(processed_dir(A, restart_idx), "ri_coeffs.npz"),
    aux_path=lambda A: os.path.join(processed_dir(A, restart_idx), "ri_ovlp.npz"),
    keep_in_mem=ML_SETTINGS["loading"]["train"]["keep_in_mem"],
    **TORCH_SETTINGS,
)

## Initialize model

* As typically the magnitude of the invariant RI-coefficients are many orders of
  magnitude larger than the covariant coefficients, it often helps model
  training if the invariant block models predict a baselined quantity. The
  baseline is then added back in on the TensorMap level before loss evaluation,
  acting as a kind of non-learnable bias.
* We can compute the mean features of the invariant blocks of the training data
  and initialize the model with this. Alternatively, one could use free-atom
  superpositions of the scalar field of interest as the baseline.

In [None]:
if ML_SETTINGS["model"]["use_invariant_baseline"]:
    invariant_baseline = rho_data.get_invariant_means(
        idxs=train_idxs, which_data="output"
    )
else:
    invariant_baseline = None

for block in invariant_baseline:
    print(block)
    print(block.values)

* We can also calculate the standard deviation of the training data. This is
  defined relative to this invariant baseline.

In [None]:
stddev = rho_data.get_standard_deviation(
    idxs=train_idxs,
    which_data="output",
    invariant_baseline=invariant_baseline,
    use_overlaps=True,
)
np.savez(os.path.join(ML_DIR, "stddev.npz"), stddev=stddev.detach().numpy())
stddev

* In order to perform end-to-end predictions, we also need to initialize the
  model with the settings used to build the equivariant descriptor and final
  target property. These will be the same as above, used to generate the
  training data.

In [None]:
from settings import RASCAL_SETTINGS, CG_SETTINGS, BASE_AIMS_KWARGS

# For descriptor building, we need to store the rascaline settings for
# generating a SphericalExpansion and performing Clebsch-Gordan combinations.
# The `descriptor_builder` function in "predictor.py" contains the 'recipe' for
# using these settings to transform an ASE Atoms object.
descriptor_kwargs = {
    "RASCAL_SETTINGS": RASCAL_SETTINGS,
    "CG_SETTINGS": CG_SETTINGS,
}

# For target building, the base AIMS settings need to be stored, along with the
# basis set definition.
basis_set = io.unpickle_dict(os.path.join(processed_dir(A, restart_idx), "calc_info.pickle"))["basis_set"]

target_kwargs = {
    "aims_kwargs": {**BASE_AIMS_KWARGS},
    "basis_set": {**basis_set},
}

In [None]:
# Initialize model
model = models.RhoModel(
    # Standard model architecture
    model_type=ML_SETTINGS["model"]["model_type"],  # "linear" or "nonlinear"
    input=rho_data[idxs[0]][1],   # example input data for init metadata
    output=rho_data[idxs[0]][2],  # example output data for init metadata
    bias_invariants=ML_SETTINGS["model"]["bias_invariants"],

    # Architecture settings if using a nonlinear base model
    hidden_layer_widths=ML_SETTINGS["model"].get("hidden_layer_widths"),
    activation_fn=ML_SETTINGS["model"].get("activation_fn"),
    bias_nn=ML_SETTINGS["model"].get("bias_nn"),

    # Invariant baselining
    invariant_baseline=invariant_baseline,

    # Settings for descriptor/target building
    descriptor_kwargs=descriptor_kwargs,
    target_kwargs=target_kwargs,

    # Torch tensor settings
    **TORCH_SETTINGS
)

* Let's update the model with some settings needed to rebuild the
  real-space density from RI-coefficients in AIMS, needed for computing the L1 error.

In [None]:
# Settings specific to RI rebuild procedure
RI_KWARGS = {
    # Force no SCF
    "sc_iter_limit": 0,
    "postprocess_anyway": True,
    "ri_fit_assume_converged": True,
    # What we want to do
    "ri_fit_rebuild_from_coeffs": True,
    # What we want to output
    "ri_fit_write_rebuilt_field": True,
    "ri_fit_write_rebuilt_field_cube": True,
    "output": ["cube ri_fit"],  # needed for cube files
}

# Update the AIMS and SBATCH kwargs
tmp_aims_kwargs = {**model.target_kwargs["aims_kwargs"]}
tmp_aims_kwargs.update(RI_KWARGS)

# Settings for slurm
SBATCH_KWARGS = {
    "job-name": "h2o-pred",
    "nodes": 1,
    "time": "01:00:00",
    "mem-per-cpu": 2000,
    "partition": "standard",
    "ntasks-per-node": 10,
}
model.update_target_kwargs(
    {
        "AIMS_PATH": AIMS_PATH,
        "aims_kwargs": tmp_aims_kwargs,
        "SBATCH_KWARGS": SBATCH_KWARGS,
    }
)

## Initialize training objects: loaders, loss, optimizer

In [None]:
# Construct dataloaders
train_loader = data.RhoLoader(
    rho_data,
    idxs=train_idxs,
    get_aux_data=True,  # load the overlap matrix
    batch_size=ML_SETTINGS["loading"]["train"]["batch_size"],
)
test_loader = data.RhoLoader(
    rho_data,
    idxs=test_idxs,
    get_aux_data=True,   # load the overlap matrix
    batch_size=ML_SETTINGS["loading"]["test"]["batch_size"],
)
# For the validation set, we want to evaluate the performance of the model
# against the real-space scalar field (requires calling AIMS), so the overlaps
# do not need to be loaded.
val_loader = data.RhoLoader(
    rho_data,
    idxs=val_idxs,
    get_aux_data=False,
    batch_size=None,
)

In [None]:
# Initialize loss fxn and optimizer (don't use a scheduler for now)
loss_fn = loss.L2Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = None

## Run model training

* We train the model by gradient descent, evaluating the loss of the prediction
  against the density-fitted quantity. As the learning target is the real-space
  density, which has been expanded on a non-orthogonal basis, and not just the
  RI coefficients, the overlap matrices must be used in loss evaluation:

* For a structure $A$, the L2 loss of the predicted scalar field relative to the **RI
  approximated scalar field** is defined as:

  * $ \mathcal{L}_A^{\text{L2}} = \left( \textbf{c}_A^{\text{ML}} - \textbf{c}_A^{\text{RI}} \right) \ . \ \hat{S}_A \ . \ \left( \textbf{c}_A^{\text{ML}} - \textbf{c}_A^{\text{RI}} \right)  $


* And L1 (MAE) error of the of the ML prediction relative to the **true QM
  scalar field** is defined as: 

  * $ \mathcal{L}_A^{\text{L1}} = \frac{1}{\int d\textbf{r} \rho_A^{\text{QM}} (\textbf{r})} \int d\textbf{r} \vert \rho_A^{\text{ML}} (\textbf{r}) - \rho_A^{\text{QM}} (\textbf{r}) \vert $
  
* While the L2 loss is used to train the model, we will also periodically
  calculate the L1 loss to assess the performance of the model relative to the
  actual true QM scalar field.

In [None]:
from rholearn import train
from settings import ML_SETTINGS

# Define a log file for writing losses at each epoch
log_path = os.path.join(ML_DIR, "training.log")
io.log(log_path, "# epoch train_L2_loss test_L2_loss train_L1_error test_L1_error time")

# Run training loop
for epoch in range(1, ML_SETTINGS["training"]["n_epochs"] + 1):
    # Training step
    t0 = time.time()
    train_loss_epoch, test_loss_epoch = train.training_step(
        train_loader=train_loader,
        test_loader=test_loader,
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        scheduler=scheduler,
        check_args=True
        if epoch == 1
        else False,  # Switch off metadata checks after 1st epoch
    )

    # Calculate also the L1 error of train/test predictions against the
    # real-space QM quantity. Store values of -1 if not evaluating the L1 error
    # this epoch
    mean_maes = {"train": -1, "test": -1}
    if (epoch - 1) % ML_SETTINGS["validation"]["interval"] == 0:

        mean_maes = {"train": [], "test": []}
        
        # Get frames and make prediction
        tmp_idxs = np.concatenate([train_idxs, test_idxs])
        tmp_frames = [all_frames[A] for A in tmp_idxs]
        pred_coeffs, pred_fields = model.predict(
            structure_idxs=tmp_idxs,
            frames=tmp_frames,
            build_target=True,
            save_dir=pred_dir,
        )

        # Evaluate mean L1 Error
        for A, pred_field in zip(tmp_idxs, pred_fields):
            # Get grids and check they're the same in the SCF and ML directories
            grid = np.loadtxt(os.path.join(ri_dir(A, restart_idx), "partition_tab.out"))
            assert np.allclose(
                grid,
                np.loadtxt(os.path.join(pred_dir(A), "partition_tab.out")),
            )

            # Get L1 error vs real-space QM scalar field
            target_field = np.loadtxt(os.path.join(ri_dir(A, restart_idx), "rho_ref.out"))
            mae = aims_parser.get_percent_mae_between_fields(
                input=pred_field,
                target=target_field,
                grid=grid,
            )
            if A in train_idxs:
                mean_maes["train"].append(mae)
            elif A in test_idxs:
                mean_maes["test"].append(mae)
        for category in ["train", "test"]:
            mean_maes[category] = np.mean(mean_maes[category])

    dt = time.time() - t0
    print(
        f"epoch {epoch}  ",
        f"train_L2_loss {np.round(train_loss_epoch.detach().numpy(), 5)}  "
        f"test_L2_loss {np.round(test_loss_epoch.detach().numpy(), 5)}  "
        f"time {dt}",
    )
    io.log(
        log_path,
        f"{epoch} {train_loss_epoch} {test_loss_epoch} {mean_maes['train']} {mean_maes['test']} {dt}",
    )
    if epoch % ML_SETTINGS["training"]["save_interval"] == 0:
        torch.save(model, os.path.join(chkpt_dir, f"model_{epoch}.pt"))
        torch.save(optimizer.state_dict, os.path.join(chkpt_dir, f"opt_{epoch}.pt"))

In [None]:
# Load training log
losses = np.loadtxt(os.path.join(ML_DIR, "training.log"))

# Unpack data from each row
epochs, train_loss, test_loss, train_mae, test_mae, times = losses.T

# Plot the various errors
fig, ax = plt.subplots()
ax.plot(epochs, train_loss, label="L2 Loss (ML/RI), train")
ax.plot(epochs, test_loss, label="L2 Loss (ML/RI), test")
ax.scatter(
    epochs[np.where(train_mae != -1)],
    train_mae[np.where(train_mae != -1)] / 100,
    label="L1 Error (ML/QM), train",
    marker=".",
)
ax.scatter(
    epochs[np.where(test_mae != -1)],
    test_mae[np.where(test_mae != -1)] / 100,
    label="L1 Error (ML/QM), test",
    marker=".",
)

ax.set_xlabel("Epochs")
ax.set_ylabel("Error per structure")
ax.set_yscale("log")
ax.set_title("HOMO-learning w/ NN | water monomer | AIMS cluster calculation")
ax.legend()

* This process is slow in a notebook - better to make use of HPC resources! The
  training procedure can essentially be copied to a seperate python script and
  run using a job scheduler - see "run_training.py".

* We can load a model that has been pre-trained and validate its performance

* Here we load a model trained to over 1500 epochs, only a dataset of only 10
  water molecules.

## Evaluate model performance on validation set



In [None]:
# Load pre-trained model and the training log file
model = torch.load(os.path.join(TOP_DIR, "pretrained_models", "model_1611.pt"))
log_file = np.loadtxt(os.path.join(TOP_DIR, "pretrained_models", "training.log"))

In [None]:
# Unpack data from each row
epochs, train_loss, test_loss, train_mae, test_mae, val_mae, times = log_file.T

# Plot the various errors
fig, ax = plt.subplots()
ax.plot(epochs, train_loss, label="L2 Loss (ML/RI), train")
ax.plot(epochs, test_loss, label="L2 Loss (ML/RI), test")
ax.scatter(
    epochs[np.where(train_mae != -1)],
    train_mae[np.where(train_mae != -1)] / 100,
    label="L1 Error (ML/QM), train",
    marker=".",
)
ax.scatter(
    epochs[np.where(test_mae != -1)],
    test_mae[np.where(test_mae != -1)] / 100,
    label="L1 Error (ML/QM), test",
    marker=".",
)

ax.set_xlabel("Epochs")
ax.set_ylabel("Error per structure")
ax.set_yscale("log")
ax.set_title("HOMO-learning w/ NN | water monomer | AIMS cluster calculation")
ax.legend()

## End-to-end prediction on the validation set

* Now let's make an end-to-end prediction (from xyz -> real-space scalar field)
  on the validation set

In [None]:
# Load the validation frames as ASE Atoms objects
val_frames = [all_frames[A] for A in val_idxs]

def val_dir(A):
    return os.path.join(ML_DIR, "validation", f"{A}")

# Make predictions for the validation set. Note: we could predict from the
# descriptors that are already constructed here, but do so from ASE frames for
# demonstrative purposes
pred_coeffs, pred_fields = model.predict(
    structure_idxs=val_idxs, frames=val_frames, build_target=True, save_dir=val_dir
)

In [None]:
# Calculate density fitting error of predictions relative to SCF-converged
# real-space scalar field
val_df_errors = {}
for A, pred_field in zip(val_idxs, pred_fields):
    # Get grids and check they're the same in the SCF and ML directories
    grid = np.loadtxt(os.path.join(ri_dir(A, restart_idx=0), "partition_tab.out"))
    assert np.allclose(grid, np.loadtxt(os.path.join(val_dir(A), "partition_tab.out")))

    # Get DF error
    target_field = np.loadtxt(os.path.join(ri_dir(A, restart_idx=0), "rho_ref.out"))  # QM scalar field
    df_error = aims_parser.get_percent_mae_between_fields(
        input=pred_field,
        target=target_field,
        grid=grid,
    )
    print(f"Val structure {A}, DF error (%): {df_error}")
    val_df_errors[A] = df_error

* Let's visualize these predictions using the cube file outputs

* Pick one validation structure and vizualize the target QM and predicted ML
  HOMO scalar fields

In [None]:
# Visualize the predicted density
A = val_idxs[0]

qm_cube = os.path.join(ri_dir(A, restart_idx=0), "rho_ref.cube")
ml_cube = os.path.join(val_dir(A), "rho_rebuilt.cube")

for cube_file in [qm_cube, ml_cube]:
    v = py3Dmol.view()
    v.addModelsAsFrames(open(cube_file, "r").read(), "cube")
    v.setStyle({"stick": {}})
    v.addVolumetricData(
        open(cube_file, "r").read(),
        "cube",
        {"isoval": 0.001, "color": "blue", "opacity": 0.8},
    )
    v.show()

* Let's calculate the difference between ML and QM RI-coefficients, and use
  these 'delta-coeffs' to reconstruct a delta-scalar field

In [None]:
# First define a new directory to store the delta HOMOs
if not os.path.exists(os.path.join(ML_DIR, "delta")):
    os.mkdir(os.path.join(ML_DIR, "delta"))

def delta_dir(A):
    return os.path.join(ML_DIR, "delta", f"{A}")

In [None]:
# Retrieve the target coeffs
target_coeffs = [
    metatensor.load(os.path.join(processed_dir(A, restart_idx=0), "ri_coeffs.npz"))
    for A in val_idxs
]

# Calculate the delta coeffs
delta_coeffs = [
    metatensor.subtract(pred_coeff, target_coeff) for pred_coeff, target_coeff in zip(pred_coeffs, target_coeffs)
]

# Build the delta-HOMO from RI coefficients using AIMS
delta_fields = predictor.target_builder(
    structure_idxs=val_idxs, 
    frames=val_frames, 
    predictions=delta_coeffs,
    save_dir=delta_dir, 
    **model.target_kwargs
)

In [None]:
# Visualize the delta-HOMO of one of the validation structures
A = val_idxs[0]
delta_cube = os.path.join(delta_dir(A), "rho_rebuilt.cube")

for cube_file in [delta_cube]:
    v = py3Dmol.view()
    v.addModelsAsFrames(open(cube_file, "r").read(), "cube")
    v.setStyle({"stick": {}})
    v.addVolumetricData(
        open(cube_file, "r").read(),
        "cube",
        {"isoval": 0.001, "color": "blue", "opacity": 0.8},
    )
    v.show()