# Exploration of hippynn graph system

## Let's revisit the simple training script "barebones.py"

In [None]:
'''
To obtain the data files needed for this example, use the script process_QM7_data.py, 
also located in this folder. The script contains further instructions for use.
'''

import torch

# Setup pytorch things
torch.set_default_dtype(torch.float32)

import hippynn
hippynn.settings.WARN_LOW_DISTANCES=False

# Hyperparameters for the network
# These are set deliberately small so that you can easily run the example on a laptop or similar.
network_params = {
    "possible_species": [0, 1, 6, 7, 8, 16],  # Z values of the elements in QM7
    "n_features": 20,  # Number of neurons at each layer
    "n_sensitivities": 20,  # Number of sensitivity functions in an interaction layer
    "dist_soft_min": 1.6,  # qm7 is in Bohr!
    "dist_soft_max": 10.0,
    "dist_hard_max": 12.5,
    "n_interaction_layers": 2,  # Number of interaction blocks
    "n_atom_layers": 3,  # Number of atom layers in an interaction block
}

# Define a model
from hippynn.graphs import inputs, networks, targets, physics

species = inputs.SpeciesNode(db_name="Z")
positions = inputs.PositionsNode(db_name="R")

network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params)
henergy = targets.HEnergyNode("HEnergy", network, db_name="T")
# hierarchicality = henergy.hierarchicality

# define loss quantities
from hippynn.graphs import loss

mse_energy = loss.MSELoss.of_node(henergy)
mae_energy = loss.MAELoss.of_node(henergy)
rmse_energy = mse_energy ** (1 / 2)

# Validation losses are what we check on the data between epochs -- we can only train to
# a single loss, but we can check other metrics too to better understand how the model is training.
# There will also be plots of these things over time when training completes.
validation_losses = {
    "RMSE": rmse_energy,
    "MAE": mae_energy,
    "MSE": mse_energy,
}

# This piece of code glues the stuff together as a pytorch model,
# dropping things that are irrelevant for the losses defined.
training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses)

# Go to a directory for the model.
# hippynn will save training files in the current working directory.
# Log the output of python to `training_log.txt`
database = hippynn.databases.DirectoryDatabase(
    name="data-qm7",  # Prefix for arrays in the directory
    directory="../../datasets/qm7_processed",
    test_size=0.1,  # Fraction or number of samples to test on
    valid_size=0.1,  # Fraction or number of samples to validate on
    seed=2001,  # Random seed for splitting data
    **db_info,  # Adds the inputs and targets db_names from the model as things to load
)

# Now that we have a database and a model, we can
# Fit the non-interacting energies by examining the database.
# This tends to stabilize training a lot.
from hippynn.pretraining import hierarchical_energy_initialization

hierarchical_energy_initialization(henergy, database, trainable_after=False)

# Parameters describing the training procedure.
from hippynn.experiment import setup_and_train

experiment_params = hippynn.experiment.SetupParams(
    stopping_key="MSE",  # The name in the validation_losses dictionary.
    batch_size=12,
    optimizer=torch.optim.Adam,
    max_epochs=1,
    learning_rate=0.001,
)
netname = "TEST_BAREBONES_SCRIPT"
with hippynn.tools.active_directory(netname):
    
    setup_and_train(
        training_modules=training_modules,
        database=database,
        setup_params=experiment_params,
    )


## Assembling a graph for training

Perhaps one of the more mysterious lines is:

`training_modules, db_info = hippynn.experiment.assemble_for_training(mse_energy, validation_losses)`

In [None]:
db_info

In [None]:
type(training_modules)

`training_modules` contain 3 objects: A `model`, a (training) `loss`, and an `evaluator` (which computes the validation losses)

In [None]:
for x in [training_modules.model,training_modules.loss,training_modules.evaluator]:
    print(type(x))

With the python graphviz interface installed, it is easy to visualize what a GraphModule does:

In [None]:
from hippynn.graphs.viz import visualize_connected_nodes, visualize_graph_module, visualize_node_set

In [None]:
visualize_graph_module(training_modules.model)

Hidden in the multiple arrows are child nodes. Each node with multiple arrows was a MultiNode, that actually outputs multiple tensors.

In [None]:
visualize_graph_module(training_modules.model,compactify=False)

Let's take a look at just the one-hot encoder. The `node_from_name` method will make it easy to get a reference to a particular node from the printed or visualized information.

In [None]:
onehot = training_modules.model.node_from_name("OneHot")

In [None]:
visualize_node_set([onehot,*onehot.children],compactify=False)

A Predictor interface can make it simpler to compute the value of some nodes over a database.

In [None]:
from hippynn.graphs import Predictor
onehot_predictor = Predictor([species],[*onehot.children,network.input_features])

In [None]:
outputs = onehot_predictor.apply_to_database(database,batch_size=512)
outputs.keys()

In [None]:
train_outs = outputs['train']

The outputs can be indexed by the node name to get the output value:

In [None]:
train_outs["OneHot.encoding"]

You can also get the value for a node using the node directly:

In [None]:
onehot_train = train_outs[onehot.encoding]

In [None]:
input_features = train_outs["PaddingIndexer.indexed_features"]

In this context, the input features look to be the same.

In [None]:
print(input_features.shape,input_features.dtype)
print(onehot_train.shape,onehot_train.dtype)
print(torch.equal(onehot_train,input_features))

But actually, the predictor is hiding some complexity. Let's take a look at a more rudimentary GraphModule constructed directly - we will manually specify the set of inputs and outputs.

