In [None]:
import numpy as np

import matplotlib.pyplot as plt
import glob

from numba import njit, vectorize, float32

from typing import Callable, Optional

import hoomd
import hoomd.forward_flux.forward_flux as ffs
import gsd.hoomd

from matplotlib import cm, colors

import sys

In [None]:
sys.path.insert(0, "/home/ian/Projects/work/monk/src")
sys.path

In [None]:
from monk import nb
from monk import prep
from monk import pair
import freud
import ex_render

In [None]:
import importlib
importlib.reload(ffs);

In [None]:
def gen_highT_state():
    seed = 1000
    cpu = hoomd.device.auto_select()
    print(cpu)
    sim = hoomd.Simulation(cpu, seed=seed)
    N = 64
    rng = prep.init_rng(seed + 1)
    L = prep.len_from_phi(N, 1.2, dim=2)
    snap = prep.approx_euclidean_snapshot(N, L, rng, dim=2, ratios=[60,40], diams=[1.0, 0.88])

    sim.create_state_from_snapshot(snap)

    integrator = hoomd.md.Integrator(dt=0.0025)
    tree = hoomd.md.nlist.Tree(0.3)
    lj = pair.KA_LJ(tree)
    nvt = hoomd.md.methods.NVT(hoomd.filter.All(), 1.0, 1.0)
    integrator.forces = [lj]
    integrator.methods = [nvt]

    sim.always_compute_pressure = True
    thermodynamic_properties = hoomd.md.compute.ThermodynamicQuantities(
        filter=hoomd.filter.All())
    sim.operations.computes.append(thermodynamic_properties)

    sim.operations.integrator = integrator

    sim.run(0)

    print(thermodynamic_properties.pressure)

    for i in range(10):

        sim.run(40_001, True)

        print(i, thermodynamic_properties.pressure)

    hoomd.write.GSD.write(sim.state, "init-state-2d.gsd")

    return
    

gen_highT_state()

In [None]:
cpu = hoomd.device.CPU()
seed = 3412
sim = ffs.ForwardFluxSimulation(cpu, 20, seed=seed)
sim.create_state_from_gsd("init-state-2d.gsd")

In [None]:
from hoomd.custom import Action
import freud
import time

class TrackParticle(Action):

    def __init__(self, pid):
        self._pid = pid
        self.data = []

    def act(self, timestep):
        # if isinstance(self._simulation.device, hoomd.device.CPU):
        with self._state.cpu_local_snapshot as data:
            idx = data.particles.rtag[self._pid]
            self.data.append(np.array(data.particles.position[idx], copy=True))

@njit
def _diff_with_rtag(ref_pos, pos, rtags):
    out = np.zeros_like(pos)
    n = len(rtags)
    for tag_idx in range(n):
        idx = rtags[tag_idx]
        out[idx] = pos[idx] - ref_pos[tag_idx]
    return out

class ZeroDrift(Action):

    def __init__(self, reference_positions, box):
        self._ref_pos = reference_positions
        self._box = freud.box.Box.from_box(box)

    @classmethod
    def from_state(cls, state: hoomd.State):
        return cls(state.get_snapshot().particles.position, state.box)

    def act(self, timestep):
        with self._state.cpu_local_snapshot as data:
            pos = data.particles.position._coerce_to_ndarray()
            rtags = data.particles.rtag._coerce_to_ndarray()
            diff = self._box.wrap(_diff_with_rtag(self._ref_pos, pos, rtags))
            dx = np.mean(diff, axis=0)
            data.particles.position = self._box.wrap(data.particles.position - dx)


In [None]:
integrator = hoomd.md.Integrator(dt=0.0025)
tree = hoomd.md.nlist.Tree(0.3)
lj = pair.KA_LJ(tree)

lang = hoomd.md.methods.Langevin(hoomd.filter.All(), 0.4)
integrator.forces = [lj]
integrator.methods = [lang]
sim.operations.integrator = integrator

# set zero drift
pos = sim.state.get_snapshot().particles.position
box = sim.state.box
trigger = hoomd.trigger.Periodic(1)
remove_drift = hoomd.update.CustomUpdater(trigger, ZeroDrift(pos, box))
sim.operations.updaters.clear()
sim.operations.updaters.append(remove_drift)

sim.operations.writers.clear()

In [None]:
print(lang.gamma.default)

In [None]:
sim.run(40_000)

In [None]:
writer = hoomd.write.GSD(trigger=hoomd.trigger.Periodic(10), filename=f"sampling-algo-2.gsd", mode="wb", filter=hoomd.filter.All(), dynamic=["property", "momentum"])
sim.operations.writers.append(writer)

In [None]:
basin_ops, valid_basins, conf_dists = sim.sample_all_sub_basins(40_000, 1, check_interval=100, reset_if_basin_left=True)

In [None]:
plt.figure(dpi=200)
plt.plot(conf_dists)
plt.hlines(0.5, *plt.xlim(), color="k", linestyle="--")
plt.ylabel(r"$|\Omega_i - \Omega_0|$")
plt.xlabel(r"$t$")

In [None]:

basin = np.array(basin_ops)[valid_basins]
print(basin.shape)
data = []
for i in range(basin.shape[2]):
    data.append(np.quantile(basin[:,:,i], .99))

plt.hist(basin.flatten(), bins=20)