# Cascade: serial prototype implementation

Here we use some of the classes we've written to create a serial prototype run of cascade

This is the minimum viable run, intended to inform upcoming design decisions before distributed runs.

No science is done here. 

In [1]:
from glob import glob
from pathlib import Path
from dataclasses import dataclass, field


import ase
from ase.io import read, write
from ase.io.trajectory import Trajectory, TrajectoryWriter
from ase import units
from ase.md import MDLogger, VelocityVerlet
import numpy as np
from mace.calculators import mace_mp

from cascade.trajectory import CascadeTrajectory
from cascade.utils import canonicalize, apply_calculator
from cascade.auditor import RandomAuditor
from cascade.learning.torchani import TorchANI
from cascade.learning.torchani.build import make_output_nets, make_aev_computer
from cascade.runner import SerialCascadeRunner

%load_ext autoreload
%autoreload 2

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
  from torch.distributed.optim import ZeroRedundancyOptimizer


## Read in structure
We'll do these simulations on a Si 2x2x2 with a vacancy

In [2]:
atoms = read('../0_setup/initial-geometries/si-vacancy-2x2x2.vasp')

## Set up calculator

We'll use a small MACE model as our *target*.   
That is to say, MACE is our ground truth physics.   
(We want fast for this prototype)

In [3]:
device = 'cuda:0'
calc = mace_mp('small')

  torch.load(f=model_path, map_location=device)


Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.


## Set up learner

We'll fit two ANI models to MACE

In [4]:
learner = TorchANI()

In [5]:
species = list(set(atoms.symbols))
aev = make_aev_computer(species)

model = aev, make_output_nets(species, aev), dict((s, 0.) for s in species)

## Minimum viable cascasde loop

In [6]:

# create two cascasde trajectories from the same starating point but with different seeds
seeds = [0, 1]
trajectories = [CascadeTrajectory(path=f'si-diffusion-seed={s}.traj', 
                                  starting=atoms.copy()) for s in seeds]
# notably, right now, the seeds have no effect since our dynamics are NVE

cascade = SerialCascadeRunner(
    trajectories=trajectories,
    total_steps=128,
    increment_steps=64,
    uq_threshold=0.5,
    auditor=RandomAuditor(random_state=42),
    learner=learner,
    model=model,
    calculator=calc,
    dyn_cls=VelocityVerlet,
    train_kws=dict(device='cpu', num_epochs=1),
    max_train=10,
    val_frac=0.1,
    training_file='train.traj',
)

cascade.run(max_iter=10)

**********
Starting pass 1/10 of cascade loop
Currently 0 of 2 complete
Examining trajectory 1 of 2
Trajectory is trusted, advancing
Running ML-driven dynamics


  self_energies = torch.tensor(self_energies, dtype=torch.double)


Examining trajectory 2 of 2
Trajectory is trusted, advancing
Running ML-driven dynamics
**********
Starting pass 2/10 of cascade loop
Currently 0 of 2 complete
Examining trajectory 1 of 2
Trajectory has untrusted segment, auditing
Auditing trajectory
score < threshold (0.3745401188473625 < 0.5, marking recent segment as trusted
Examining trajectory 2 of 2
Trajectory has untrusted segment, auditing
Auditing trajectory
score < threshold (0.034388521115218396 < 0.5, marking recent segment as trusted
**********
Starting pass 3/10 of cascade loop
Currently 0 of 2 complete
Examining trajectory 1 of 2
Trajectory is trusted, advancing
Running ML-driven dynamics
Examining trajectory 2 of 2
Trajectory is trusted, advancing
Running ML-driven dynamics
**********
Starting pass 4/10 of cascade loop
Currently 0 of 2 complete
Examining trajectory 1 of 2
Trajectory has untrusted segment, auditing
Auditing trajectory
score > threshold (0.6688412526636073 > 0.5), running audit calculations and dropping u

## did those complete? 

This is great, next steps: 
- [ ] diagram out current/trusted logic
- [x] add training
- [x] break into functions/classes WIP
- [x] add tests WIP
- [ ] add logging

In [7]:
[len(t.read()) for t in trajectories]

[129, 129]

Seems done enough for now