# How to Train Your Model (Live Demo #1)

![M-stack ecosystem](../figures/m_stack_ecosystem.png)

## 1. Import M-stack packages

In [None]:
# Useful standard and scientific ML libraries
import os
import ase.io
import matplotlib.pyplot as plt
import numpy as np
import pyscf
import py3Dmol
import torch

# M-Stack packages
import equistore   # storage format for atomistic ML
import chemiscope  # interactive molecular visualization
import rascaline   # generating structural representations
import qstack      # quantum chemistry toolkit

# Torch-based density leaning
from rholearn import io, features, loss, plots, predictor, pretraining, training, utils
from settings import RASCAL_HYPERS, DATA_SETTINGS, ML_SETTINGS

## 2. Visualize and explore dataset: `chemiscope`

* `chemiscope` is an interactive structure and property viewer

In [None]:
# Read the water molecules from file
frames = ase.io.read(os.path.join(DATA_SETTINGS["data_dir"], "water_monomers_1k.xyz"), index=":")

# Display molecules with chemiscope
chemiscope.show(
    frames,
    properties={
        "Mean O-H bond length, Angstrom": [np.mean([f.get_distance(0, 1), f.get_distance(0, 2)]) for f in frames],
        "H-O-H angle, degrees": [f.get_angle(1, 0, 2) for f in frames],
    },
)

## 3. Generate $\lambda$-SOAP equivariant structural representation: `rascaline` + `equistore`

In [None]:
# Compute lambda-SOAP: uses rascaline to compute a SphericalExpansion (~ 25 secs)
print("Computing lambda-SOAP representation...")
input = features.lambda_soap_vector(
    frames, RASCAL_HYPERS, even_parity_only=True
)
# Drop the block for l=5, Hydrogen as this isn't included in the output electron density
input = equistore.drop_blocks(input, keys=equistore.Labels(input.keys.names, np.array([[5, 1]])))

# Drop the first 199 structures
input = equistore.slice(input, samples=equistore.Labels(["structure"], np.arange(199, 1000).reshape(-1, 1)))

print("Done.")

# Save lambda-SOAP and hypers to file
equistore.save(os.path.join(DATA_SETTINGS["data_dir"], "lambda_soap.npz"), input)
io.pickle_dict(os.path.join(DATA_SETTINGS["data_dir"], "rascal_hypers.pickle"), RASCAL_HYPERS)

In [None]:
# Inspect lambda-SOAP descriptor
input

In [None]:
# Inspect metadata for block l = 4, oxygen
block = input.block(spherical_harmonics_l=4, species_center=8)
block

In [None]:
# Inspect data for block l = 4, oxygen
block.values[0]

## 4. Load reference electron density coefficients: `Q-stack` + `equistore`

In [None]:
# Load the electron density data
output = equistore.load(os.path.join(DATA_SETTINGS["data_dir"], "e_densities.npz"))

# Drop the first 199 structures of the output t
output = equistore.slice(output, samples=equistore.Labels(["structure"], np.arange(199, 1000).reshape(-1, 1)))

In [None]:
output

In [None]:
# Inspect block for l = 4, oxygen
output.block(spherical_harmonics_l=4, species_center=1).samples.shape

## 5. Prepare data: `equistore`

### Train-test-validation split

In [None]:
from equisolve.utils import split_data

# Check metadata is consistent between input and output
assert equistore.equal_metadata(input, output, check=["samples", "components"])

# Split the data into training, validation, and test sets
[[in_train, in_test, in_val], [out_train, out_test, out_val]], grouped_labels = split_data(
    [input, output],
    axis=DATA_SETTINGS["axis"],
    names=DATA_SETTINGS["names"],
    n_groups=DATA_SETTINGS["n_groups"],
    group_sizes=DATA_SETTINGS["group_sizes"],
    seed=DATA_SETTINGS["seed"],
)
tm_files = {
    "in_train.npz": in_train,
    "in_test.npz": in_test,
    "out_train.npz": out_train,
    "out_test.npz": out_test,
    "in_val.npz": in_val,
    "out_val.npz": out_val,
}
# Save the TensorMaps to file
for name, tm in tm_files.items():
    equistore.save(os.path.join(DATA_SETTINGS["data_dir"], name), tm)
    
