# How to Train Your Model (Live Demo #1)

## 1. Import M-stack packages

In [1]:
# %load_ext autoreload
# %autoreload 2

# %load_ext line_profiler

# Useful standard and scientific ML libraries
import os
import time
import ase.io
import matplotlib.pyplot as plt
import numpy as np
# import pyscf
# import py3Dmol
import torch

# M-Stack packages

import metatensor   # storage format for atomistic ML
import chemiscope  # interactive molecular visualization
import rascaline   # generating structural representations
# import qstack      # quantum chemistry toolkit

from metatensor import Labels, TensorBlock, TensorMap
from rascaline.utils import clebsch_gordan, old_clebsch_gordan, rotations

# Torch-based density leaning
from rholearn import io, data, loss, models, training, utils
from settings import lsoap_settings, data_settings, ml_settings, torch_settings

ModuleNotFoundError: No module named 'metatensor.core'

In [4]:
4%3

1

## Generate $\lambda$-SOAP equivariant structural representation

In [None]:
frames = ase.io.read("/Users/joe.abbott/Documents/phd/code/rho/rho_learn/docs/example/new_water/data/coords_1k.xyz", ":")
chemiscope.show(frames, mode="structure")

In [None]:
frames = ase.io.read("data/water_monomers_1k.xyz", ":")
lsoap = clebsch_gordan.lambda_soap_vector(
    frames,
    rascal_hypers=lsoap_settings["rascal_hypers"],
    lambda_filter=lsoap_settings["lambdas"],
    sigma_filter=lsoap_settings["sigmas"],
    lambda_cut=lsoap_settings["lambda_cut"],
)

In [None]:
for i in range(1000):
    # os.mkdir(f"data/lsoap/{i}")
    

In [None]:
coeffs = metatensor.join(metatensor.load("data/dft/0/ri_coeffs.npz")
coeffs.block(spherical_harmonics_l=4, species_center=8)

In [None]:
input = metatensor.core.io.load_custom_array(
    "/Users/joe.abbott/Documents/phd/code/rho/rho_learn/docs/example/water/data/lsoap/0/x.npz",
    create_array=metatensor.core.io.create_torch_array,
)
output = metatensor.core.io.load_custom_array(
    "/Users/joe.abbott/Documents/phd/code/rho/rho_learn/docs/example/water/data/rho/0/c.npz",
    create_array=metatensor.core.io.create_torch_array,
)
inv_means = metatensor.core.io.load_custom_array(
    "/Users/joe.abbott/Documents/phd/code/rho/rho_learn/docs/example/water/data/rho/inv_means.npz",
    create_array=metatensor.core.io.create_torch_array,
)

torch_settings = {
    "dtype": torch.float64,
    "requires_grad": True,
    "device": torch.device(type="cpu"),
}

model = models.RhoModel(
    model_type="nonlinear",
    input=input,
    output=output,
    bias_invariants=True,
    hidden_layer_widths=[128, 56],
    activation_fn=torch.nn.Tanh(),
    out_train_inv_means=inv_means,
    **torch_settings
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100])

TO_TORCH_ATTRS = ["_in_metadata", "_out_metadata", "_out_train_inv_means"]

# Convert each attribute to torch
# if model._out_train_inv_means is None:
#     attrs = TO_TORCH_ATTRS[:2]
# else:
#     attrs = TO_TORCH_ATTRS
# for attr in attrs:
#     setattr(
#         model,
#         attr,
#         metatensor.to(
#             getattr(model, attr),
#             "torch",
#             dtype=model._torch_settings["dtype"],
#             device=model._torch_settings["device"],
#         ),
#     )

In [None]:
models.RhoModel?


In [None]:
torch.save(model, "model.pt")

In [None]:
chkpt_dict = {
        "model": model,
        "optimizer": optimizer,
    }
    if scheduler is not None:
        chkpt_dict.update({"scheduler": scheduler})
    torch.save(chkpt_dict, "chekpoint.pt")

In [None]:
torch.load("checkpoint.pt")

In [None]:
model.models[:3]

In [None]:
torch.save(model, "model.pt")
model(input).block(0).values

In [None]:
model2 = load_rho_model("model.pt")
# model2.load_to_torch(
#     dtype=torch.float64, requires_grad=True, device=torch.device("cpu")
# )
model2(input).block(0).values

In [None]:
model._torch_settings

## 2. Visualize and explore dataset: `chemiscope`

In [None]:
# Read the water molecules from file
# frames = ase.io.read(os.path.join(data_settings["data_dir"], "water_monomers_1k.xyz"), index=":")
# structure_idxs = np.arange(len(frames))

