In [1]:
import neutromeratio
from openmmtools.constants import kB
from simtk import unit
import numpy as np
import pickle
import mdtraj as md
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
import torch
from neutromeratio.parameter_gradients import FreeEnergyCalculator
from neutromeratio.constants import kT, device, exclude_set
from glob import glob

# job idx
idx = int(1)
# where to write the results
base_path = str('/home/mwieder/')
env = str('droplet')
diameter_in_angstrom = int(18)
#######################

mode = 'forward'

# read in exp results, smiles and names
exp_results = pickle.load(open('../data/exp_results.pickle', 'rb'))

# name of the system
protocoll = []
for name in sorted(exp_results):
    if name in exclude_set:
        continue
    protocoll.append(name)

name = protocoll[idx-1]
print(name)

# don't change - direction is fixed for all runs
#################
t1_smiles = exp_results[name]['t1-smiles']
t2_smiles = exp_results[name]['t2-smiles']

exp_results[name]['energy']

_ColormakerRegistry()

[utils.py:141 - _init_num_threads()] NumExpr defaulting to 8 threads.


SAMPLmol2


-6.1

In [3]:
(-12.497892722203273 * kT).in_units_of(unit.kilocalorie_per_mole)

Quantity(value=-7.450768522014049, unit=kilocalorie/mole)

In [None]:
import neutromeratio
from openmmtools.constants import kB
from simtk import unit
import numpy as np
import pickle
import mdtraj as md
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
import torch
from neutromeratio.parameter_gradients import FreeEnergyCalculator
from neutromeratio.constants import kT, device, exclude_set
from glob import glob

def parse_lambda_from_dcd_filename(dcd_filename, env):
    return float(dcd_filename[:dcd_filename.find(f"_in_{env}")].split('_')[-1])


# job idx
idx = int(1)
# where to write the results
base_path = str('/home/mwieder/')
env = str('droplet')
diameter_in_angstrom = int(18)
#######################

mode = 'forward'

# read in exp results, smiles and names
exp_results = pickle.load(open('../data/exp_results.pickle', 'rb'))

# name of the system
protocoll = []
for name in sorted(exp_results):
    if name in exclude_set:
        continue
    protocoll.append(name)

name = protocoll[idx-1]
print(name)

# don't change - direction is fixed for all runs
#################
t1_smiles = exp_results[name]['t1-smiles']
t2_smiles = exp_results[name]['t2-smiles']


# generate both rdkit mol
tautomer = neutromeratio.Tautomer(name=name, initial_state_mol=neutromeratio.generate_rdkit_mol(t1_smiles), final_state_mol=neutromeratio.generate_rdkit_mol(t2_smiles), nr_of_conformations=20)
if mode == 'forward':
    tautomer.perform_tautomer_transformation_forward()
elif mode == 'reverse':
    tautomer.perform_tautomer_transformation_reverse()
else:
    raise RuntimeError('No tautomer reaction direction was specified.')

if env == 'droplet':
    tautomer.add_droplet(tautomer.hybrid_topology, 
                                tautomer.hybrid_coords, 
                                diameter=diameter_in_angstrom * unit.angstrom,
                                restrain_hydrogens=True,
                                file=f"{base_path}/{name}/{name}_in_droplet_{mode}.pdb")
    print('Nr of atoms: {}'.format(len(tautomer.ligand_in_water_atoms)))
    atoms = tautomer.ligand_in_water_atoms
    top = tautomer.ligand_in_water_topology
else:
    atoms = tautomer.hybrid_atoms
    top = tautomer.hybrid_topology

# define the alchemical atoms
alchemical_atoms=[tautomer.hybrid_hydrogen_idx_at_lambda_1, tautomer.hybrid_hydrogen_idx_at_lambda_0]


# extract hydrogen donor idx and hydrogen idx for from_mol
model = neutromeratio.ani.LinearAlchemicalDualTopologyANI(alchemical_atoms=alchemical_atoms, adventure_mode=True)
model = model.to(device)
torch.set_num_threads(2)