print(f"Data split sizes:\n\ntrain: {len(grouped_labels[0])}\ntest: {len(grouped_labels[1])}\nvalidation: {len(grouped_labels[2])}")

### Prepare Run Directory

In [None]:
# Create simulation run directory and save simulation
io.check_or_create_dir(ML_SETTINGS["run_dir"])
io.pickle_dict(os.path.join(ML_SETTINGS["run_dir"], "train_settings.pickle"), ML_SETTINGS)

# IMPORTANT! - set the torch default dtype
torch.set_default_dtype(ML_SETTINGS["torch"]["dtype"])

# Pre-construct the appropriate torch objects (i.e. models, loss fxns)
pretraining.construct_torch_objects_in_train_dir(
    DATA_SETTINGS["data_dir"], ML_SETTINGS["run_dir"], ML_SETTINGS, 
)

print(f"Simulation directory prepared at:\n\n{ML_SETTINGS['run_dir']}")

## 6. Train model: `equistore` interfacing with PyTorch

In [None]:
# Define the training subdirectory
train_rel_dir = ""
train_run_dir = os.path.join(ML_SETTINGS["run_dir"], train_rel_dir)

# Load training data and torch objects
data, model, loss_fn, optimizer = pretraining.load_training_objects(
    train_rel_dir, DATA_SETTINGS["data_dir"], ML_SETTINGS, ML_SETTINGS["training"]["restart_epoch"]
)

# Unpack the data
in_train, in_test, out_train, out_test = data

# Execute model training
print(f"\nTraining in subdirectory:\n\n{train_run_dir}\n")
training.train(
    in_train=in_train,
    out_train=out_train,
    in_test=in_test,
    out_test=out_test,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    n_epochs=ML_SETTINGS["training"]["n_epochs"],
    save_interval=ML_SETTINGS["training"]["save_interval"],
    save_dir=train_run_dir,
    restart=ML_SETTINGS["training"]["restart_epoch"],
    print_level=1,
)

### Loss vs epoch plot

In [None]:
# Load the train and test losses
losses = np.load(os.path.join(ML_SETTINGS["run_dir"], "losses.npz"))

# Plot losses
fig, ax = plt.subplots(1, 1, sharey=True)
ax.loglog(
    losses["train"] / 500, 
    label="linear, train", 
    color="blue",
)
ax.loglog(
    losses["test"] / 300, 
    label="linear, test", 
    color="blue",
    linestyle="dashed"
)
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss per structure")
ax.set_ylim(1e-5, 1e-2)
ax.legend()

## 7. Make a prediction on the validation structure

In [None]:
# Load the input and output validation TensorMaps
in_val = io.load_tensormap_to_torch(
    os.path.join(DATA_SETTINGS["data_dir"], "in_val.npz"), **ML_SETTINGS["torch"]
)
out_val = equistore.load(os.path.join(DATA_SETTINGS["data_dir"], "out_val.npz"))

# Retrieve the unique structure
val_idx = equistore.unique_metadata(in_val, axis="samples", names="structure")[0][0]
val_frame = ase.io.read(
    os.path.join(DATA_SETTINGS["data_dir"], "water_monomers_1k.xyz"), index=val_idx
)

# Build a pyscf Molecule object
val_mol = pyscf.gto.Mole().build(
    atom=[
        (i, j) for i, j in zip(val_frame.get_chemical_symbols(), val_frame.positions)
    ],
    basis="ccpvqz jkfit",
)

# Predict the density
out_val_pred, coeffs = predictor.predict_density_from_mol(
    in_val,
    val_mol,
    model_path=os.path.join(ML_SETTINGS["run_dir"], "epoch_10", "model.pt"),
    inv_means_path=os.path.join(DATA_SETTINGS["data_dir"], "inv_means.npz"),
)

