# Simple Experiments: A Tutorial

This notebook serves as an introductory tour to this code base. We will run a simple experiment with simplified data in order to
introduce relevant functionalities and understand what each part is doing.

Real "production" grade experiments would not be conducted through a notebook like this one. For production, use
the main entry points and configuration files.

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# The dataset

For the purpose of this tutorial, we will generate an on-the-fly simplified dataset. We will draw samples from a simple isotropic Gaussian
distribution centered around the equilibrium coordinates of crystalline silicon. This data is composed of the relative coordinates of 8 atoms in 3D.
This is not representative of a real dataset: "real" silicon atoms would be displaced according to thermalized phonons, not a simple isotropic Normal.


In [None]:
# Basic imports and defining global variables
import numpy as np
from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import MSELossParameters
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import NoiseParameters
from diffusion_for_multi_scale_molecular_dynamics.utils.reference_configurations import get_silicon_supercell

silicon_equilibrium_relative_coordinates = get_silicon_supercell(supercell_factor=1).astype(np.float32)
equilibrium_relative_coordinates = list(list(frac) for frac in silicon_equilibrium_relative_coordinates)
number_of_atoms = 8
spatial_dimension = 3

# The dataset will be a sample from an isotropic Gaussian centered on the equilibrium relative coordinates.
sigma_d = 0.05

# Let's choose convenient dataset sizes so that the code executes quickly.
train_dataset_size = 2_048
valid_dataset_size = 512

# Diffusion models rely on introducing "noise" which scrambles the original data. The goal of the diffusion model
# is to start from a completely scrambled sample and to "denoise" it back to something similar to the data it
# was trained on.

# We must define the parameters of this noising process.
noise_parameters = NoiseParameters(total_time_steps=200,
                                   schedule_type="exponential",
                                   sigma_min=0.005,
                                   sigma_max=0.5)


In [None]:
from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.gaussian_data_module import \
    GaussianDataModuleParameters, GaussianDataModule


gaussian_datamodule_parameters = GaussianDataModuleParameters(noise_parameters=noise_parameters,
                                                              elements=["Si"],
                                                              use_optimal_transport=False,
                                                              random_seed=42,
                                                              number_of_atoms=number_of_atoms,
                                                              sigma_d=sigma_d,
                                                              equilibrium_relative_coordinates=equilibrium_relative_coordinates,
                                                              train_dataset_size=train_dataset_size,
                                                              valid_dataset_size=valid_dataset_size,
                                                              batch_size=512,
                                                              num_workers=8,
                                                              max_atom=number_of_atoms,
                                                              spatial_dimension=spatial_dimension)

data_module = GaussianDataModule(gaussian_datamodule_parameters)

data_module.setup()

# The AXL Diffusion Lightning Model

The main "model" in charge of predicting how to denoise a sample is an "AXL network". The Pytorch Lightning-derived AXLDiffusionLightningMoldel is the class in charge of managing all the needed inputs and outputs to train an AXL network.

Creating this object requires a fair bit of configuration. We'll build these configurations now.

In [None]:
from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL
from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import AtomTypeLossParameters
from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import ReduceLROnPlateauSchedulerParameters
from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import OptimizerParameters
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \
    MLPScoreNetworkParameters
from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import AXLDiffusionParameters, \
    AXLDiffusionLightningModel


# Score Network parameters
# -------------------------
# This is the machine learning model that will be taught how to denoise samples. It is the main "engine" on which
# everything else rests. There are many possible choices available: here we pick a simple MLP for convenience. Note
# that the MLP is not equivariant; this score network is really just for sanity checking and simplified experiments.
score_network_parameters = MLPScoreNetworkParameters(number_of_atoms=number_of_atoms,
                                                     spatial_dimension=spatial_dimension,
                                                     num_atom_types=1,
                                                     n_hidden_dimensions=3,
                                                     hidden_dimensions_size=128,
                                                     noise_embedding_dimensions_size=8,
                                                     relative_coordinates_embedding_dimensions_size=64,
                                                     time_embedding_dimensions_size=8,
                                                     atom_type_embedding_dimensions_size=8,
                                                     lattice_parameters_embedding_dimensions_size=8)


