# Demonstrate On-the-Fly Learning with Controlled Error with Proxima
The [Proxima paper from Zamora et al.](https://dl.acm.org/doi/10.1145/3447818.3460370) combines on-the-fly learning of a surrogate model and a control system for ensuring the model is only used where appropriate. 
We demonstrate how to use our implementation of Proxima within this notebook.

In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
from cascade.calculator import make_calculator
from cascade.proxima import SerialLearningCalculator
from cascade.learning.torchani.build import make_aev_computer, make_output_nets
from cascade.learning.torchani import TorchANI
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.nptberendsen import NPTBerendsen
from ase.io.trajectory import Trajectory
from ase.io import read, iread
from ase.db import connect
from ase import units
from pathlib import Path
from tqdm import tqdm
import pickle as pkl
import pandas as pd
import json



Configuration

In [2]:
initial_data: str = 'init-run/si-perfect-2x2x2.vasp/md.traj'
starting_strc: str = 'structures/si-vacancy-2x2x2.vasp'
model_ensemble_size: int = 2
target_fmax: float = 0.1

Determine a run directory

In [3]:
starting_strc = Path(starting_strc)
run_name = f'{Path(starting_strc).name[:-5]}_fmax={target_fmax}'
run_dir = Path('proxima-run') / run_name
run_dir.mkdir(exist_ok=True, parents=True)

In [4]:
db_path = run_dir / 'proxima.db'

## Load the Initial Structure and Calculator.
Use either the last structure from the previous run or the starting structure.

Then create a calculator set to run CP2K with the correct settings.

In [5]:
traj_path = run_dir / 'md.traj'

In [6]:
if traj_path.exists():
    atoms = read(traj_path, -1)
else:
    atoms = read(starting_strc)
    MaxwellBoltzmannDistribution(atoms, temperature_K=300)

Make a calculator using a utility built into Cascade

In [7]:
target_calc = make_calculator('blyp')

Computing the forces should take about a minute

In [8]:
%%time
target_calc.get_forces(atoms);

CPU times: user 5.77 ms, sys: 0 ns, total: 5.77 ms
Wall time: 24 s


## Select the Interatomic Potential
The interatomic potential is what we will be training as a fast surrogate for the DFT calculation.

Cascade splits the definition of the potential into two parts. The first is a class with defines how to use the model, and we're going to use [TorchANI](https://github.com/aiqm/torchani).

In [9]:
learner = TorchANI()

The second part is a list of models to be trained. The TorchANI model learner expects the model to be defined as:

1. A atomic environment computer (converts an atomic structure into a representation useful to nueral networks)
2. A network to compute energies, forces, and stresses using the computed environment
3. A set of reference energies for each type of molecule (TorchANI will solve this, but we must seed with an initial value)

In [10]:
species = list(set(atoms.symbols))  # The elements our model must model

In [11]:
aev_computer = make_aev_computer(species)

In [12]:
models = [
    (aev_computer, make_output_nets(species, aev_computer), dict((s, 0.) for s in species))
    for i in range(model_ensemble_size)
]

## Select the Training Set
We ran a short trajectory in the previous notebook and we'll use it as training data for this run.

Skip if the DB path is already there

In [13]:
if not db_path.exists():
    with connect(db_path) as db:
        for frame in iread(initial_data):
            db.write(frame)

## Create the Online Trainer
The `SerialLearningCalculator` manages training the surrogate model and figuring out how to best use it.
It acts as a normal ASE calculator, so we'll use it later in place of the CP2K calculator created above.

Make one by defining:

In [14]:
learning_calc = SerialLearningCalculator(
    target_calc=target_calc,  # The target level of accuracy, as a calculator to use for generating training data
    learner=learner,  # A tool defining how to train the models
    models=models,  # Starting points for the models being trained
    train_kwargs={'num_epochs': 128, 'batch_size': 64, 'reset_weights': True, 'device': 'cuda'},  # Configuration for the training routines
    train_freq=64,  # After how many new points to retrain the surrogate
    target_ferr=0.3,  # The target level of error, defined as the maximum error between the surrogate and target in eV/atom
    history_length=32,  # How many past evaluations of the target calculator to use when determining whether if the surrogate is accurate enough
    db_path=db_path,  # A database file in which to persist training data
)

If we're starting from an old run, we can load the state of the learner from disk (we'll describe what that is below)

In [15]:
state_path = run_dir / 'proxima-state.pkl'
if state_path.exists():
    with state_path.open('rb') as fp:
        learning_calc.set_state(pkl.load(fp))

As above, you can use it just like another calculator. The first call will take a few minutes because the calculator is training the surrogate models and running the target calculator

In [16]:
%%time
learning_calc.get_forces(atoms);

CPU times: user 11min 27s, sys: 1.27 s, total: 11min 29s
Wall time: 11min 52s


The `used_surrogate` attribute is a diagnostic which tells you where the resultant energy came from

In [17]:
learning_calc.used_surrogate

False

The learning calculator exposes a few other fields it uses while operating that are be useful to understanding how it works.

For example, the most recent uncertainty metric (defined as the largest variance in force across all models) and the observed error of the model are stored in `error_history`.
The first error is going to be rather large

In [18]:
learning_calc.error_history

deque([(0.030899279, 0.505205261344483)], maxlen=32)

The threshold for the uncertainty metric is available as `threshold` but not yet set because there is not enough examples of using the surrogate available

In [19]:
learning_calc.threshold

There is also a set of training data stored on disk as `proxima.db`.

In [20]:
read(db_path, 0)

Atoms(symbols='Si64', pbc=True, cell=[10.86, 10.86, 10.86], momenta=..., calculator=SinglePointCalculator(...))

## Running Dynamics and Learning On-the-Fly
We can now use the learning calculator as if it were any other way of computing atomic energies.

As an example, let's run NPT dynamics at 300K and zero pressure.

In [21]:
atoms.calc = learning_calc

In [22]:
npt = NPTBerendsen(atoms, 
                   timestep=0.3 * units.fs,
                   temperature_K=300,
                   taut=10. * units.fs,  # Small because the structure starts out very unrelaxed
                   logfile=run_dir / 'md.log',
                   pressure_au=0,
                   compressibility_au=4.57e-5 / units.bar)

Run timesteps enough timesteps that the model will start training and, if lucky get used.

In [None]:
with Trajectory(traj_path, mode='a', atoms=atoms) as traj:
    npt.attach(traj.write, interval=1)
    for step in tqdm(range(1024)):
        # Step forward with the NPT
        npt.step()
    
        # Save a log of what the learner is doing
        with open(run_dir / 'proxima-log.json', 'a') as fp:
            last_uncer, last_error = learning_calc.error_history[-1]
            print(json.dumps({
                'step': step,
                'energy': float(atoms.get_potential_energy()),
                'temperature': atoms.get_temperature(),
                'volume': atoms.get_volume(),
                'used_surrogate': bool(learning_calc.used_surrogate),
                'proxima_alpha': learning_calc.alpha,
                'proxima_threshold': learning_calc.threshold,
                'last_uncer': float(last_uncer),
                'last_error': float(last_error),
            }), file=fp)
    
        # Save proxima's state
        with open(state_path, 'wb') as fp:
            pkl.dump(learning_calc.get_state(), fp)

 16%|██████████████████████████████████████▋                                                                                                                                                                                                      | 167/1024 [2:48:43<10:22:35, 43.59s/it]