The repository is structured through the PyTorch Lightning (PL) framework which takes care of all the boiler plate code.

You can check out a quick start at the [documentation](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html) with more details in the 'Core API' section.

PL has three main ingredients:
- LightningModule
    Defines the model and defines the all the steps of the training pipeline: training + backprop step, validation step, test step, before training routine etc etc
- LightningDataModule
    Defines the download/get data step, the data preparation step (normalize etc) and provides train, validation and test dataloaders functions which return the required data loaders
- Trainer
    Combines the LightningModule with the LightningDataModule and runs the whole training loop while taking care of GPUs, logging, callbacks. It has a fit function for training the LightningModule to the LightningDataModule and a test function which automatically sets the LightningModule into inference mode

In [1]:
import torch
from pytorch_lightning.callbacks import EarlyStopping
from MLMD.src.MD_DataUtils import load_dm_data
from src.MD_HyperparameterParser import Interpolation_HParamParser
from src.MD_PLModules import Interpolator
from pytorch_lightning import Trainer, seed_everything

First we set the hyperparameter parser which creates an ArgumentParser (this can improved by using [module specific hyperparameters](https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html).

In [2]:
hparams = Interpolation_HParamParser(logger=0,
                                     plot=0,
                                     show=0,
                                     load_weights=0,
                                     save_weights=0,
                                     fast_dev_run=0,
                                     project='vibrationalspectra',
                                     model='bi_lstm',
                                     num_layers=5,
                                     num_hidden_multiplier=10,
                                     criterion='MAE',
                                     interpolation=True,
                                     interpolation_mode='adiabatic',
                                     integration_mode='diffeq',
                                     diffeq_output_scaling=1,
                                     dataset=['malonaldehyde_dft.npz', 'benzene_dft.npz', 'ethanol_dft.npz', 'toluene_dft.npz', 'naphthalene_dft.npz', 'salicylic_dft.npz', 'paracetamol_dft.npz',
                                              'aspirin_dft.npz', 'keto_100K_0.2fs.npz', 'keto_300K_0.2fs.npz', 'keto_500K_0.2fs.npz'][0],
                                     input_length=1,
                                     output_length=20,
                                     batch_size=49,
                                     auto_scale_batch_size=False,
                                     optim='adam',
                                     lr=1e-3,
                                     train_traj_repetition=1,
                                     max_epochs=2,
                                     limit_train_batches=25,
                                     limit_val_batches=25)


After having parsed the hyperparameters, we're ready to load the data.
Every dataset is encapsulated in a LightningDataModule, while the datasets themselves are contained in BiDirectional_DataSets which take care of fetching the correct initial and final conditions.

In [3]:
dm = load_dm_data(hparams)

malonaldehyde_dft.npz: [Num Trajectory, Time Steps, Features ]=torch.Size([1, 993236, 54]) features


Then we initialize the Interpolation module from the hyperparameters and set the required output moments.

In [4]:
model = Interpolator(**vars(hparams))
model.model.set_diffeq_output_scaling_statistics(dm.dy_mu, dm.dy_std)

Here come two callbacks for early stopping and checkpointing:

In [5]:
early_stop_callback = EarlyStopping(monitor='Val/Epoch' + hparams.criterion, mode='min', patience=3, min_delta=0.0005, verbose=True)

Finally we initialize the Trainer-Module

In [6]:
trainer = Trainer.from_argparse_args(hparams,
                                         # min_steps=1000,
                                         # max_steps=50,
                                         callbacks=[early_stop_callback] if hparams.save_weights else [early_stop_callback],
                                         val_check_interval=1.,
                                         gpus=1 if torch.cuda.is_available() else None)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [7]:
trainer.fit(model=model, datamodule=dm)


  | Name  | Type                  | Params
------------------------------------------------
0 | model | MD_BiDirectional_LSTM | 8.3 M 
------------------------------------------------
8.3 M     Trainable params
108       Non-trainable params
8.3 M     Total params
33.312    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