# Loss parameters
#-----------------
# We must define the various lambda weights for the loss. Since we'll only consider diffusion over relative coordinates
# we can simply turn off the loss for atom types and lattice.
loss_parameters = AXL(A=AtomTypeLossParameters(lambda_weight=0.0),
                      X=MSELossParameters(lambda_weight=1.0),
                      L=MSELossParameters(lambda_weight=0.0))


# Optimizer and Scheduler parameters
#-----------------------------------
# We must define some standard algorithmic hyper parameters that control the training process.
optimizer_parameters = OptimizerParameters(name="adamw",
                                           learning_rate=0.001,
                                           weight_decay=1.0e-8)

scheduler_parameters = ReduceLROnPlateauSchedulerParameters(factor=0.5, patience=10)


# kmax_target_score is a convergence parameter for the Ewald-like sum of the perturbation kernel for coordinates. This
# is used to compute the targets for the score network operating on relative coordinates.
kmax_target_score = 4

# AXL Diffusion Parameters:  we combine the various configurations above into a single input for the Model.
diffusion_parameters = AXLDiffusionParameters(score_network_parameters=score_network_parameters,
                                              loss_parameters=loss_parameters,
                                              optimizer_parameters=optimizer_parameters,
                                              scheduler_parameters=scheduler_parameters,
                                              kmax_target_score=kmax_target_score)

# Create the Pytorch-Lightning "model", the class that a Trainer object operates on.
model = AXLDiffusionLightningModel(diffusion_parameters)

# Training

We can finally train the model. We leverage Pytorch-Lightning to train the model: this gives us easy access to convenient metrology tools to keep
track of training as it progesses.

In [None]:
from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR
from pytorch_lightning.loggers import TensorBoardLogger
from diffusion_for_multi_scale_molecular_dynamics.callbacks.callback_loader import create_all_callbacks
from pytorch_lightning import Trainer


# Logger
#-------
# A "logger" keeps track of various artifacts during training (the loss, for instance) and lets us visualize what is going on.
# Many loggers such as Wandb or Comet will connect directly to the cloud so progress can be monitored on a different machine from where
# training is done: this is useful when a job runs on a cluster. Here, we'll use Tensorboard, a completely local solution. Using notebook magics,
# we will embed the Tensorboard ui within this notebook.

# We must define a name for the experiment and a local folder where artifacts will be written.

experiment_name = "tutorial_mlp"
output_path = TOP_DIR / "experiments" / "tutorials" / "output" / experiment_name
output_path.mkdir(parents=True, exist_ok=True)

output_directory = str(output_path)

tensorboard_logger = TensorBoardLogger(save_dir=output_directory,
                                       default_hp_metric=False,
                                       name=experiment_name,
                                       version=0)

# Callbacks
# ---------
# A "callback" is an object with internal methods that we can pass to Pytorch-Lightning (PL); PL will "call back"
# these various methods at specified points during execution, allowing us to do some metrology on the ongoing training.

# Loss monitoring will plot the loss at every epochs in bins over the noising time axis.
loss_monitoring_parameters = dict(number_of_bins=50,
                                  sample_every_n_epochs=1,
                                  spatial_dimension=spatial_dimension)

early_stopping_parameters = dict(metric="validation_epoch_loss",
                                 mode="min",
                                 patience=20)

callback_parameters = dict(loss_monitoring=loss_monitoring_parameters, early_stopping=early_stopping_parameters)

callbacks_dict = create_all_callbacks(callback_parameters,
                                      output_directory=output_directory,
                                      verbose=True)

callbacks = list(callbacks_dict.values())

trainer = Trainer(callbacks=callbacks,
                  max_epochs=300,
                  log_every_n_steps=1,
                  fast_dev_run=False,
                  logger=tensorboard_logger,
                  enable_progress_bar=True)

