In [None]:
%cd trpcage

In [None]:
import mdtraj as md
import torch
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

## Input files

All input files are prepared (up- or downloaded) in [prepare.ipynb](prepare.ipynb). 


In [None]:
exec(open('inputs.py').read())

## Save the model for Gromacs

*Another wave of magics ...*

There are multiple ways how atoms are numbered in PDB, GRO, etc. files. 

So far we worked with atoms numbered as in the `conf` PDB file, assuming `traj` XTC file was consistent with those.
If the topology was used, it might have had different numbering, as Gromacs likes. 

In the subsequent simulations, we assume the usual protocol starting with `pdb2gmx` to generate topology,
hence Gromacsish atom numbering will be followed afterwards.
Therefore we need `plumed.dat` to pick the atoms according to the PDB file order, and skip hydrogens added by Gromacs. 

Many things can go wrong, therefore we strongly encorage to check the results manually. For example, the first residuum (ASP) of tryptophan cage may look like the following in PDB file:

    ATOM      1  N   ASP     1      28.538  39.747  31.722  1.00  1.00           N
    ATOM      2  CA  ASP     1      28.463  39.427  33.168  1.00  1.00           C
    ATOM      3  C   ASP     1      29.059  37.987  33.422  1.00  1.00           C
    ATOM      4  O   ASP     1      30.226  37.748  33.735  1.00  1.00           O
    ATOM      5  CB  ASP     1      26.995  39.482  33.630  1.00  1.00           C
    ATOM      6  CG  ASP     1      26.889  39.307  35.101  1.00  1.00           C
    ATOM      7  OD1 ASP     1      27.749  39.962  35.773  1.00  1.00           O
    ATOM      8  OD2 ASP     1      26.012  38.510  35.611  1.00  1.00           O
    
Which turns Gromacs topology: 

     1         N3      1    ASP      N      1     0.0782      14.01   ; qtot 0.0782
     2          H      1    ASP     H1      2       0.22      1.008   ; qtot 0.2982
     3          H      1    ASP     H2      3       0.22      1.008   ; qtot 0.5182
     4          H      1    ASP     H3      4       0.22      1.008   ; qtot 0.7382
     5         CT      1    ASP     CA      5     0.0292      12.01   ; qtot 0.7674
     6         HP      1    ASP     HA      6     0.1141      1.008   ; qtot 0.8815
     7         CT      1    ASP     CB      7    -0.0235      12.01   ; qtot 0.858
     8         HC      1    ASP    HB1      8    -0.0169      1.008   ; qtot 0.8411
     9         HC      1    ASP    HB2      9    -0.0169      1.008   ; qtot 0.8242
    10          C      1    ASP     CG     10     0.8194      12.01   ; qtot 1.644
    11         O2      1    ASP    OD1     11    -0.8084         16   ; qtot 0.8352
    12         O2      1    ASP    OD2     12    -0.8084         16   ; qtot 0.0268
    13          C      1    ASP      C     13     0.5621      12.01   ; qtot 0.5889
    14          O      1    ASP      O     14    -0.5889         16   ; qtot 0
    
Besides adding hydrogens, the carboxyl group of the protein backbone (atoms 3,4 in PDB) is pushed down (to become 13,14 in the topology).

Consequently, the ATOMS setting in the generated `plumed.dat` must be:

    model: PYTORCH_MODEL_CV FILE=model.pt ATOMS=1,5,13,14,7,10,11,12, ...
    
i.e., the atoms are enumerated *in the order* of PDB file but *referring to numbers* of topology file. 

If there is any mismatch, the MD simulations are likely to fail, or at least to produce meaningless results.

It's also **critical** that `{conf}`, `{top}`, and `{gro}` correspond to one another, and that `{gro}` **includes hydrogens**.


In [None]:
test_geom = np.moveaxis(np.stack(list(tf.data.Dataset.load('datasets/geoms/test'))),2,0)
test_geom.shape

In [None]:
mol_model = torch.jit.load('features.pt')
torch_encoder = torch.jit.load('encoder.pt')

In [None]:
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)


