In [None]:
import os
import ase.io
import numpy as np
import torch

import chemiscope  # interactive molecular visualization
import equistore   # storage format for atomistic ML
import qstack      # quantum chemistry toolkit
import rascaline   # generating structural representations
import rholearn    # torch-based density leaning

RHOLEARN_DIR = "/Users/joe.abbott/Documents/phd/code/rho/rho_learn/"  # for example
data_dir = os.path.join(RHOLEARN_DIR, "docs/example/water/data")

In [None]:
# Read the water molecules from file
n_structures = 1000
frames = ase.io.read(
    os.path.join(data_dir, "water_monomers_1k.xyz"), index=f":{n_structures}"
)

# Turn off periodic boundary conditions
for f in frames:
    f.set_pbc(False)

# Display molecules with chemiscope
cs = chemiscope.show(
    frames,
    mode="default",
    properties={
        "Mean O-H bond length, Angstrom": [np.mean([f.get_distance(0, 1), f.get_distance(0, 2)]) for f in frames],
        "H-O-H angle, degrees": [f.get_angle(1, 0, 2) for f in frames],
    },
)
display(cs)

In [None]:
import equistore.io
from equistore import Labels

from rholearn import features, utils

rascal_hypers = {
    "cutoff": 3.0,  # Angstrom
    "max_radial": 6,  # Exclusive
    "max_angular": 5,  # Inclusive
    "atomic_gaussian_width": 0.2,
    "radial_basis": {"Gto": {}},
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}},
    "center_atom_weight": 1.0,
}

# Compute lambda-SOAP: uses rascaline to compute a SphericalExpansion
# Runtime approx 25 seconds
input = features.lambda_soap_vector(
    frames, rascal_hypers, even_parity_only=True
)

# Drop the block for l=5, Hydrogen as this isn't included in the output electron density
input = utils.drop_blocks(input, keys=Labels(input.keys.names, np.array([[5, 1]])))

# Load the electron density data
output = equistore.io.load(os.path.join(data_dir, "e_densities.npz"))

# Check that the metadata of input and output match along the samples and components axes
assert utils.equal_metadata(input, output, check=["samples", "components"])

# Save lambda-SOAP descriptor to file
equistore.io.save(os.path.join(data_dir, "lambda_soap.npz"), input)

In [None]:
import equistore.io

# Load lambda-SOAP descriptor from file
input = equistore.io.load(os.path.join(data_dir, "lambda_soap.npz"))

In [None]:
from rholearn import pretraining

# Define setting for the data partitioning
settings = {
    "io": {
        "input": os.path.join(data_dir, "lambda_soap.npz"),
        "output": os.path.join(data_dir, "e_densities.npz"),
        "data_dir": os.path.join(data_dir, "partitions"),
    },
    "numpy": {
        "random_seed": 10,
    },
    "train_test_split": {
        "axis": "samples",
        "names": ["structure"],
        "n_groups": 3,
        "group_sizes": [0.7, 0.2, 0.1],
    },
    "data_partitions": {
        "n_exercises": 3,
        "n_subsets": 4,
    },
}
# Partition the data
pretraining.partition_data(settings)

In [None]:
from rholearn import io, pretraining

run_dir = os.path.join(RHOLEARN_DIR, "docs/example/water/runs")
io.check_or_create_dir(run_dir)

settings = {
    "io": {
        "data_dir": os.path.join(data_dir, "partitions"),
        "run_dir": os.path.join(run_dir, "02_linear_std"),
    },
    "data_partitions": {
        "n_exercises": 3,
        "n_subsets": 4,
    },
    "torch": {
        "requires_grad": True,  # needed to track gradients
        "dtype": torch.float64,  # recommended
        "device": torch.device("cpu"),  # which device to load tensors to
    },
    "model": {
        "type": "linear",  # linear or nonlinear
        "args": {
            # "hidden_layer_widths": [16, 16, 16],
            # "activation_fn": "SiLU"
        },
    },
    "optimizer": {
        "algorithm": torch.optim.LBFGS,
        "args": {
            "lr": 0.25,
        },
    },
    "loss": {
        "fn": "MSELoss",  # CoulombLoss or MSELoss
        "args": {
            "reduction": "sum",  # reduction can be used with MSELoss
        },
    },
    "training": {
        "n_epochs": 100,  # number of total epochs to run 
        "save_interval": 10,  # save model and optimizer state every x intervals
        "restart_epoch": None,  # None, or the epoch checkpoint number if restarting
        "standardize_invariant_features": True, 
    },
}