# We will monitor progess using Tensorboard, which will be embedded in this notebook.
%tensorboard --logdir output --samples_per_plugin images=99999

# Training can take a few minutes...
trainer.fit(model, datamodule=data_module)


# Sampling
Now that the model is trained, we can draw new samples with it. In order to do so, we create a "generator" which is responsible for
creating new samples by using the axl_network to denoise random starting points.


In [None]:
from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL_COMPOSITION
from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling import create_batch_of_samples
import torch
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \
    PredictorCorrectorSamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import LangevinGenerator

# Sampling parameters
cell_dimensions = [5.43, 5.43, 5.43] # Si cell cubic cell side lengths.

sampling_parameters = PredictorCorrectorSamplingParameters(number_of_samples=1_024,
                                                           spatial_dimension=spatial_dimension,
                                                           number_of_corrector_steps=1,
                                                           num_atom_types=1,
                                                           number_of_atoms=number_of_atoms,
                                                           use_fixed_lattice_parameters=True,
                                                           cell_dimensions=cell_dimensions)


generator = LangevinGenerator(noise_parameters=noise_parameters,
                              sampling_parameters=sampling_parameters,
                              axl_network=model.axl_network)


with torch.no_grad():
    samples_batch = create_batch_of_samples(generator=generator,
                                            sampling_parameters=sampling_parameters,
                                            device=model.device)

# We can extract the sampled relative coordinates
sampled_relative_coordinates = samples_batch[AXL_COMPOSITION].X


# Quantifying Results

We can can gauge the quality of the samples by computing their total displacements from the equilibrium relative coordinates, and
seeing how the distribution of these displacements compare with the validation dataset. We'll define a simple function to extract
the total distances, and then we'll plot the results.

In [None]:
import einops


def compute_total_distance(relative_coordinates: torch.Tensor, reference_relative_coordinates: torch.Tensor) -> float:
    """ Compute total distance.

    This method computes the "total distance" between two configurations, accounting for periodicity,
    by comparing coordinates in order.

    Args:
        relative_coordinates: the relative coordinates of a configuration. Dimension [number_of_atoms, spatial_dimension]
        reference_relative_coordinates: the reference relative coordinates of a configuration.
            Dimension [number_of_atoms, spatial_dimension]

    Returns:
        Total distance: the total distance between the relative coordinates and the reference, in reduced units.
    """
    raw_displacements = relative_coordinates - reference_relative_coordinates
    augmented_displacements = [raw_displacements - 1.0, raw_displacements, raw_displacements + 1.0]

    squared_displacements = einops.rearrange(augmented_displacements, "c n d -> (n d) c")**2

    total_displacement = torch.sqrt(squared_displacements.min(dim=1).values.sum())
    return total_displacement.item()


In [None]:
# Compute the distances for samples and validation datasets

reference_relative_coordinates = torch.from_numpy(silicon_equilibrium_relative_coordinates)

sampled_distances = [compute_total_distance(relative_coordinates, reference_relative_coordinates)
                        for relative_coordinates in sampled_relative_coordinates]


validation_distances = []
for row in data_module.valid_dataset:
    relative_coordinates = row['relative_coordinates']
    distance = compute_total_distance(relative_coordinates, reference_relative_coordinates)
    validation_distances.append(distance)

In [None]:
from diffusion_for_multi_scale_molecular_dynamics.analysis import PLEASANT_FIG_SIZE
from matplotlib import pyplot as plt

fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig.suptitle(f"Sampling Displacement Distributions")
common_params = dict(density=True, bins=100, histtype="stepfilled", alpha=0.25)

ax1 = fig.add_subplot(111)

ax1.hist(sampled_distances, **common_params, label=f"Samples", color="red")
ax1.hist(validation_distances, **common_params, label=f"Validation Data", color="green")

ax1.set_xlabel("Total Displacement (Unitless)")
ax1.set_ylabel("Density")
ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=12)
fig.tight_layout()
plt.show()