# Display molecules with chemiscope
# chemiscope.show(
#     frames,
#     properties={
#         "Mean O-H bond length, Angstrom": [np.mean([f.get_distance(0, 1), f.get_distance(0, 2)]) for f in frames],
#         "H-O-H angle, degrees": [f.get_angle(1, 0, 2) for f in frames],
#     },
# )

## 3. Generate $\lambda$-SOAP descriptors

### Calculate

In [None]:
# from rascaline.utils import clebsch_gordan, old_clebsch_gordan
# ...
# rascal_hypers = {...}
# lambdas = np.arange(rascal_hypers["max_angular"] + 1)

# # Generate lambda-SOAP -- new version
# lsoap = clebsch_gordan.lambda_soap_vector(
#     frames,
#     rascal_hypers=rascal_hypers,
#     lambdas=lambdas,
#     only_keep_parity=+1,
# )

# # Generate lambda-SOAP -- old version
# lsoap = old_clebsch_gordan.lambda_soap_vector(
#     frames,
#     rascal_hypers=rascal_hypers,
#     lambdas=lambdas,
#     only_keep_parity=+1,
# )

### Check the equivariance condition

In [None]:
# # Check equivariance

# # Generate Wigner-D matrices, initialized with random angles
# wig = rotations.WignerDReal(lmax=rascal_hypers["max_angular"])
# print("Random rotation angles (rad):", wig.angles)

# # Apply an O(3) transformation to each frame 
# frames_o3 = [rotations.transform_frame_o3(frame, wig.angles) for frame in frames]
# assert not np.allclose(frames[0].positions, frames_o3[0].positions)

# # Generate lambda-SOAP for the transformed frames
# lsoap_o3 = clebsch_gordan.lambda_soap_vector(
#     frames_o3,
#     rascal_hypers=rascal_hypers,
#     lambdas=lambdas,
#     only_keep_parity=+1,
# )

# # Apply the O(3) transformation to the TensorMap
# lsoap_transformed = wig.transform_tensormap_o3(lsoap)

# # Check for equivariance!
# assert metatensor.equal_metadata(lsoap_transformed, lsoap_o3)
# assert metatensor.allclose(lsoap_transformed, lsoap_o3)
# print("O(3) EQUIVARIANT!")

### Save

In [None]:
# # Create a dir for lambda-SOAP
# lsoap_dir = os.path.join(data_settings["data_dir"], "lsoap")
# if not os.path.exists(lsoap_dir):
#     os.mkdir(path=lsoap_dir)

# # Split into separate TensorMaps for each structure
# lsoap_split = metatensor.split(
#     lsoap,
#     axis="samples",
#     grouped_labels=[
#         Labels(names="structure", values=np.array([A]).reshape(-1, 1))
#         for A in range(n_frames)
#     ],
# )

# # Save the lambda-SOAP features for each structure to a separate dir
# for A, frame in enumerate(frames[:n_frames]):
#     # Create dir
#     struct_dir = os.path.join(lsoap_dir, f"{A}")
#     if not os.path.exists(struct_dir):
#         os.mkdir(path=struct_dir)
#     # Save
#     metatensor.save(os.path.join(struct_dir, "x.npz"), lsoap_split[A])

## 4. Split data and standardize invariants

In [None]:
# # Get the grouped indices for train/test(/val) splits
# train_idxs, test_idxs, val_idxs = data.group_idxs(
#     structure_idxs,
#     n_groups=data_settings["n_groups"],
#     group_sizes=data_settings["group_sizes"],
#     shuffle=data_settings["shuffle"],
#     seed=data_settings["seed"],
# )

# # Define new dir for storing standardized features
# rho_std_dir = os.path.join(data_settings["data_dir"], "rho_std")
# if not os.path.exists(rho_std_dir):
#     os.mkdir(rho_std_dir)

# # Save the grouped indices
# np.savez(
#     os.path.join(rho_std_dir, "idxs.npz"),
#     train=train_idxs,
#     test=test_idxs,
#     val=val_idxs,
# )

# # Load all the structures into a single TensorMap
# c_list = [
#     metatensor.load(os.path.join(data_settings["data_dir"], "rho", f"{i}", "c.npz"))
#     for i in structure_idxs
# ]
# c_all = metatensor.join(c_list, axis="samples", remove_tensor_name=True)

# # Split to get another TensroMap with only the training structures
# c_train = metatensor.slice(
#     c_all,
#     axis="samples",
#     labels=Labels(names=["structure"], values=np.array([train_idxs]).reshape(-1, 1)),
# )