# perform initial sampling
energy_function = neutromeratio.ANI1_force_and_energy(
                                        model = model,
                                        atoms = atoms,
                                        mol = None,
                                        )


for r in tautomer.ligand_restraints:
    energy_function.add_restraint(r)

for r in tautomer.hybrid_ligand_restraints:
    energy_function.add_restraint(r)

if env == 'droplet':

    tautomer.add_COM_for_hybrid_ligand(np.array([diameter_in_angstrom/2, diameter_in_angstrom/2, diameter_in_angstrom/2]) * unit.angstrom)

    for r in tautomer.solvent_restraints:
        energy_function.add_restraint(r)

    for r in tautomer.com_restraints:
        energy_function.add_restraint(r)


# get steps inclusive endpoints
# and lambda values in list
dcds = glob(f"{base_path}/{name}/*.dcd")

lambdas = []
ani_trajs = []
energies = []

for dcd_filename in dcds:
    lam = parse_lambda_from_dcd_filename(dcd_filename, env)
    lambdas.append(lam)
    traj = md.load_dcd(dcd_filename, top=top)
    ani_trajs.append(traj)  
    f = open(f"{base_path}/{name}/{name}_lambda_{lam:0.4f}_energy_in_{env}_{mode}.csv", 'r')  
    tmp_e = []
    for e in f:
        tmp_e.append(float(e))
    f.close()
    energies.append(np.array(tmp_e))

In [2]:
snapshots = {}
K = len(ani_trajs)
for lam, traj in zip(lambdas, ani_trajs):

    #equil, g = detectEquilibration(self.potential_energy_trajs[i])[:2]
    #thinning = int(g)
    #if len(traj[equil::thinning]) > max_snapshots_per_window:
        # what thinning will give me len(traj[equil::thinning]) == max_snapshots_per_window?
    #    thinning = int((len(traj) - equil) / max_snapshots_per_window)

    new_snapshots = list(traj[0:200:10].xyz * unit.nanometer)
    snapshots[lam] = new_snapshots

In [3]:
def calculate_stddev(snapshots):
    lambda0_e_b_stddev = [energy_function.calculate_energy(x, lambda_value=0.0) for x in tqdm(snapshots)]
    lambda1_e_b_stddev = [energy_function.calculate_energy(x, lambda_value=1.0) for x in tqdm(snapshots)]

    # extract endpoint stddev
    lambda0_stddev = [stddev/kT for stddev in [e_b_stddev[2] for e_b_stddev in lambda0_e_b_stddev]]
    lambda1_stddev = [stddev/kT for stddev in [e_b_stddev[2] for e_b_stddev in lambda1_e_b_stddev]]
    return np.array(lambda0_stddev), np.array(lambda1_stddev)

def compute_linear_penalty(current_stddev, n_atoms):
    per_atom_thresh = 0.5 * unit.kilojoule_per_mole
    total_thresh = per_atom_thresh * n_atoms
    linear_penalty = np.maximum(0, current_stddev - (total_thresh/kT))
    return linear_penalty

def compute_last_valid_ind(linear_penalty):
    last_valid_ind = np.argmax(np.cumsum(linear_penalty) > 0)
    if last_valid_ind == 0:
        return len(linear_penalty)
    return last_valid_ind

In [4]:
n_atoms = len(atoms)
last_valid_inds = {}
for lam in snapshots:
    lambda0_stddev, lambda1_stddev = calculate_stddev(snapshots[lam])
    current_stddev = (1 - lam) * lambda0_stddev + lam * lambda1_stddev
    print(current_stddev)
    linear_penalty = compute_linear_penalty(current_stddev, n_atoms)
    last_valid_ind = compute_last_valid_ind(linear_penalty)
    print(last_valid_ind)
    last_valid_inds[lam] = last_valid_ind

