In [1]:
import sys
sys.path.insert(0, '../../')

import numpy as np
import matplotlib.pyplot as plt
import time
import os

import torch
import torch.nn as nn

import pandas as pd

import ase

import time
import argparse

from models.descriptors import DihedralDescriptors, DistanceDescriptors, CoordinateDescriptors

from models.basis_set_bias import BasisBias
from models.gaussian_models import GridGaussians

from simulator.simple_diff_md import simulateNVTSystem, simulateNVTSystem_adjoint, simulateNVTSystem_warmup
from simulator.descriptor_loss import DescriptorLoss
from simulator.trainer import train_epoch
from simulator.adjoint_provider import get_adjoints

from plotting.plot_intermediate import plot_intermediate

from scipy.stats import chi2

parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument('--save_figs', action='store_true')
parser.add_argument('--no_adjoint', action='store_true')
parser.add_argument("--plot_every", default=1, type=int, help="When to plot figures.")

# Main simulation parameters
# Descriptor loss
parser.add_argument("--iterations", default=6000, type=int, help="Number of iterations to generate training data.")
parser.add_argument("--warmup", default=1200, type=int, help="Warm-up period for the trajectory termalization")
parser.add_argument("--backsteps", default=190, type=int, help="Number of backsteps. Must be smaller as warmup")
parser.add_argument("--plot_nth", default=10, type=int, help="Plotting every nth point.")


parser.add_argument("--batch_size", default=300, type=int, help="Batch size.")
parser.add_argument("--epochs", default=100, type=int, help="Number of training epochs.")
parser.add_argument("--save_steps", default=1, type=int, help="When to save tensor")
parser.add_argument("--dt", default=1.0, type=float, help="Timestep in fs")
parser.add_argument("--barrier", default=1.0, type=float, help="Barier size.")
parser.add_argument("--bias", default="gauss", type=str, help="Biased potential")
parser.add_argument("--dimension", default=2, type=int, help="Biased potential")
parser.add_argument("--neurons", default=150, type=int, help="Neurons in a net")
parser.add_argument("--loss", default="quad", type=str, help="Quadratic function")
parser.add_argument("--p_in_domain", default=0.1, type=float, help="Tolerance to it the target.")

#Simulation controls
parser.add_argument("--temperature", default=10, type=float, help="System temperature.")
parser.add_argument("--gamma", default=0.1, type=float, help="Friction in langevin.")

#Training parameters
parser.add_argument("--learning_factor", default=2.0, type=float, help="Learning rate.")
parser.add_argument("--mini_batch", default=120, type=int, help="Mini batch")
parser.add_argument('--use_non_batched', action='store_true')
parser.add_argument('--device', default="cuda:0", type=str, help="Which device to use. CPU or GPU?")
parser.add_argument('--save_every_model', default=1, type=int, help="When to save the bias potential")
args = parser.parse_args("")

args.save_figs = True

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'models'

In [None]:
#Define som runtimes stuff    
if args.save_steps > 1:
    raise NotImplemented("Feature not yet implemented and tested")

args.warmup = args.warmup//args.save_steps
args.backsteps = args.backsteps//args.save_steps

#Distance d - arg of normal distribution follows chi2. We therefore consult scipy package and its chi2
#implementation to get the actual d**2 tolerance so that P(x in X) >= p where p defined in args

args.domain_tol = chi2.isf(args.p_in_domain, args.dimension)

#The numerical stability factor. Adjoints tend to be too small and numerical accuracy is an issue. 
adjoint_multiplier = 1e4


assert args.iterations > 3* args.warmup
assert args.backsteps < args.warmup
torch.manual_seed(args.seed)

if args.no_adjoint:
    forward_simulate = simulateNVTSystem
else:
    forward_simulate = simulateNVTSystem_adjoint


first_checkpoint = False
first_thr = 0.51
second_checkpoint = False
second_thr = 0.71
#third_checkpoint = True
#fourth_checkpoint = True


args.folder = "resultsMB/"+args.loss+args.bias+ "D" + str(args.dimension)+"I:"+str(args.iterations) \
    + "N"+ str(args.neurons) + "dt" + str(args.dt) + "T" + str(args.temperature) + "g" \
    + str(args.gamma) +"/"
isExist = os.path.exists(args.folder)