# # Get the invariant means and save
# inv_means = features.get_invariant_means(c_train)
# metatensor.save(os.path.join(rho_std_dir, "inv_means.npz"), inv_means)

# # Standardize the invariants of all strutcures
# c_all_std = features.standardize_invariants(c_all, inv_means)

# # Split into individual TensorMaps
# c_all_split = metatensor.split(
#     c_all_std,
#     axis="samples",
#     grouped_labels=[
#         Labels(names=["structure"], values=np.array([i]).reshape(-1, 1))
#         for i in structure_idxs
#     ],
# )

# # Save each structure in a separate directory
# for A, c_std in enumerate(c_all_split):
#     assert c_std.block(0).samples["structure"][0] == A
#     c_dir = os.path.join(rho_std_dir, f"{A}")
#     if not os.path.exists(c_dir):
#         os.mkdir(c_dir)
#     metatensor.save(os.path.join(c_dir, "c.npz"), c_std)

## 5. Build the torch dataset and dataloader

In [None]:
import os
import time
import numpy as np
import torch

import metatensor

from rholearn import io, data, training
from settings import data_settings, ml_settings, torch_settings

In [None]:
# Define all the structure indices
all_idxs = np.arange(1000)

# Get the grouped indices for train/test(/val) splits
train_idxs, test_idxs, val_idxs = data.group_idxs(
    all_idxs=all_idxs,
    n_groups=data_settings["n_groups"],
    group_sizes=data_settings["group_sizes"],
    shuffle=data_settings["shuffle"],
    seed=data_settings["seed"],
)

# Define a training subset
if data_settings["n_train_subsets"]:
    subset_sizes = data.get_log_subset_sizes(len(train_idxs), data_settings["n_train_subsets"])
    train_idxs = train_idxs[:subset_sizes[data_settings.get("i_train_subset")]]

# Define a test subset if not doing test batching
if ml_settings["loading"]["test"]["do_batching"] is False:
    if ml_settings["loading"]["test"]["batch_size"] < len(test_idxs):
        test_idxs = test_idxs[:ml_settings["loading"]["test"]["batch_size"]]

print("num train structures:", len(train_idxs))
print("num test structures:", len(test_idxs))

In [None]:
# Build density dataset
rho_data = data.RhoData(
    idxs=np.concatenate([train_idxs, test_idxs]),
    input_dir=data_settings["input_dir"],
    output_dir=data_settings["output_dir"],
    overlap_dir=data_settings["overlap_dir"],
    keep_in_mem=True,
    standardize_invariants=data_settings["standardize_invariants"],
    train_idxs=train_idxs,
    **torch_settings,
)

In [None]:
# Initialize objects or load from checkpoint
restart_epoch = ml_settings["training"]["restart_epoch"]
out_inv_means = None if "output" not in data_settings["standardize_invariants"] else rho_data.out_invariant_means
if restart_epoch == 0:
    objects = training.init_training_objects(
        ml_settings,
        input=rho_data[rho_data._idxs[0]][1],
        output=rho_data[rho_data._idxs[0]][2],
        out_invariant_means=out_inv_means,
    )
else:
    objects = training.load_from_checkpoint(
        path=os.path.join(ml_settings["run_dir"], f"checkpoint_{restart_epoch}.pt"),
        ml_settings=ml_settings,
        input=rho_data[rho_data._idxs[0]][1],
        output=rho_data[rho_data._idxs[0]][2],
    )
# Unpack objects
model, optimizer, rho_loss_fn, coeff_loss_fn, scheduler = objects

# 5. Run training