### Parity plot: target vs predicted coefficients

In [None]:
# Calculate the MSE Error
with torch.no_grad():
    val_loss = loss.MSELoss(reduction="sum")(
        utils.tensor_to_torch(out_val, **ML_SETTINGS["torch"]), 
        utils.tensor_to_torch(out_val_pred, **ML_SETTINGS["torch"])
    ).detach().numpy()

# Plot the target vs predicted coefficients, standardized
fig, ax = plots.parity_plot(
    target=out_val,
    predicted=out_val_pred,
    color_by="spherical_harmonics_l",
)
lim = [-0.05, 0.1]
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_aspect("equal")
ax.set_xlabel("target density coefficient")
ax.set_ylabel("predicted density coefficient")
ax.set_title(f"Validation MSE Error: {round(val_loss * 1e6, 3)}"r" $\times 10^{-6}$")
ax.legend()

# 8. Process densities with `Q-stack` and visualize

In [None]:
# Build a delta density TensorMap
out_val_delta = equistore.abs(equistore.subtract(out_val_pred, out_val))

# Vectorize the coefficients from each of the TensorMaps
new_key_names = ["spherical_harmonics_l", "element"]
vect_coeffs_target = qstack.equio.tensormap_to_vector(
    val_mol,
    utils.rename_tensor(
        utils.drop_metadata_name(out_val, "samples", "structure"),
        keys_names=new_key_names,
    ),
)
vect_coeffs_input = qstack.equio.tensormap_to_vector(
    val_mol,
    utils.rename_tensor(
        utils.drop_metadata_name(out_val_pred, "samples", "structure"),
        keys_names=new_key_names,
    ),
)
vect_coeffs_delta = qstack.equio.tensormap_to_vector(
    val_mol,
    utils.rename_tensor(
        utils.drop_metadata_name(out_val_delta, "samples", "structure"),
        keys_names=new_key_names,
    ),
)

# Convert the basis function coefficients to a cube file
plot_dir = os.path.join(ML_SETTINGS["run_dir"], "plots")
io.check_or_create_dir(plot_dir)
n = 60  # grid points per dimension
for (coeffs, filename) in [
    (vect_coeffs_target, "out_val.cube"),
    (vect_coeffs_input, "out_val_pred.cube"),
    (vect_coeffs_delta, "out_val_delta.cube"),
]:
    qstack.fields.density2file.coeffs_to_cube(
        val_mol,
        coeffs,
        os.path.join(plot_dir, filename),
        nx=n,
        ny=n,
        nz=n,
        resolution=None,
    )

### Predicted electron density

In [None]:
# Visualize the predicted density
v = py3Dmol.view()
v.addModelsAsFrames(open(os.path.join(plot_dir, "out_val_pred.cube"), "r").read(), "cube")
v.setStyle({"stick": {}})
v.addVolumetricData(
    open(os.path.join(plot_dir, "out_val_pred.cube"), "r").read(),
    "cube",
    {"isoval": 0.05, "color": "blue", "opacity": 0.8},
)
v.show()

### "Delta electron density" - i.e. the ML error (100x magnification)

In [None]:
# Visualize the delta density
v = py3Dmol.view()
v.addModelsAsFrames(open(os.path.join(plot_dir, "out_val_delta.cube"), "r").read(), "cube")
v.setStyle({"stick": {}})
v.addVolumetricData(
    open(os.path.join(plot_dir, "out_val_delta.cube"), "r").read(),
    "cube",
    {"isoval": 0.0005, "color": "blue", "opacity": 0.8},
)
v.show()

# Extra Material

* DIY session at poster #14

* Torch-based electron density learning at https://github.com/m-stack-org/rho_learn ...

* ... with examples/tutorials for:

    * water
    
    * azoswitch molecules
    
    
![azoswitch density](../figures/azoswitch_density.png)