if not isExist and (args.save_figs or args.save_every_model > 0):
    os.makedirs(args.folder)
    print("making folder")

In [None]:
#Cell to design descriptors. 

if args.dimension == 2:
    project_to = torch.tensor([0,1])
elif args.dimension == 5:
    project_to = torch.tensor([0,2])

In [None]:
from simulator.system import ToySystem
from models.simpleMB import SimpleMB
from functorch import vmap, jvp, vjp, grad


potential = SimpleMB(args, n_in=args.dimension)

M = 0.1*torch.ones(1, args.dimension)
system = ToySystem(args.dimension, nreplicas=2*args.batch_size, masses=M, device=args.device)
system.set_positions(potential.get_init_point(args.batch_size))


In [None]:
#Define initialization and setup the run
from simulator.utils import maxwell_boltzmann, kinetic_energy, kinetic_to_temp, temp_to_kin

desired_kin = temp_to_kin(300, args.dimension)
    
def get_init_point(steps = 300):

    system.set_positions(potential.get_init_point(args.batch_size))
    system.set_velocities(maxwell_boltzmann(system.M, args.temperature, args.dimension, replicas=2*args.batch_size))
    kin = kinetic_energy(system.M, system.vel)

    factor = torch.sqrt(desired_kin/kin).reshape(-1,1).repeat(1,args.dimension)
    system.vel = system.vel*factor

    system.to_(args.device)
    #Pass a vmap on the non-biased forces
    simulateNVTSystem_warmup(system, vmap(potential.force_func), args, steps=steps)
    
    kin = kinetic_energy(system.M, system.vel)

    start, end = system.pos.detach().chunk(2,dim=0)
    return system, torch.cat((start, end)), torch.cat((end, start))


get_init_point(10)
#Pass a vmap on a non-biased force
pos_list = simulateNVTSystem_warmup(system, vmap(potential.force_func), args, steps=4200)
#Stack the list of points and split into reactant trajs and product trajs
pos_list = torch.stack(pos_list, axis=1)
react_pos = pos_list[:args.batch_size]
prod_pos = pos_list[args.batch_size:]

coordinate_descriptors = CoordinateDescriptors()

domain = {}
domain["Lx"], domain["Hx"] = 10, 50
domain["Ly"], domain["Hy"] = 0, 40
domain["res"] = 100

In [None]:
if args.bias == "gauss":
    height = 0.01
    resolution = 50
    var = 3**2
    sample_bias_force = GridGaussians(height, args.dimension, -10, 55, resolution, var=var, descriptors=coordinate_descriptors, device=args.device)
    if args.use_non_batched or args.no_adjoint:
        lr = 0.001/adjoint_multiplier
    else:
        lr = 50*args.learning_factor/adjoint_multiplier/args.batch_size/10 
        
if args.bias == "descriptor_basis":   
    
    height = 0.01
    resolution = 50
    var = 25/10
    sample_bias_force = BasisBias(height,args.dimension, -5, 55,
                                       resolution=resolution, var=var, neurons=args.neurons,
                                       descriptors=coordinate_descriptors, device=args.device)
    if args.use_non_batched or args.no_adjoint:
        lr = 0.1/adjoint_multiplier
    else:
        lr = args.learning_factor/adjoint_multiplier/args.batch_size         
    
    
sample_bias_force.to(args.device)
   
#Adam optimizer seems to work the best
optimizer = torch.optim.Adam(sample_bias_force.parameters(), lr=lr)

#Since we update scheduler only 
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.3)

#Descriptors
#descriptor_loss = DescriptorLoss(coordinate_descriptors, react_var_inv, prod_var_inv, args)

def sample_total_force(R):
       
    Epot, f_U = potential.force_func(R)   
    f_b = sample_bias_force.force_func(R)
    return f_U + f_b

def get_forces_vjp(R, vec):
    (_, vjpfunc) = vjp(sample_total_force, R)
    grad = vjpfunc(vec)[0]
    return grad


vmaped_force = vmap(sample_total_force)
vmaped_vjp = vmap(get_forces_vjp)

In [None]:
trajectory, _, _ = forward_simulate(system, vmaped_force, args)


sample_bias_force.load_state_dict(torch.load("model65.pt"))

trajectory_conv, _, _ = forward_simulate(system, vmaped_force, args)