In [None]:
def train():
    # Make a run dir for saving results
    if not os.path.exists(ml_settings["run_dir"]):
        os.mkdir(ml_settings["run_dir"])

    # Define a log file
    log_file = os.path.join(ml_settings["run_dir"], "log.txt")
    io.log(log_file, "# Model training")
    io.log(log_file, "# epoch train_loss test_loss lr time learning_on_rho")

    # Initialize the train and test loaders
    train_loader = data.RhoLoader(
        rho_data,
        idxs=train_idxs,
        get_overlaps=False,
        batch_size=ml_settings["loading"]["train"]["batch_size"],
    )
    test_loader = data.RhoLoader(
        rho_data,
        idxs=test_idxs,
        get_overlaps=True,
        batch_size=ml_settings["loading"]["test"]["batch_size"],
    )

    # Pre-collate test data if not performing test batching
    if ml_settings["loading"]["test"]["do_batching"] is False:
        test_batch_idxs, x_test, c_test, s_test = next(iter(test_loader))

    # Start training
    train_losses = []
    test_losses = []
    use_rho_loss = False
    for epoch in range(
        ml_settings["training"]["restart_epoch"] + 1,
        ml_settings["training"]["n_epochs"] + 1,
    ):
        # Start timer
        t0 = time.time()

        # Set some epoch-dependent settings
        check_args = True if epoch == 0 or ml_settings["training"]["learn_on_rho_at_epoch"] else False

        # If we're now switching to learning on rho, reinitialize the train loader
        if epoch == ml_settings["training"]["learn_on_rho_at_epoch"]:
            use_rho_loss = True
            train_loader = data.RhoLoader(
                rho_data,
                idxs=train_idxs,
                get_overlaps=True,
                batch_size=ml_settings["loading"]["train"]["batch_size"],
            )

        # ===== Iterate over training batches
        for train_batch in train_loader:
            # Reset gradients
            optimizer.zero_grad()

            # Unpack train batch
            if use_rho_loss:
                train_batch_idxs, x_train, c_train, s_train = train_batch
            else:
                train_batch_idxs, x_train, c_train = train_batch

            # Make a prediction
            c_train_pred = model(x_train, check_args=check_args)

            # Evaluate the loss with either CoeffLoss or RhoLoss
            if use_rho_loss:
                train_loss = rho_loss_fn(
                    c_train_pred, c_train, s_train, check_args=check_args
                )
            else:  # use CoeffLoss
                train_loss = coeff_loss_fn(c_train_pred, c_train, check_args=check_args)

            # Calculate gradient and update parameters
            train_loss.backward(retain_graph=True)
            optimizer.step()

            # Store training loss, divided by the number of structures in the batch
            train_losses.append(train_loss.detach().numpy() / len(train_batch_idxs))

        # ===== Evaluate test loss *on the density*
        with torch.no_grad():
            # Option 1) perform test batching
            if ml_settings["loading"]["test"]["do_batching"]:
                # Iterate over test batches: calculate the test loss
                for test_batch in test_loader:
                    # Unpack test batch
                    test_batch_idxs, x_test, c_test, s_test = test_batch
                    # Make a prediction
                    c_test_pred = model(x_test, check_args=check_args)
                    # Evaluate and store test loss per structure
                    test_loss = rho_loss_fn(
                        c_test_pred, c_test, s_test, check_args=check_args
                    )
                    test_losses.append(
                        test_loss.detach().numpy() / len(test_batch_idxs)
                    )

            # Option 2) use a single batch, pre-collated
            else:
                # Make a prediction
                c_test_pred = model(x_test, check_args=check_args)
                # Evaluate and store test loss per structure
                test_loss = rho_loss_fn(
                    c_test_pred, c_test, s_test, check_args=check_args
                )
                test_losses.append(test_loss.detach().numpy() / len(test_batch_idxs))

        # Save checkpoint
        if epoch % ml_settings["training"]["save_interval"] == 0:
            training.save_checkpoint(
                ml_settings["run_dir"], epoch, model, optimizer, scheduler=scheduler
            )

        # Write log for the epoch
        io.log(
            log_file,
            f"{epoch} "
            f"{np.round(train_losses[-1], 7)} "
            f"{np.round(test_losses[-1], 7)} "
            f"{np.round(scheduler.get_last_lr()[0], 7)} "
            f"{np.round(time.time() - t0, 7)} "
            f"{1 if use_rho_loss else 0} ",
        )
        scheduler.step()

In [None]:
train()

# 6. Plot results

In [None]:
# # Load the results
# results = np.loadtxt(os.path.join(ml_settings["run_dir"], "log.txt"))

# # Plot train and test loss versus epoch
# fig, ax = plt.subplots()
# ax.loglog(results[:, 0], results[:, 1], label="train")
# ax.loglog(results[:, 0], results[:, 2], label="test")
# ax.legend()
# ax.set_xlabel("epoch")
# ax.set_ylabel("loss per batch")

## 7. Make a prediction on the validation structure

In [None]:
# # Load the input and output validation TensorMaps
# in_val = io.load_tensormap_to_torch(
#     os.path.join(data_settings["data_dir"], "in_val.npz"), **ml_settings["torch"]
# )
# out_val = metatensor.load(os.path.join(data_settings["data_dir"], "out_val.npz"))

# # Retrieve the unique structure
# val_idx = metatensor.unique_metadata(in_val, axis="samples", names="structure")[0][0]
# val_frame = ase.io.read(
#     os.path.join(data_settings["data_dir"], "water_monomers_1k.xyz"), index=val_idx
# )