In [None]:
from hippynn.graphs import GraphModule
onehot_graphmodule = GraphModule([species],[onehot.encoding,onehot.nonblank,network.input_features])
visualize_graph_module(onehot_graphmodule)

We will also manually graph the input array:

In [None]:
arrays = database.splits['train']

outputs_graph = onehot_graphmodule(arrays['Z'])

In [None]:
type(outputs_graph)

In [None]:
len(outputs_graph)

Each one corresponds to one of the outputs directly: `[onehot.encoding,onehot.nonblank,network.input_features]`

In [None]:
features_graph = outputs_graph[-1]

What share the the features now?

In [None]:
features_graph.shape

Hmm, that's not familiar.

Let's compare this to the output of the Predictor interface:

In [None]:
input_features.shape,features_graph.shape

What's going on? It has to do with the fact that we have batches of systems, but each system has a different number of atoms: 

In [None]:
database.splits['train']['Z']

The predictor actually uses the graph system too:

In [None]:
visualize_graph_module(onehot_predictor.graph,compactify=True)

In [None]:
from hippynn.graphs import IdxType
type(IdxType)

IdxType is an enumeration

In [None]:
dir(IdxType)

## IdxType tags the "batch information" for the tensor

* On each tensor, the batch might refer to a different quantity
* We can have a batch of atoms, or a batch of molecules
* or a batch of MolAtom, meaning molecules on the first batch axis, followed by atoms on the second batch axis
* Index Types like MolAtom and MolAtomAtom can be conveniently batched over
* Index Types like Atoms and Pair are sparse, and so make for more efficient computation
* To track the relationship between the different batch-types, we need _indexing_ information.
* `hippynn` looks at the index types associated with inputs and outputs and can automatically construct conversions between the types whenever the answer is unambiguous.
* In cases where the automatic construction fails, an advanced user can directly specify the intended result.

Behind the hood, the loss and evaluator also use graphs! 

- This is what allows us to python syntax to build a loss function from algebraic operations.

In [None]:
visualize_graph_module(training_modules.loss)

Every model quantity with a db_name can be an input into the loss graph, either in 'true' (database) form, or 'predicted' (model) form:

In [None]:
henergy.mol_energy.true

In [None]:
henergy.mol_energy.pred

In [None]:
visualize_graph_module(training_modules.evaluator.loss)

# Graph transformations

## ASE Interface

In [None]:
# To run this, train a model using ani_aluminum_example.py! 
with hippynn.tools.active_directory('./TEST_ALUMINUM_MODEL/'):
    model=hippynn.experiment.serialization.load_model_from_cwd()

In [None]:
type(model)

In [None]:
visualize_graph_module(model)

Notice the graph structure is somewhat different here, because, for example, we have per-atom energies to train to, and periopdic boundary conditons.

Let's send this to the Atomic Simulation Environment, a code for performing molecular dynamics in python.

In [None]:
from hippynn.interfaces.ase_interface import calculator_from_model

calc = calculator_from_model(model)

In [None]:
visualize_graph_module(calc.module)

Very similarly, we can send the model to an MLIAPInterface for the LAMMPS molecular dynamics code, which is very useful for highly parallel simulations.

In [None]:
from hippynn.interfaces.lammps_interface import MLIAPInterface

In [None]:
lammps_interface = MLIAPInterface(model.node_from_name("HEnergy"),element_types=['Al'])

In [None]:
visualize_graph_module(lammps_interface.graph)

# Ensembling

Often it is useful to ensemble multiple models in machine learning. `hippynn` has some tools to automatically ensemble nodes and graphs.

In [None]:
n_ensemble=5
useful_nodes = []

for i in range(n_ensemble):
    this_species = inputs.SpeciesNode(db_name="Z")
    this_positions = inputs.PositionsNode(db_name="R")
    this_network = networks.Hipnn("hipnn_model", (this_species, this_positions), module_kwargs=network_params)
    this_henergy = targets.HEnergyNode("HEnergy", this_network, db_name="T")
    this_force = physics.GradientNode("Force",(this_henergy,this_positions),sign=-1,db_name="F")
    
    useful_nodes.append(this_henergy)
    useful_nodes.append(this_force)

In [None]:
visualize_connected_nodes(useful_nodes)

Note that due to the presence of multiple ndoes with the same name in this visualization, each one is tagged with an its id. 

In [None]:
from hippynn.graphs import make_ensemble

ensemble,ensemble_info = make_ensemble(useful_nodes)

In [None]:
visualize_graph_module(ensemble)

The graph interface allows us to easily glue these models together and share intermediate computations where possible.

Now, the models are merged as far as possible, sharing inputs and early calculations.
 
At the same time, the ensemble quantities for energy ("T") and force ("F") have been constructed as nodes.

In [None]:
ensemble_T = ensemble.node_from_name("ensemble_T")
ensemble_F = ensemble.node_from_name("ensemble_F")

In [None]:
ensemble_predictor = Predictor.from_graph(ensemble)

In [None]:
outputs = ensemble_predictor.apply_to_database(database,batch_size=128)

In [None]:
ensemble_T.mean

In [None]:
outputs['test'][ensemble_T.mean].shape

In [None]:
outputs['test'][ensemble_T.std].shape

The "all" node outputs each individual prediction, stacked:

In [None]:
outputs['test'][ensemble_T.all].shape

In [None]:
outputs['test'][ensemble_F.all].shape

The features above can be intermixed, for example, building an ASE calculator using the ensemble module.

In [None]:
from hippynn.interfaces.ase_interface import HippynnCalculator

In [None]:
ensemble_calculator = HippynnCalculator(ensemble_T.mean)

In [None]:
visualize_graph_module(ensemble_calculator.module)