In [None]:
class CompleteModel(torch.nn.Module):
    def __init__(self, mol_model, torch_encoder, train_mean, train_scale):
        super(CompleteModel, self).__init__()
        self.mol_model = mol_model
        self.torch_encoder = torch_encoder
        # Convert train_mean and train_scale from numpy to torch tensors
        self.train_mean = torch.from_numpy(np.reshape(train_mean, (-1, 1)))
        self.train_scale = torch.from_numpy(np.reshape(train_scale, (-1, 1)))

    def forward(self, x):
        # nice here, but plumed module provides us with a flat array (n_atoms*3,)
        # mol_output = self.mol_model(x.moveaxis(0,-1))
        mol_output = self.mol_model(x.reshape((-1,3,1)))
        normalized = (mol_output - self.train_mean) / self.train_scale
#        reshaped = normalized.reshape(-1)
        #return self.torch_encoder(normalized.T)
        # blbost, ale nepada: return x[:2]
        lows = self.torch_encoder(normalized[:,0])
        return lows

# Initialize the CompleteModel class with your components
complete_model = CompleteModel(mol_model, torch_encoder, train_mean, train_scale)

# Save the Torch model using TorchScript trace
# 
# same shape trick as in forward()
# example_input = torch.randn([1,test_geom.shape[1], test_geom.shape[2]])
example_input = torch.randn([1,test_geom.shape[1]*3])

traced_script_module = torch.jit.trace(complete_model, example_input)

model_file_name = "model.pt"
traced_script_module.save(model_file_name)

In [None]:
test_geom.shape

In [None]:
f = mol_model(example_input.reshape((-1,3,1)))
f.shape

In [None]:
n = (f-np.reshape(train_mean,(-1,1)))/np.reshape(train_scale,(-1,1))
n.shape

In [None]:
l = torch_encoder(n.T)
l.shape

In [None]:
example_input.shape, example_input.reshape((-1,3,1)).shape

In [None]:
m = torch.jit.load('model.pt')
lows = m(torch.tensor(test_geom)).numpy()
lows.shape

In [None]:
x = md.load('x_test.xtc', top=conf)

In [None]:
x.xyz.shape, test_geom.shape

In [None]:
lows = m(torch.tensor(x.xyz)).numpy()
lows.shape

In [None]:
# visual check, should be the same as in train.ipynb
rg = md.compute_rg(x)
base = md.load(conf)
rmsd = md.rmsd(x,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

#### Determine range of CVs for simulation

Plumed maintains a grid to approximate accumulated bias potential, which size must be known in advance.

Making it wider is safe, the simulation is less likely to escape and crash, but there is perfomance penalty.

Calculate the CVs on the testset, determine their range, and add some margins


In [None]:
grid_margin = 3.  # that many times the actual computed size added on both sides

lmin = np.min(lows,axis=0)
lmax = np.max(lows,axis=0)
llen = lmax-lmin
lmin -= llen * grid_margin
lmax += llen * grid_margin

lmin, lmax

In [None]:
# Atom numbering magic with Gromacs, see above

grotr = md.load(gro)
nhs = grotr.topology.select('element != H')

with open(index) as f:
    f.readline()
    ndx = np.fromstring(" ".join(f),dtype=np.int32,sep=' ')-1

pdb2gmx = nhs[np.argsort(ndx)]+1

# maybe double check manually wrt. the files
pdb2gmx

In [None]:
ndx

In [None]:
with open("plumed.dat","w") as p:
    p.write(f"""\
RESTART
WHOLEMOLECULES ENTITY0=1-{grotr.xyz.shape[1]}
model: PYTORCH_MODEL_CV FILE={model_file_name} ATOMS={','.join(map(str,pdb2gmx))}
metad: METAD ARG=model.node-0,model.node-1 PACE=1000 HEIGHT=1 BIASFACTOR=15 SIGMA=0.1,0.1 GRID_MIN={lmin[0]},{lmin[1]} GRID_MAX={lmax[0]},{lmax[1]} FILE=HILLS
PRINT FILE=COLVAR ARG=model.node-0,model.node-1,metad.bias STRIDE=100
""")