# # Build a pyscf Molecule object
# val_mol = pyscf.gto.Mole().build(
#     atom=[
#         (i, j) for i, j in zip(val_frame.get_chemical_symbols(), val_frame.positions)
#     ],
#     basis="ccpvqz jkfit",
# )

# # Predict the density
# out_val_pred, coeffs = predictor.predict_density_from_mol(
#     in_val,
#     val_mol,
#     model_path=os.path.join(ml_settings["run_dir"], "epoch_10", "model.pt"),
#     inv_means_path=os.path.join(data_settings["data_dir"], "inv_means.npz"),
# )

### Parity plot: target vs predicted coefficients

In [None]:
# # Calculate the MSE Error
# with torch.no_grad():
#     val_loss = loss.MSELoss(reduction="sum")(
#         utils.tensor_to_torch(out_val, **ml_settings["torch"]), 
#         utils.tensor_to_torch(out_val_pred, **ml_settings["torch"])
#     ).detach().numpy()

# # Plot the target vs predicted coefficients, standardized
# fig, ax = plots.parity_plot(
#     target=out_val,
#     predicted=out_val_pred,
#     color_by="spherical_harmonics_l",
# )
# lim = [-0.05, 0.1]
# ax.set_xlim(lim)
# ax.set_ylim(lim)
# ax.set_aspect("equal")
# ax.set_xlabel("target density coefficient")
# ax.set_ylabel("predicted density coefficient")
# ax.set_title(f"Validation MSE Error: {round(val_loss * 1e6, 3)}"r" $\times 10^{-6}$")
# ax.legend()

# 8. Process densities with `Q-stack` and visualize

In [None]:
# # Build a delta density TensorMap
# out_val_delta = metatensor.abs(metatensor.subtract(out_val_pred, out_val))

# # Vectorize the coefficients from each of the TensorMaps
# new_key_names = ["spherical_harmonics_l", "element"]
# vect_coeffs_target = qstack.equio.tensormap_to_vector(
#     val_mol,
#     utils.rename_tensor(
#         utils.drop_metadata_name(out_val, "samples", "structure"),
#         keys_names=new_key_names,
#     ),
# )
# vect_coeffs_input = qstack.equio.tensormap_to_vector(
#     val_mol,
#     utils.rename_tensor(
#         utils.drop_metadata_name(out_val_pred, "samples", "structure"),
#         keys_names=new_key_names,
#     ),
# )
# vect_coeffs_delta = qstack.equio.tensormap_to_vector(
#     val_mol,
#     utils.rename_tensor(
#         utils.drop_metadata_name(out_val_delta, "samples", "structure"),
#         keys_names=new_key_names,
#     ),
# )

# # Convert the basis function coefficients to a cube file
# plot_dir = os.path.join(ml_settings["run_dir"], "plots")
# io.check_or_create_dir(plot_dir)
# n = 60  # grid points per dimension
# for (coeffs, filename) in [
#     (vect_coeffs_target, "out_val.cube"),
#     (vect_coeffs_input, "out_val_pred.cube"),
#     (vect_coeffs_delta, "out_val_delta.cube"),
# ]:
#     qstack.fields.density2file.coeffs_to_cube(
#         val_mol,
#         coeffs,
#         os.path.join(plot_dir, filename),
#         nx=n,
#         ny=n,
#         nz=n,
#         resolution=None,
#     )

### Predicted electron density

In [None]:
# # Visualize the predicted density
# v = py3Dmol.view()
# v.addModelsAsFrames(open(os.path.join(plot_dir, "out_val_pred.cube"), "r").read(), "cube")
# v.setStyle({"stick": {}})
# v.addVolumetricData(
#     open(os.path.join(plot_dir, "out_val_pred.cube"), "r").read(),
#     "cube",
#     {"isoval": 0.05, "color": "blue", "opacity": 0.8},
# )
# v.show()

### "Delta electron density" - i.e. the ML error (100x magnification)

In [None]:
# # Visualize the delta density
# v = py3Dmol.view()
# v.addModelsAsFrames(open(os.path.join(plot_dir, "out_val_delta.cube"), "r").read(), "cube")
# v.setStyle({"stick": {}})
# v.addVolumetricData(
#     open(os.path.join(plot_dir, "out_val_delta.cube"), "r").read(),
#     "cube",
#     {"isoval": 0.0005, "color": "blue", "opacity": 0.8},
# )
# v.show()