In [None]:
import tf2onnx
import onnx2torch
import tempfile

import torch
import numpy as np
import tensorflow as tf

## Input files

All input files are prepared (up- or downloaded) in [prepare.ipynb](prepare.ipynb). 


In [None]:
exec(open('inputs.py').read())

## Save the model for Gromacs

*Another wave of magics ...*

There are multiple ways how atoms are numbered in PDB, GRO, etc. files. 

So far we worked with atoms numbered as in the `conf` PDB file, assuming `traj` XTC file was consistent with those.
If the topology was used, it might have had different numbering, as Gromacs likes. 

In the subsequent simulations, we assume the usual protocol starting with `pdb2gmx` to generate topology,
hence Gromacsish atom numbering will be followed afterwards.
Therefore we need `plumed.dat` to pick the atoms according to the PDB file order, and skip hydrogens added by Gromacs. 

Many things can go wrong, therefore we strongly encorage to check the results manually. For example, the first residuum (ASP) of tryptophan cage may look like the following in PDB file:

    ATOM      1  N   ASP     1      28.538  39.747  31.722  1.00  1.00           N
    ATOM      2  CA  ASP     1      28.463  39.427  33.168  1.00  1.00           C
    ATOM      3  C   ASP     1      29.059  37.987  33.422  1.00  1.00           C
    ATOM      4  O   ASP     1      30.226  37.748  33.735  1.00  1.00           O
    ATOM      5  CB  ASP     1      26.995  39.482  33.630  1.00  1.00           C
    ATOM      6  CG  ASP     1      26.889  39.307  35.101  1.00  1.00           C
    ATOM      7  OD1 ASP     1      27.749  39.962  35.773  1.00  1.00           O
    ATOM      8  OD2 ASP     1      26.012  38.510  35.611  1.00  1.00           O
    
Which turns Gromacs topology: 

     1         N3      1    ASP      N      1     0.0782      14.01   ; qtot 0.0782
     2          H      1    ASP     H1      2       0.22      1.008   ; qtot 0.2982
     3          H      1    ASP     H2      3       0.22      1.008   ; qtot 0.5182
     4          H      1    ASP     H3      4       0.22      1.008   ; qtot 0.7382
     5         CT      1    ASP     CA      5     0.0292      12.01   ; qtot 0.7674
     6         HP      1    ASP     HA      6     0.1141      1.008   ; qtot 0.8815
     7         CT      1    ASP     CB      7    -0.0235      12.01   ; qtot 0.858
     8         HC      1    ASP    HB1      8    -0.0169      1.008   ; qtot 0.8411
     9         HC      1    ASP    HB2      9    -0.0169      1.008   ; qtot 0.8242
    10          C      1    ASP     CG     10     0.8194      12.01   ; qtot 1.644
    11         O2      1    ASP    OD1     11    -0.8084         16   ; qtot 0.8352
    12         O2      1    ASP    OD2     12    -0.8084         16   ; qtot 0.0268
    13          C      1    ASP      C     13     0.5621      12.01   ; qtot 0.5889
    14          O      1    ASP      O     14    -0.5889         16   ; qtot 0
    
Besides adding hydrogens, the carboxyl group of the protein backbone (atoms 3,4 in PDB) is pushed down (to become 13,14 in the topology).

Consequently, the ATOMS setting in the generated `plumed.dat` must be:

    model: PYTORCH_MODEL_CV FILE=model.pt ATOMS=1,5,13,14,7,10,11,12, ...
    
i.e., the atoms are enumerated *in the order* of PDB file but *referring to numbers* of topology file. 

If there is any mismatch, the MD simulations are likely to fail, or at least to produce meaningless results.

It's also **critical** that `{conf}`, `{top}`, and `{gro}` correspond to one another, and that `{gro}` **includes hydrogens**.


In [None]:
import tf2onnx
import onnx2torch
import tempfile

def _convert_to_onnx(model, destination_path):
#    model = keras.models.load_model(source_path)

    input_tensor = model.layers[0]._input_tensor
#    input_tensor = model.inputs[0]
    input_signature = tf.TensorSpec(
        name=input_tensor.name, shape=input_tensor.shape, dtype=input_tensor.dtype
    )
    output_name = model.layers[-1].name

    @tf.function(input_signature=[input_signature])
    def _wrapped_model(input_data):
        return {output_name: model(input_data)}

    tf2onnx.convert.from_function(
        _wrapped_model, input_signature=[input_signature], output_path=destination_path
    )

