In [None]:
import numpy as np

import matplotlib.pyplot as plt

from numba import njit, vectorize, float32

from monk import nb

from typing import Callable, Optional

import hoomd
import hoomd.forward_flux.forward_flux as ffs
import gsd.hoomd

from monk import prep
from monk import pair
import ex_render
import pandas as pd

from monk import nb

import pathlib

In [None]:
import importlib
importlib.reload(ffs);

In [None]:
data_dir = pathlib.Path("/home/ian/Documents/Data/monk/states")
data = {}
for file in data_dir.glob("*.gsd"):
    name = file.name
    temp = float(name.split("temp-")[-1].split("_seed")[0])
    if temp >= .65 or temp < 0.5:
        continue

    sub_data = {}
    sub_data["traj"] = file
    sub_data["soft"] = file.parents[0].glob(f"temp-{temp}*.parquet").__next__()
    data[temp] = sub_data

In [None]:
file = "traj.gsd"
soft = "soft.parquet"

In [None]:
pd.read_parquet(soft)

In [None]:
from hoomd.custom import Action
import freud
import time

class TrackParticle(Action):

    def __init__(self, pid):
        self._pid = pid
        self.data = []

    def act(self, timestep):
        # if isinstance(self._simulation.device, hoomd.device.CPU):
        with self._state.cpu_local_snapshot as data:
            idx = data.particles.rtag[self._pid]
            self.data.append(np.array(data.particles.position[idx], copy=True))

@njit
def _diff_with_rtag(ref_pos, pos, rtags):
    out = np.zeros_like(pos)
    n = len(rtags)
    for tag_idx in range(n):
        idx = rtags[tag_idx]
        out[idx] = pos[idx] - ref_pos[tag_idx]
    return out

class ZeroDrift(Action):

    def __init__(self, reference_positions, box):
        self._ref_pos = reference_positions
        self._box = freud.box.Box.from_box(box)

    @classmethod
    def from_state(cls, state: hoomd.State):
        return cls(state.get_snapshot().particles.position, state.box)

    def act(self, timestep):
        with self._state.cpu_local_snapshot as data:
            pos = data.particles.position._coerce_to_ndarray()
            rtags = data.particles.rtag._coerce_to_ndarray()
            diff = self._box.wrap(_diff_with_rtag(self._ref_pos, pos, rtags))
            dx = np.mean(diff, axis=0)
            data.particles.position = self._box.wrap(data.particles.position - dx)


In [None]:
import contextlib

In [None]:
def get_basins(snap, ids, seed, temp, steps=10_000, skip=10) -> np.ndarray:

    output = []

    print("temp:", temp)

    cpu = hoomd.device.CPU()
    sim = ffs.ForwardFluxSimulation(cpu, ids[0], seed=seed)
    sim.create_state_from_snapshot(snap)
    integrator = hoomd.md.Integrator(dt=0.0025)
    tree = hoomd.md.nlist.Tree(0.3)
    lj = pair.KA_LJ(tree)

    lang = hoomd.md.methods.Langevin(hoomd.filter.All(), temp)
    integrator.forces = [lj]
    integrator.methods = [lang]
    sim.operations.integrator = integrator

    # set zero drift
    pos = sim.state.get_snapshot().particles.position
    box = sim.state.box
    trigger = hoomd.trigger.Periodic(1)
    remove_drift = hoomd.update.CustomUpdater(trigger, ZeroDrift(pos, box))
    sim.operations.updaters.clear()
    sim.operations.updaters.append(remove_drift)

    basin_output = []

    for id in ids:
        print("    id:", id)
        sim.pid = int(id)
        basin_op = sim.sample_basin(steps, skip)
        basin_output.append(basin_op)
        sim.reset_state()

    return basin_output