100%|██████████| 20/20 [00:25<00:00,  1.28s/it]
100%|██████████| 20/20 [00:32<00:00,  1.63s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[35.54565312 33.71561249 29.49961212 35.12164857 28.42659966 28.84589842
 34.7389601  32.82288587 34.58596583 31.23591895 37.81755263 23.92589236
 30.32409965 34.87454346 31.83022945 29.52869658 32.58677437 27.17266795
 25.70561765 27.1862321 ]
20


100%|██████████| 20/20 [00:29<00:00,  1.48s/it]
100%|██████████| 20/20 [00:30<00:00,  1.53s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[41.72148378 38.73133952 31.88020362 34.36612729 29.13551649 34.84994173
 36.83820444 33.20671832 32.86367665 24.70991068 36.71820588 30.74188788
 31.87267347 33.8298714  31.75555156 33.35595974 32.21869838 31.34417146
 24.84004115 31.51509683]
20


100%|██████████| 20/20 [00:24<00:00,  1.20s/it]
100%|██████████| 20/20 [00:26<00:00,  1.32s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[40.44271442 24.94409478 26.43278515 31.85163946 30.86227419 37.39729724
 29.40021184 30.34762301 33.63964266 33.59299758 23.18325027 26.964528
 33.13475693 34.42152901 29.48434521 27.21187419 34.20667277 28.84620325
 31.61685931 31.71508578]
20


100%|██████████| 20/20 [00:24<00:00,  1.24s/it]
100%|██████████| 20/20 [00:22<00:00,  1.10s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[46.15737098 35.6052868  25.98284831 28.2429934  32.70070589 37.29959716
 30.33835587 31.31561385 27.91765055 23.28349003 24.24664472 28.62384596
 18.65662589 26.0779154  29.59454891 27.91009697 31.85803607 29.56836541
 27.45507979 14.54543321]
20


100%|██████████| 20/20 [00:26<00:00,  1.32s/it]
100%|██████████| 20/20 [00:27<00:00,  1.37s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[ 38.29787107  34.4244135   35.08730245  28.07936982  20.0920642
 208.05610888 183.16983726 185.46570531 213.07385304 229.08609806
 356.22998509 358.43830078 387.86661441 398.88506542 395.7280576
 410.04981822 401.74697774 363.12417588 364.22240275 348.51920725]
5


100%|██████████| 20/20 [00:23<00:00,  1.19s/it]
100%|██████████| 20/20 [00:23<00:00,  1.15s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[40.58331105 32.13255434 32.95327201 28.59065102 22.00198139 23.43457606
 27.5631559  32.80517574 34.05353087 33.77547482 28.3979139  27.29898529
 36.3803739  36.42622035 35.50603472 26.16799797 33.02075541 25.06107507
 24.73598227 28.06990605]
20


100%|██████████| 20/20 [00:18<00:00,  1.06it/s]
100%|██████████| 20/20 [00:22<00:00,  1.12s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[41.01952917 35.92791427 29.49766461 28.51206456 31.50395145 28.26746279
 29.13664009 35.54779406 33.19614841 30.77976507 29.60929667 31.08870692
 31.13978028 32.43491201 29.10685386 38.48922982 30.75754816 38.21523627
 36.75672472 28.59682505]
20


100%|██████████| 20/20 [00:20<00:00,  1.01s/it]
100%|██████████| 20/20 [00:17<00:00,  1.13it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

[39.88798101 35.97334771 26.79890919 33.07582514 31.54080455 28.86927093
 33.06580096 31.22336624 25.69023056 36.38982965 27.12418581 29.88396135
 35.69470965 32.90008598 33.0320839  30.20140639 26.56629873 28.44877166
 25.84409672 28.67566608]
20


100%|██████████| 20/20 [00:19<00:00,  1.02it/s]
100%|██████████| 20/20 [00:19<00:00,  1.03it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

[38.24179285 28.97493632 32.41918293 33.26598473 33.50118178 35.13125356
 29.58579549 34.57823759 41.09349918 41.1183326  30.34817239 29.80618374
 27.17354682 35.65684492 32.13536156 30.92076848 30.99388956 31.49435006
 25.94334256 21.02792615]
20


100%|██████████| 20/20 [00:19<00:00,  1.05it/s]
100%|██████████| 20/20 [00:18<00:00,  1.10it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

[46.1834761  32.66728587 31.51728425 40.45116409 34.64499419 37.91720263
 35.43416229 37.03957697 37.01616543 31.82742391 41.87122316 39.29944744
 36.82820255 30.09667135 27.80991085 34.42472212 37.65217663 31.54870036
 39.60482486 34.91360981]
20


100%|██████████| 20/20 [00:18<00:00,  1.09it/s]
100%|██████████| 20/20 [00:18<00:00,  1.08it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

[43.57513038 28.60957003 37.94625024 30.77069923 30.46899849 33.53683015
 34.97683134 25.40852792 38.77668925 30.30336011 31.41455242 28.94106247
 31.90569855 25.39290467 28.18572793 30.88366228 30.15266686 29.74880194
 27.5626743  32.76944882]
20


100%|██████████| 20/20 [00:17<00:00,  1.12it/s]
100%|██████████| 20/20 [00:19<00:00,  1.04it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

[47.26024081 35.09300904 21.77947278 25.11532637 35.21489917 36.41663044
 34.37220505 38.46133603 28.67246396 30.46436803 31.18974554 38.037456
 37.32668544 29.52730252 36.45051072 26.26305664 27.68806138 30.80578608
 38.67635058 38.12031427]
20


100%|██████████| 20/20 [00:27<00:00,  1.40s/it]
100%|██████████| 20/20 [00:26<00:00,  1.31s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[45.73095117 26.04454042 29.4960406  28.02954222 34.7709307  22.8370641
 29.58921962 25.73595664 27.94140644 26.35951182 23.63759827 26.03907095
 24.76474791 29.07038191 25.71786377 27.32031413 29.32155971 23.33855899
 30.6725355  26.80648966]
20


100%|██████████| 20/20 [00:22<00:00,  1.10s/it]
100%|██████████| 20/20 [00:22<00:00,  1.11s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[37.54358662 30.2900024  34.09328763 27.93152877 36.63011074 26.22492089
 34.32889694 26.67190048 40.3727124  25.12323281 31.24295905 37.33414244
 30.3095151  34.92665718 28.14334922 29.54809177 31.90415609 29.36927991
 31.73884801 33.91766562]
20


100%|██████████| 20/20 [00:22<00:00,  1.11s/it]
100%|██████████| 20/20 [00:21<00:00,  1.07s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[39.05055415 29.93860774 27.45058323 35.85458565 29.56272677 30.06576014
 25.82851427 32.87134535 41.03806674 34.41615841 35.26920515 31.11600636
 32.67023087 35.56415749 28.9855766  28.39164101 34.810115   40.87836754
 39.12007534 35.68087337]
20


100%|██████████| 20/20 [00:22<00:00,  1.12s/it]
100%|██████████| 20/20 [00:22<00:00,  1.10s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[52.07462592 39.76168002 39.31769193 43.26362067 32.68015797 34.16726375
 33.73548969 25.30203982 32.08159089 34.46871562 30.54997928 30.80987445
 34.66375027 35.11952797 28.04581816 31.88348562 24.8998702  23.64892323
 28.9290197  27.96427029]
20


100%|██████████| 20/20 [00:22<00:00,  1.11s/it]
100%|██████████| 20/20 [00:22<00:00,  1.11s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[ 62.12772794  29.54313021  31.14450569 164.64444124 172.75312275
 261.03048783 235.62341533 230.88679027 261.88108645 255.97159844
 246.37548343 217.45338443 263.86865841 240.83433059 234.11825385
 223.26204634 242.77766753 259.86446044 233.46756772 222.97282436]
3


100%|██████████| 20/20 [00:21<00:00,  1.09s/it]
100%|██████████| 20/20 [00:22<00:00,  1.11s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[61.65218855 27.48866783 26.58840295 27.46314014 22.04887438 27.75885158
 32.01814645 30.42583024 27.44994648 33.50360583 29.036792   26.2725234
 30.53829095 40.20398269 35.6060967  27.29138661 29.86178627 34.84154251
 32.42171178 36.59129979]
20


100%|██████████| 20/20 [00:23<00:00,  1.19s/it]
100%|██████████| 20/20 [00:25<00:00,  1.29s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[47.70792313 32.88582664 33.31627715 29.83241976 29.90668104 32.75142959
 31.6528804  24.8463055  35.83958659 29.91163244 29.59900115 32.31458767
 33.83971691 31.83825474 33.69748973 32.4244972  27.66529539 31.05198847
 32.29890764 31.7087311 ]
20


100%|██████████| 20/20 [00:24<00:00,  1.24s/it]
100%|██████████| 20/20 [00:25<00:00,  1.26s/it]
  0%|          | 0/20 [00:00<?, ?it/s]

[  53.65440113   41.87982487   34.65762627   29.47528006   48.07861583
   55.70499074  151.88327027  192.28808121  211.19113176  355.74530355
  533.62426845  622.72381216  757.59083771  881.66462689  975.34427541
 1132.20267114 1118.37011242 1266.47942677 1293.79382527 1387.36391851]
6


100%|██████████| 20/20 [00:24<00:00,  1.24s/it]
100%|██████████| 20/20 [00:24<00:00,  1.24s/it]

[41.6987579  32.42672286 32.26539073 35.0023438  34.23746625 32.39632299
 31.40239445 33.18019432 35.74752288 34.65752252 26.18046237 33.96806234
 26.49155486 37.52926646 24.57156465 30.76052227 27.32622002 32.80497807
 34.60298191 37.57629227]
20





In [5]:
for lam in sorted(list(last_valid_inds.keys())):
    print('lam={}: usable snapshots: {}'.format(lam, max(0, 10 * last_valid_inds[lam])))

lam=0.0: usable snapshots: 200
lam=0.05: usable snapshots: 200
lam=0.1: usable snapshots: 200
lam=0.15: usable snapshots: 200
lam=0.2: usable snapshots: 200
lam=0.25: usable snapshots: 200
lam=0.3: usable snapshots: 50
lam=0.35: usable snapshots: 200
lam=0.4: usable snapshots: 200
lam=0.45: usable snapshots: 200
lam=0.5: usable snapshots: 200
lam=0.55: usable snapshots: 200
lam=0.6: usable snapshots: 200
lam=0.65: usable snapshots: 200
lam=0.7: usable snapshots: 200
lam=0.75: usable snapshots: 200
lam=0.8: usable snapshots: 200
lam=0.85: usable snapshots: 30
lam=0.9: usable snapshots: 200
lam=0.95: usable snapshots: 60
lam=1.0: usable snapshots: 200


In [None]:
snapshots = []
N_k = []

max_n_snapshots_per_state = 10

for lam in lambdas_with_usable_samples:
    traj = trajs[lam][5:last_valid_inds[lam]]
    further_thinning = 1
    if len(traj) > max_n_snapshots_per_state:
        further_thinning = int(len(traj) / max_n_snapshots_per_state)
    new_snapshots = list(traj.xyz[::further_thinning] * unit.nanometer)
    snapshots.extend(new_snapshots)
    N_k.append(len(new_snapshots))

N = len(snapshots)
N_k, N

In [None]:
from collections import namedtuple
RawEnergies = namedtuple('RawEnergies',
                         ['raw_energies_without_dummy_0',
                          'raw_energies_without_dummy_1'
                         ]
                        )
raw_energy_dict = {}
for dcd in dcds:
    print(dcd)
    lam = parse_lambda_from_filename(dcd)
    traj = md.load_dcd(dcd, top=tautomer.ligand_in_water_topology)
    raw_energies_without_dummy_0, raw_energies_without_dummy_1 = compute_endstate_atomic_energy_contributions(traj[::10], model, energy_function)
    raw_energies = RawEnergies(raw_energies_without_dummy_0, raw_energies_without_dummy_1)
    raw_energy_dict[lam] = raw_energies