## Pseudocode

We want to build something like this
```python
new_training_examples = []
initial_starting_points = []

explore_timesteps = 50
n_audit_frames = 10

class Trajectory: 
    def __init__(self, file): 
        file = file
    last_trusted_ix = 0

def audit(trajectory, reference_calc): 
    uqs = [uq(frame) for frame in traj[last_trusted:]]
    # audit the topk worst UQ frames
    audit_indices = np.argsort(uqs)[::-n_audit_frames] # double check this indexing
    ref_energies, ref_forces = [compute_forces(frames[i]) for i in audit_indices]
    error = compute_error(forces, ref_forces)
    return error, ref_energies, ref_forces
    
trajectories = [Trajectory() for file in initial_starting_points]

done = False
while not done: # loop breaks when all trajectories are done
    if training_set.updated(): 
        model.retrain()
    for traj in trajectories:
        calc = reference_calc
        init = read(traj.file, index=traj.last_trusted_ix)
        dyn.run(init, explore_timesteps, traj)
        error, audit_frames = audit(trajectory, calc)
        if error > tol: 
            training_set.append(audit_frames)

    done = all(traj.done() for traj in trajectory)
```

## prototype

In [1]:
from glob import glob
from pathlib import Path
from ase.io import read
from ase.io.trajectory import Trajectory
from ase import units
from ase.md.npt import NPT

import numpy as np
from cascade.calculator import make_calculator

from mace.calculators import MACECalculator

### Read in initial conditions and get trajectory names

In [2]:
initial_trajectories = glob('../0_setup/md/**/*md.traj')[-1:]
initial_trajectories

['../0_setup/md/packmol-CH4-in-H2O=32-seed=3-blyp-npt=298/md.traj']

In [3]:
traj_names = [traj.split('/')[-2] for traj in initial_trajectories]
traj_names

['packmol-CH4-in-H2O=32-seed=3-blyp-npt=298']

In [4]:
initial_conditions = {name: read(traj, index='0') 
                      for traj, name in 
                      zip(initial_trajectories, 
                          traj_names)}

In [5]:
initial_conditions

{'packmol-CH4-in-H2O=32-seed=3-blyp-npt=298': Atoms(symbols='CH68O32', pbc=True, cell=[10.288817093428836, 10.288817093428836, 10.288817093428836], momenta=..., calculator=SinglePointCalculator(...))}

## Read in initial models

In [6]:
initial_model_files = glob('../1_mace/ensemble/*.pt')
calc_ml = MACECalculator(initial_model_files, device='cuda:0')
calc_ref = make_calculator('blyp')

Running committee mace with 4 models
No dtype selected, switching to float32 to match model dtype.


## TODO:
handle the trajectory splicing/overwiting. I think it makes sense to see how Logan did this. 

One idea is to track the trusted segments
something like 

```python
@dataclass
class Trajectory: 

    id: str
    starting: ase.Atoms
    last_trusted_timestep = 0
    
    # which chunks to splice together into a trusted trajectory
    trusted_chunks = tuple()
```

In [8]:
from typing import Tuple, List, Dict
from dataclasses import dataclass, field
from collections import defaultdict

Really maybe what makes more sense is something like 

```python

@dataclass
class Trajectory: 

    chunk_size: int = 50 # number of timesteps in a chunk (for now is constant)
    last_trusted_timestep: int = 0
    chunks = {} # (factory!) dict mapping chunk: list of filenames, or just int: int
```

In [9]:
@dataclass
class CascadeTrajectory:
    """A class to encasplulate a cascade trajectory, which has trusted and untrusted chunks"""
    
    dir: str
    last_trusted_timestep: int = 0
    trusted_chunks: List[int] = field(default_factory=list)
    
    def read(self): 
        pass

In [15]:
name = traj_names[0]
chunk_size = 50
total_steps = 512
n_chunks = np.ceil(total_steps / chunk_size)
done = False
retrain_ix = 1 # start counting at 1 (scary!)
chunk_ix = 1
chunk_passes = defaultdict(lambda: 0) # int: int how many passes for each chunk

while not done:
    pass_ix = 1
    
    # set up the directory to hold the trajectory for this pass
    run_dir = Path(f'cascade-md') / name
    pass_dir = run_dir / f'chunk={chunk_ix}-pass={pass_ix}'
    pass_dir.mkdir(exist_ok=True, parents=True)

    # pull in initial conidtions or last frame from the most recent trusted chunk
    if chunk_ix == 1: 
        atoms = initial_conditions[name]
    else:
        last_pass = chunk_passes[chunk_ix-1]
        atoms = read(Path(run_dir)/name/f'chunk={chunk_ix-1}-{last_pass}', 
                     index='-1')

    # we save the trajectory in chunks, inluding every pass at simulating that chunk
    logfile = str(pass_dir / 'md.log')
    trajfile = str(pass_dir / 'md.traj')
    
    # setup the ml-driven dynamics
    atoms.calc = calc_ml
    dyn = NPT(atoms,
      timestep=0.5 * units.fs,
      temperature_K=298,
      ttime=100 * units.fs,
      pfactor=0.01,
      externalstress=0,
      logfile=logfile,
      trajectory=trajfile,
      append_trajectory=False)
    # timestep indexing
    # start = (chunk_ix-1) * chunk_size # the actual starting timestep
    # stop = min(chunk_size, chunk_size*chunk_ix)
    # there is probably a nice mathy way to do this
    resulting_steps = chunk_ix * chunk_size # how many total timesteps will be achieved
    if resulting_steps < total_steps: 
        chunk_steps = chunk_size
    else: 
        chunk_steps = total_steps - ((chunk_ix-1)*chunk_size)

    # run the dynamics for this chunk
    dyn.run(chunk_steps)

    # read in the recent chunk
    chunk = read(trajfile)
    break

## todo

- [ ] make the UQ eval plots
- [ ] get ASE/cp2k working locally again
- [ ] get the above working code working with a single trajectory in a token manner
- [ ] get it working with multiple trajectories in a less token manner

In [18]:
!tree cascade-md

[01;34mcascade-md[0m
└── [01;34mpackmol-CH4-in-H2O=32-seed=3-blyp-npt=298[0m
    └── [01;34mchunk=1-pass=1[0m
        ├── md.log
        └── md.traj

2 directories, 2 files