In [None]:
from scipy.interpolate import griddata

tick_font_size = 14
label_font_size = 18

traj_tensor = torch.stack(trajectory, axis=1)
traj_desc = coordinate_descriptors.get_descriptors(traj_tensor)   
known_cvs = coordinate_descriptors.get_descriptors(traj_tensor)[...,project_to]


Lx,Ly = domain["Lx"], domain["Ly"]
Hx,Hy = domain["Hx"], domain["Hy"]
resolution = domain["res"]

sample = traj_tensor[:,::args.plot_nth].reshape(-1, *traj_tensor.shape[2:])
desc_sample = traj_desc[:,::args.plot_nth].reshape(-1, *traj_desc.shape[2:])
cv_sample = known_cvs[:,::args.plot_nth].reshape(-1,2)

x, y = torch.meshgrid(torch.linspace(Lx, Hx, resolution), torch.linspace(Ly, Hy, resolution), indexing='ij')
x_dev, y_dev = x.to(args.device), y.to(args.device)

plt.figure(figsize = (12,6)) 
plt.subplot(121)

z = potential.U_split(x_dev, y_dev).reshape((resolution,resolution)).detach().cpu()
pot = plt.pcolormesh(x, y, z, cmap='magma', vmin=z.min(), vmax=10, shading='auto')
#cbar = plt.colorbar(pot, location="left")
#cbar.ax.set_ylabel('Potential [kcal/mol]',fontsize=13)

#plt.xlabel("x")
plt.axis([Lx, Hx, Ly, Hy])

alpha = 0.5

plt.hexbin(cv_sample[:,0], cv_sample[:,1].detach().cpu(), bins="log", alpha=alpha, cmap="viridis",gridsize=40, mincnt=5)
#plt.hexbin(cv_sample[:,0], cv_sample[:,1].detach().cpu(), bins=None, alpha=alpha, cmap="viridis",gridsize=40, mincnt=5)

cbar = plt.colorbar(location="left")
cbar.ax.tick_params(labelsize=tick_font_size)
cbar.ax.set_ylabel('Unbiased sampling - log-density',fontsize=label_font_size)
plt.axis([Lx, Hx, Ly, Hy])

plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    left=False,
    right=False,
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off

plt.subplot(122)

traj_tensor = torch.stack(trajectory_conv, axis=1)
traj_desc = coordinate_descriptors.get_descriptors(traj_tensor)   
known_cvs = coordinate_descriptors.get_descriptors(traj_tensor)[...,project_to]

sample = traj_tensor[:,::args.plot_nth].reshape(-1, *traj_tensor.shape[2:])
desc_sample = traj_desc[:,::args.plot_nth].reshape(-1, *traj_desc.shape[2:])
cv_sample = known_cvs[:,::args.plot_nth].reshape(-1,2)

x, y = torch.meshgrid(torch.linspace(Lx, Hx, resolution), torch.linspace(Ly, Hy, resolution), indexing='ij')
x_dev, y_dev = x.to(args.device), y.to(args.device)

z = potential.U_split(x_dev, y_dev).reshape((resolution,resolution)).detach().cpu()
pot = plt.pcolormesh(x, y, z, cmap='magma', vmin=z.min(), vmax=10, shading='auto')
#cbar = plt.colorbar(pot, location="left")
#cbar.ax.set_ylabel('Potential [kcal/mol]',fontsize=13)

#plt.xlabel("x")
plt.axis([Lx, Hx, Ly, Hy])

alpha = 0.5

plt.hexbin(cv_sample[:,0], cv_sample[:,1].detach().cpu(), bins="log", alpha=alpha, cmap="viridis",gridsize=40, mincnt=5)
#plt.hexbin(cv_sample[:,0], cv_sample[:,1].detach().cpu(), bins=None, alpha=alpha, cmap="viridis",gridsize=40, mincnt=5)
#cbar2 = plt.colorbar(location="right")
cbar2.ax.tick_params(labelsize=tick_font_size)
cbar2.ax.set_ylabel('Biased sampling - log-density',fontsize=label_font_size)
plt.axis([Lx, Hx, Ly, Hy])

#plt.ylabel("y")

plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    left=False,
    right=False,
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off


plt.tight_layout()

fig = plt.gcf()
fig.savefig("initial_density")
plt.show()
plt.close()