In [None]:
def run_ff_job(snap, ids, seed, temp, override_basin=None) -> np.ndarray:

    output = []

    print("temp:", temp)

    cpu = hoomd.device.CPU()
    sim = ffs.ForwardFluxSimulation(cpu, ids[0], seed=seed)
    sim.create_state_from_snapshot(snap)
    integrator = hoomd.md.Integrator(dt=0.0025)
    tree = hoomd.md.nlist.Tree(0.3)
    lj = pair.KA_LJ(tree)

    lang = hoomd.md.methods.Langevin(hoomd.filter.All(), temp)
    integrator.forces = [lj]
    integrator.methods = [lang]
    sim.operations.integrator = integrator

    # set zero drift
    pos = sim.state.get_snapshot().particles.position
    box = sim.state.box
    trigger = hoomd.trigger.Periodic(1)
    remove_drift = hoomd.update.CustomUpdater(trigger, ZeroDrift(pos, box))
    sim.operations.updaters.clear()
    sim.operations.updaters.append(remove_drift)

    if override_basin is not None:
        sim.sample_basin(0, 1)
        sim.basin_barrier = override_basin
        print("Universal barrier:", sim.basin_barrier)

    arr_barriers = []
    arr_rates = []

    for id in ids:
        print("    id:", id)
        sim.pid = int(id)
        if override_basin is None:
            basin_op = sim.sample_basin(1_000, 1)
            barrier = np.quantile(basin_op, .99)
            print("    barrier:", barrier)
            sim.basin_barrier = barrier
        sim.reset_state()
        # with contextlib.redirect_stdout(None):
        rate, barriers, rates = sim.run_ff(10_000, collect=100, trials=100, barrier_step=0.01, floor=1e-5, op_thresh=sim.basin_barrier + 0.1, thresh=0.95, verbose=True)
        output.append(np.log(rate))
        arr_barriers.append(barriers)
        arr_rates.append(rates)
        sim.reset_state()
    return np.array(output), arr_barriers, arr_rates

In [None]:
frame = 80
snap = gsd.hoomd.open(file)[frame]
t = snap.log["NVT/kT"][0]
print(t)
df = pd.read_parquet(soft)
ids = np.unique(df["ids"])[:10]
basin_dist = get_basins(snap, ids, 1234, t, steps=50_000, skip=1)

In [None]:
plt.figure(dpi=150)
for dist in basin_dist:
    plt.hist(dist[:5000], bins=np.linspace(0.0, 0.5, 25), histtype="step", density=True)

plt.ylabel(r"$P(\Delta x)$")
plt.xlabel(r"$\Delta x$")

In [None]:
plt.figure(dpi=150)
for dist in basin_dist:
    plt.hist(dist, bins=np.linspace(0.0, 0.7, 25), histtype="step", density=True)

plt.ylabel(r"$P(\Delta x)$")
plt.xlabel(r"$\Delta x$")

In [None]:
frame = 80
Np = 100
snap = gsd.hoomd.open(file)[frame]
t = snap.log["NVT/kT"][0]
print(t)
df = pd.read_parquet(soft)
ids = np.unique(df["ids"])[:Np]
dF, arr_barriers, arr_rates = run_ff_job(snap, ids, 1238, t)
frame_soft = df[df.frames == frame].softness.values[:Np]

In [None]:
plt.figure(dpi=150)
plt.scatter(frame_soft, np.exp(dF))
plt.xlabel(r"$S$")
plt.ylabel(r"$\Delta F$")

In [None]:
arr_rates[0][0]

In [None]:
from matplotlib import cm, colors

In [None]:
plt.figure(dpi=150)

cmap = cm.viridis
norm = colors.Normalize(0, 9)

for i in range(len(arr_rates)):
    i_arr_rates = arr_rates[i]
    for j in range(len(i_arr_rates)):
        plt.plot(i_arr_rates[j], color=cmap(norm(i)))

# plt.yscale('log')

In [None]:
plt.figure(dpi=150)
plt.scatter(frame_soft, -dF)
plt.xlabel(r"$S$")
plt.ylabel(r"$\Delta F$")

In [None]:
plt.figure(dpi=150)
plt.scatter(frame_soft, np.exp(dF))
plt.xlabel(r"$S$")
plt.ylabel(r"$e^{-\Delta F}$")

In [None]:
for i, (t, d) in enumerate(data.items()):
    out_data = {}
    df = pd.read_parquet(d["soft"])
    ids = np.unique(df["ids"])
    frames = np.unique(df["frames"])
    frame = int(frames[0])
    frame_data = {}
    traj_file = d["traj"]
    snap = gsd.hoomd.open(str(traj_file))[frame]
    dF = run_ff_job(snap, ids, 1234+i, t)
    frame_data["dF"] = dF
    frame_data["S"] = df[df.frames == frame].softness.values
    out_data[frame] = frame_data
    data[t]["ff"] = out_data
    break

In [None]:
data.values().__iter__().__next__()["ff"][0]["dF"].shape

In [None]:
sim.basin_barrier = 0.2936840136269137

In [None]:
sim.reset_state()
rate, data = sim.run_ff(10_000, collect=10, trials=500, barrier_step=0.01, flex_step=0.0002, floor=1e-20, op_thresh=sim.basin_barrier+0.1, thresh=0.9, thermalize=10_000)

In [None]:
sim.PID

In [None]:
sim.reset_state()
rate, data = sim.run_ff(100_000, collect=50, trials=100, barrier_step=0.01, flex_step=0.0002, floor=1e-20, op_thresh=sim.basin_barrier+0.1, thresh=0.9)

In [None]:
rate