In [None]:
# load test geometry dataset
geom = np.stack(list(tf.data.Dataset.load('datasets/geoms/test')))


#### Traditional internal coordinates (all bond distances, angles, and torsions)

Run this section or the next one, consistently with [prepare.ipynb](prepare.ipynb)

With this approach, the model is applicable to the molecule used for training only

In [None]:
# all internal cooridnates
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

#### Alternative: only backbone + Cbeta anlges and dihedrals

In [None]:
with open('backbone.ndx') as i:
    i.readline()
    bb = np.array([ int(j)-1 for j in " ".join(i.readlines()).split() ])

In [None]:
# backbone angles and dihedrals
angles = np.array([ bb[i:i+3] for i in range(0,len(bb)-3) ])
diheds = np.array([ bb[i:i+4] for i in range(0,len(bb)-4) ])
# angles, diheds

In [None]:
# XXX: select alpha carbons and matching betas
tr1 = md.load(conf)
cas = tr1.topology.select('name CA and not resname GLY')
cbs = tr1.topology.select('name CB')
assert(len(cas) == len(cbs))

In [None]:
# indices of CAs (non-GLY) on the backbone
cai = np.argwhere(bb.reshape(1,-1) == cas.reshape(-1,1))[:,1]
cai

In [None]:
# angles of CB-CA-X, where X is the next atom on the backbone
cbangles = np.array([[ cbs[0], cas[0], bb[cai[0]+1] ]] +
                   [[cbs[i], bb[cai[i]], bb[cai[i]-1] ] for i in range(1,len(cbs))])
# just check 
# cbangles+1

In [None]:
cbdiheds = np.array([[ cbs[0], cas[0], bb[cai[0]+1], bb[cai[0]+2] ]] +
                   [[cbs[i], bb[cai[i]], bb[cai[i]-1], bb[cai[i]-2]] for i in range(1,len(cbs))])
#cbdiheds+1

In [None]:
# molecule model with explicit angles and dihedrals, and sparse distances
# (don't bother with distances now)
sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=nb_density)
mol=asmsa.Molecule(pdb=conf,n_atoms=geom.shape[0],
                   angles=np.concatenate((angles,cbangles)),
                   diheds=np.concatenate((diheds,cbdiheds)),
                   fms=[sparse_dists]) 

In [None]:
model = testm

with tempfile.NamedTemporaryFile() as onnx:
#    tf2onnx.convert.from_keras(model.enc,output_path=onnx.name)
    _convert_to_onnx(model.enc,onnx.name)
    torch_encoder = onnx2torch.convert(onnx.name)


# XXX: we rely on determinism of the model creation, it must be the same as in prepare.ipynb
# better to store it there in onnx, and reload here

sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=nb_density)

mol_model = mol.get_model()

In [None]:
test_geom = np.stack(list(tf.data.Dataset.load('datasets/geoms/test')))
test_geom.shape

In [None]:
mol_model = torch.jit.load('features.pt')
torch_encoder = torch.jit.load('encoder.pt')

In [None]:
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)


In [None]:
class CompleteModel(torch.nn.Module):
    def __init__(self, mol_model, torch_encoder, train_mean, train_scale):
        super(CompleteModel, self).__init__()
        self.mol_model = mol_model
        self.torch_encoder = torch_encoder
        # Convert train_mean and train_scale from numpy to torch tensors
        self.train_mean = torch.from_numpy(np.reshape(train_mean, (-1, 1)))
        self.train_scale = torch.from_numpy(np.reshape(train_scale, (-1, 1)))

    def forward(self, x):
        mol_output = self.mol_model(x)
        normalized = (mol_output - self.train_mean) / self.train_scale
        reshaped = normalized.reshape(-1)
        return self.torch_encoder(reshaped)

# Initialize the CompleteModel class with your components
complete_model = CompleteModel(mol_model, torch_encoder, train_mean, train_scale)

# Save the Torch model using TorchScript trace
example_input = torch.randn([test_geom.shape[0], test_geom.shape[1], 1])
traced_script_module = torch.jit.trace(complete_model, example_input)

model_file_name = "model.pt"
traced_script_module.save(model_file_name)