# IMPORTANT! - set the torch default dtype
torch.set_default_dtype(settings["torch"]["dtype"])

# Construct the appropriate torch objects (i.e. models, loss fxns) prior to training
pretraining.construct_torch_objects(settings)

In [None]:
# Define the exercises and subsets to train
exercises = [0]
subsets = [0, 1, 2, 3]

In [None]:
import time
from rholearn import training, utils

for exercise in exercises:
    for subset in subsets:

        # Start timer
        t0 = time.time()

        # Define the training subdirectory
        train_dir = os.path.join(
            settings["io"]["run_dir"], f"exercise_{exercise}", f"subset_{subset}"
        )

        # Load training data and torch objects
        data, model, loss_fn, optimizer = pretraining.load_training_objects(
            settings, exercise, subset, settings["training"]["restart_epoch"]
        )

        # Unpack the data
        in_train, in_test, out_train, out_test = data

        # Execute model training
        print(f"\nTraining in subdirectory {train_dir}")
        training.train(
            in_train=in_train,
            out_train=out_train,
            in_test=in_test,
            out_test=out_test,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            n_epochs=settings["training"]["n_epochs"],
            save_interval=settings["training"]["save_interval"],
            save_dir=train_dir,
            restart=settings["training"]["restart_epoch"],
        )

        # Report on timings
        dt = time.time() - t0
        num_epochs_run = (
            settings["training"]["n_epochs"]
            if settings["training"]["restart_epoch"] is None
            else settings["training"]["n_epochs"]
            - settings["training"]["restart_epoch"]
        )
        msg = (
            f"\nTraining finished in {np.round(dt, 2)} s = {np.round(dt / num_epochs_run, 2)} s per epoch"
            + f"\n(Timed over {num_epochs_run} epochs, perhaps since restart)"
        )
        print(msg)
        with open(os.path.join(train_dir, "log.txt"), "a+") as log:
            log.write(msg)

In [None]:
import matplotlib.pyplot as plt
from rholearn import analysis, plots

run_dir_1 = os.path.join(RHOLEARN_DIR, "docs/example/water/runs", "01_linear")
run_dir_2 = os.path.join(RHOLEARN_DIR, "docs/example/water/runs", "02_linear_std")

plot_dir_1 = os.path.join(run_dir_1, "plots")
plot_dir_2 = os.path.join(run_dir_2, "plots")

In [None]:
# Compile data - linear
train_1, test_1 = analysis.compile_loss_data(run_dir_1, exercises, subsets)
mean_train_1 = analysis.average_losses(train_1)
mean_test_1 = analysis.average_losses(test_1)

# compile data - nonlinear
train_2, test_2 = analysis.compile_loss_data(run_dir_2, exercises, subsets)
mean_train_2 = analysis.average_losses(train_2)
mean_test_2 = analysis.average_losses(test_2)

In [None]:
# Log-log plot of loss vs epoch
fig, ax = plots.loss_vs_epoch(
    [mean_train_1[3], mean_test_1[3], mean_train_2[3], mean_test_2[3]],
    sharey=True,
    mutliple_traces=False,
)

# Format
fig.tight_layout()
fig.set_figheight(7)
fig.set_figwidth(12)
ax[0].legend(labels=[f"subset {s}" for s in np.sort(list(mean_test_1.keys()))])
ax[0].set_ylabel(r"train loss (MSE)")
ax[1].set_ylabel(r"test loss (MSE)")


# Save
# plots.save_fig_mpltex(fig, os.path.join(plot_dir, "loss_vs_epoch"))

In [None]:
# Log-log learning curve plot of loss vs training set size
point = "final"  # take the "final" or "best" epoch loss
fig, ax = plots.learning_curve(
    [mean_train_1, mean_test_1, mean_train_2, mean_test_2],
    np.load(os.path.join(settings["io"]["data_dir"], "subset_sizes_train.npy")),
    point=point,
)

# Format
fig.tight_layout()
ax.set_ylabel(point + r" loss")
ax.legend(labels=["train_1", "test_2", "train_2", "test_2"])

# Save
# plots.save_fig_mpltex(fig, os.path.join(plot_dir, "learning_curve"))