## Pseudocode

We want to build something like this
```python
new_training_examples = []
initial_starting_points = []

explore_timesteps = 50
n_audit_frames = 10

class Trajectory: 
    def __init__(self, file): 
        file = file
    last_trusted_ix = 0

def audit(trajectory, reference_calc): 
    uqs = [uq(frame) for frame in traj[last_trusted:]]
    # audit the topk worst UQ frames
    audit_indices = np.argsort(uqs)[::-n_audit_frames] # double check this indexing
    ref_energies, ref_forces = [compute_forces(frames[i]) for i in audit_indices]
    error = compute_error(forces, ref_forces)
    return error, ref_energies, ref_forces
    
trajectories = [Trajectory() for file in initial_starting_points]

done = False
while not done: # loop breaks when all trajectories are done
    if training_set.updated(): 
        model.retrain()
    for traj in trajectories:
        calc = reference_calc
        init = read(traj.file, index=traj.last_trusted_ix)
        dyn.run(init, explore_timesteps, traj)
        error, audit_frames = audit(trajectory, calc)
        if error > tol: 
            training_set.append(audit_frames)

    done = all(traj.done() for traj in trajectory)
```

## prototype

In [59]:
from glob import glob
from pathlib import Path
from ase.io import read
from ase.io.trajectory import Trajectory
from ase import units
from ase.md.npt import NPT

import numpy as np
from cascade.calculator import make_calculator

from mace.calculators import MACECalculator

### Read in initial conditions and get trajectory names

In [37]:
initial_trajectories = glob('../0_setup/md/**/*md.traj')[-1:]
initial_trajectories

['../0_setup/md/packmol-CH4-in-H2O=32-seed=3-blyp-npt=298/md.traj']

In [38]:
traj_names = [traj.split('/')[-2] for traj in initial_trajectories]
traj_names

['packmol-CH4-in-H2O=32-seed=3-blyp-npt=298']

In [40]:
initial_conditions = {name: read(traj, index='0') 
                      for traj, name in 
                      zip(initial_trajectories, 
                          traj_names)}

In [43]:
initial_conditions

{'packmol-CH4-in-H2O=32-seed=3-blyp-npt=298': Atoms(symbols='CH68O32', pbc=True, cell=[10.288817093428836, 10.288817093428836, 10.288817093428836], momenta=..., calculator=SinglePointCalculator(...))}

## Read in initial models

In [44]:
initial_model_files = glob('../1_mace/ensemble/*.pt')
calc_ml = MACECalculator(initial_model_files, device='cuda:0')
calc_ref = make_calculator('blyp')

Running committee mace with 4 models
No dtype selected, switching to float32 to match model dtype.


In [55]:
min(chunk_size, total_time - 500)

12

## TODO:
handle the trajectory splicing/overwiting. I think it makes sense to see how Logan did this. 

One idea is to track the trusted segments
something like 

```python
@dataclass
class Trajectory: 

    id: str
    starting: ase.Atoms
    last_trusted_timestep = 0
    
    # which chunks to splice together into a trusted trajectory
    trusted_chunks = tuple()
```

In [61]:
last_trusted_ix = 0
chunk_size = 50
total_time = 512
retrain_ix = 0
done = False

while not done:
    for name, atoms in initial_conditions.items():
        # set up the run dir
        run_dir = Path('cascade-md') / name
        run_dir.mkdir(exist_ok=True, parents=True)
        traj_file = run_dir / 'md.traj'

        
        if last_trusted_ix > 0: 
            atoms = read(traj_file, 
                         index=last_trusted_ix)
        
       
        # setup the ml-driven dynamics
        atoms.calc = calc_ml
        dyn = NPT(atoms,
          timestep=0.5 * units.fs,
          temperature_K=298,
          ttime=100 * units.fs,
          pfactor=0.01,
          externalstress=0,
          logfile=str(run_dir / 'md.log'),
          trajectory=str(traj_file),
          append_trajectory=False)

        # timestep indexing
        start = last_trusted_ix
        step = min(chunk_size, total_time - last_trusted_ix)
        stop = start + step

        # run the dynamics for this chunk
        dyn.run(step)

        # read in the recent chunk
        chunk = read(traj_file, index=f'{start}:{stop}')

        audit
    break

