In [1]:
## test train task
## train a model with enough runs
## compare results with default MD DFT



In [2]:
## import modules
from pathlib import Path
import torch 
import logging
import shutil
from collections import deque
from dataclasses import dataclass
from random import shuffle, sample
from typing import Dict, Any, Optional
import json
from functools import partial, update_wrapper
import numpy as np

import ase
from ase.db import connect
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from fff.learning.gc.ase import SchnetCalculator
from fff.learning.gc.functions import GCSchNetForcefield
from fff.learning.gc.models import SchNet, load_pretrained_model
from fff.learning.util.messages import TorchMessage
from fff.sampling.md import MolecularDynamics
from fff.simulation import run_calculator
from fff.simulation.utils import read_from_string, write_to_string

In [12]:
## path and varaibles
multisite_path = "/home/yxx/work/project/colmena/multisite_"
training_set = multisite_path + "/data/forcefields/starting-model/initial-database.db"
model_path = multisite_path + "/data/forcefields/starting-model/starting-model"
search_path = training_set
out_dir = Path(multisite_path) / f"my_test/temp"
out_dir.mkdir(parents=True, exist_ok=True)

starting_model = torch.load(model_path, map_location='cpu')

num_epochs = 10
huber_deltas = (1, 10)
sampler_kwargs = {'device': "cuda", 'timestep': 0.1, 'log_interval': 10}
sampler = MolecularDynamics()
n_models = 1
n_qc_workers = 8
min_run_length = 1
max_run_length = 100
energy_tolerance = 0.1

In [11]:
## train model pretreat

# Apply wrappers to functions that will be used to fix certain requirements
def _wrap(func, **kwargs):
    out = partial(func, **kwargs)
    update_wrapper(out, func)
    return out

## MD objectives
@dataclass
class Trajectory:
    """Tracks the state of searching along individual trajectories

    We mark the starting point, the last point produced from sampling,
    and the last point we produced that has been validated
    """
    id: int  # ID number of the
    starting: ase.Atoms  # Starting point of the trajectory
    current_timestep = 0  # How many timesteps have been used so far
    last_validated: ase.Atoms = None  # Last validated point on the trajectory
    current: ase.Atoms = None  # Last point produced along the trajectory
    last_run_length: int = 0  # How long between current and last_validated
    name: str = None  # Name of the trajectory

    def __post_init__(self):
        self.last_validated = self.current = self.starting

    def update_current_structure(self, strc: ase.Atoms, run_length: int):
        """Update the structure that has yet to be updated

        Args:
            strc: Structure produced by sampling
            run_length: How many timesteps were performed in sampling run
        """
        self.current = strc.copy()
        self.last_run_length = run_length

    def set_validation(self, success: bool):
        """Set whether the trajectory was successfully validated

        Args:
            success: Whether the validation was successful
        """
        if success:
            self.last_validated = self.current  # Move the last validated forward
            self.current_timestep += self.last_run_length


@dataclass
class SimulationTask:
    atoms: ase.Atoms  # Structure to be run
    traj_id: int  # Which trajectory this came from
    ml_eng: float  # Energy predicted from machine learning model
    ml_std: Optional[float] = None  # Uncertainty of the model

## get model
schnet = GCSchNetForcefield(starting_model)

## copy training data
train_path = out_dir / "train.db" 
shutil.copyfile(training_set, train_path)

### wrap functions
## train model
my_train_schnet = _wrap(schnet.train, num_epochs=num_epochs, device='cuda',
                        patience=8, reset_weights=False,
                        huber_deltas=huber_deltas)

## evaluate model
my_eval_schnet = _wrap(schnet.evaluate, device='cuda')

## use model sampling
my_run_dynamics = _wrap(sampler.run_sampling, **sampler_kwargs)

### prepare input
# Load in the search space
with connect(search_path) as db:
    search_space = [Trajectory(i, x.toatoms(), name=x.get('filename', f'traj-{i}')) for i, x in enumerate(db.select(''))]
    shuffle(search_space)
    search_space = deque(search_space)
    
# Load in the training dataset
with connect(train_path) as db:
    all_examples = np.array([x.toatoms() for x in db.select("")], dtype=object)

    # Remove the unrealistic structures
    # if self.max_force is not None:
    #     all_examples = [a for a in all_examples if np.abs(a.get_forces()).max() < max_force]

## search space queue
to_audit: dict[int, Trajectory] = {}  # Trajectories that need to be audited
audit_results: deque[float] = deque(maxlen=50) # Results of the last 50 audits

# Prepare the initial model
StartModelMessage = TorchMessage(starting_model)
ActiveModelMessage = SchnetCalculator(starting_model) 
# Prepare the dataset
train_sets = []
valid_sets = []
n_train = int(len(all_examples) * 0.9)
for _ in range(n_models):
    shuffle(all_examples)
    train_sets.append(all_examples[:n_train])
    valid_sets.append(all_examples[n_train:])

In [18]:
## train model
for i in range(0, 1):
    model_msgs = []
    train_logs = []
    for i, train_set in enumerate(valid_sets):
        model_msg, train_log = my_train_schnet(model_msg=StartModelMessage,train_data = train_set, valid_data = valid_sets[i])
        model_msgs.append(model_msg)
        train_logs.append(train_log)

    ## store model
    # now we just test one model 
    model_save_path = out_dir / "model.pth"
    with open(model_save_path, 'wb') as fp:
        torch.save(model_msgs[0].get_model(), fp)
    # Save the training data
    with open(out_dir / 'training-history.json', 'a') as fp:
        print(json.dumps(train_logs[0].to_dict(orient='list')), file=fp)
        
    active_model_proxy = SchnetCalculator(model_msgs[0].get_model())
    StartModelMessage = TorchMessage(model_msgs[0].get_model())
    model_msgs = []
    train_logs = []

In [None]:
## use model for sampling
# Pick the next eligible trajectory and start from the last validated structure
trajectory = search_space.popleft()
starting_point = trajectory.starting

# Initialize the structure if need be
if trajectory.current_timestep == 0:
    MaxwellBoltzmannDistribution(starting_point, temperature_K=100)
    print('Initialized temperature to 100K')
# Add the structure to a list of those being validated
to_audit[trajectory.id] = trajectory

# Determine the run length based on observations of errors
run_length = min_run_length
if len(audit_results) > n_qc_workers:
    # Predict run length given audit error
    error_per_step = np.median(audit_results)
    target_error = energy_tolerance * 2
    estimated_run_length = int(target_error / error_per_step)
    print(f'Estimated run length of {estimated_run_length} steps to have an error of {target_error:.3f} eV/atom')
    run_length = max(min_run_length, min(max_run_length, estimated_run_length))  # Keep to within the user-defined bounds
    

my_run_dynamics()