# ASMSA: Train AAE model with the tuned hyperparameters

**Previous steps**
- [prepare.ipynb](prepare.ipynb): Download and sanity check input files
- [tune.ipynb](tune.ipynb): Perform initial hyperparameter tuning for this molecule

**Next step**
- [md.ipynb](md.ipynb): Use a trained model in MD simulation with Gromacs

## Notebook setup

In [None]:
threads = 2
import os
os.environ['OMP_NUM_THREADS']=str(threads)
import tensorflow as tf

# PyTorch favours OMP_NUM_THREADS in environment
import torch

# Tensorflow needs explicit cofig calls
tf.config.threading.set_inter_op_parallelism_threads(threads)
tf.config.threading.set_intra_op_parallelism_threads(threads)

In [None]:
from asmsa.tuning_analyzer import TuningAnalyzer
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
import urllib.request
from tensorflow import keras
import keras_tuner
import asmsa.visualizer as visualizer
import asmsa

## Input files

All input files are prepared (up- or downloaded) in [prepare.ipynb](prepare.ipynb). 


In [None]:
exec(open('inputs.py').read())

## Apply the tuning results

In [None]:
# pick from plots in tune.ipynb

best_enc_seed=96
best_disc_seed=32

In [None]:
# Get best HP from latest tuning
analyzer = TuningAnalyzer()
analyzer.get_best_hp(num_trials=3)

In [None]:
# Select HP to use by specifying trial_id
#  e.g: trial_id = '483883b929b3445bff6dee9759c4d50ee3a4ba7f0db22e665c49b5f942d9693b'
# ... or don't specify, by default use the trial with the lowest score
trial_id = ''

hps = None
for trial in analyzer.sorted_trials:
    if trial['trial_id'] == trial_id:
        hps = trial['hp']
    
if not hps:
    print(f'Could not find trial with specified ID, using one with the lowest score - {analyzer.sorted_trials[0]["trial_id"]}')
    hps = analyzer.sorted_trials[0]['hp']
    
print(hps)

## Load datasets
Load filtered trajectory datasets that were processed in **prepare.ipynb**. Trajectories are in internal coordinates format.

In [None]:
# load train dataset
X_train = tf.data.Dataset.load('datasets/intcoords/train')

# get batched version of dataset to feed to AAE model for training
X_train_batched = X_train.batch(hps['batch_size'],drop_remainder=True)

# get numpy version for visualization purposes
X_train_np = np.stack(list(X_train))
X_train_np.shape

In [None]:
# load test dataset
X_test = tf.data.Dataset.load('datasets/intcoords/test')

# get batched version of dataset to feed to AAE model for prediction
X_test_batched = X_test.batch(hps['batch_size'],drop_remainder=True)

# get numpy version for testing purposes
X_test_np = np.stack(list(X_test))
X_test_np.shape

In [None]:
# Merge (zip) the trajectory density of the training set points
# those will be aligned with the probability density of the prior distribution 

dens = tf.data.Dataset.from_tensor_slices(np.loadtxt('datasets/train_density.txt'))
X_train_dens = tf.data.Dataset.zip((X_train,dens)).batch(hps['batch_size'],drop_remainder=True)

In [None]:
for e in X_train_dens.as_numpy_iterator():
    break
e[0].shape,e[1].shape

## Train

### Distribution prior
Train with common prior distributions. See https://www.tensorflow.org/probability/api_docs/python/tfp/distributions for all available distributions. It is ideal to use tuned Hyperparameters for training.

In [None]:
# set used prior

# this one is (more or less) required to with the density alignment
prior = tfp.distributions.MultivariateNormalDiag(loc=[0.,0.])

#prior = tfp.distributions.Normal(loc=0, scale=1)
# prior = tfp.distributions.Uniform()
# prior = tfp.distributions.Weibull(1,0.5)
# prior = tfp.distributions.Cauchy(loc=0, scale=1)

In [None]:
# prepare model using the best hyperparameters
testm = asmsa.AAEModel((X_train_np.shape[1],),
                       prior=prior,
                       hp=hps,
                       enc_seed=best_enc_seed,
                       disc_seed=best_disc_seed,
                       with_density=True
                      )
testm.compile()

In [None]:
# specify earlystopping callback to avoid overfitting
monitored_metric = "AE loss min"

early_stop_cb = tf.keras.callbacks.EarlyStopping(
    monitor=monitored_metric,
    min_delta=0.0001,
    patience=15,
    verbose=1,
    mode="min",
    restore_best_weights=True,
)

In [None]:
# train it (can be repeated several times to add more epochs)

testm.fit(X_train_dens, # X_train_batched,
          epochs=600,
          verbose=2, # this flag is essential due to connection with EarlyStopping callback (epoch vs batch)
          callbacks=[
              early_stop_cb,
              visualizer.VisualizeCallback(testm,freq=25,inputs=X_train_np[15000:25000],figsize=(12,3))
          ])

In [None]:
# - plot AE loss min during training
# - specify "since_epoch" for better plot scaling (ignore outliers)
# - note that numbering of epochs starts at 1, 0th epoch does not exist
since_epoch = 1

assert since_epoch > 0
history = np.array(testm.history.history[monitored_metric])
y = history[since_epoch-1:]
x = list(range(since_epoch, len(y)+since_epoch))
result = np.array(list(map(lambda x: x+1, np.where(history == history.min())[0]))) # add +1 to convert index to epoch

[plt.axvline(_x, linewidth=0.5, color='r', ls=':') for _x in result]
plt.plot(x, y)
plt.title(f'Best weights for metric [{monitored_metric}] at epoch/s {result}')
plt.show()

In [None]:
# whatever test
''' 
batch_size = 256

val_result = testm.predict(X_test_batched)
mse = keras.losses.MeanSquaredError()
dataset_size = X_test_np.shape[0]
print(dataset_size)
mse_result=[]
for i in range(0, dataset_size, batch_size):
    if i+batch_size > dataset_size:
        batch_size = batch_size-(i+batch_size-dataset_size)
    batch_mse = mse(X_test_np[i:i+batch_size],val_result[i:i+batch_size]).numpy()
    mse_result.append(batch_mse)

mse_result'''

In [None]:
# final visualization, pick a slice of the input data for demo purposes
#visualizer.Visualizer(figsize=(12,3)).make_visualization(testm.call_enc(X_train_np[15000:20000]).numpy())

# on test data
visualizer.Visualizer(figsize=(12,3)).make_visualization(testm.call_enc(X_test_np).numpy())

In [None]:
# load testing trajectory for further visualizations and computations
tr = md.load('x_test.xtc',top=conf)
idx=tr[0].top.select("name CA")

# for trivial cases like Ala-Ala, where superposing on CAs fails
#idx=tr[0].top.select("element != H") 

tr.superpose(tr[0],atom_indices=idx)

# reshuffle the geometry to get frame last so that we can use vectorized calculations
geom = np.moveaxis(tr.xyz ,0,-1)
geom.shape

In [None]:
# Rgyr and rmsd color coded in low dim (rough view)

lows = testm.call_enc(X_test_np).numpy()
rg = md.compute_rg(tr)
base = md.load(conf)
rmsd = md.rmsd(tr,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

In [None]:
# not used
'''testm.enc.save('enc.keras')
testm.dec.save('dec.keras')
testm.disc.save('dec.keras')'''

### Image prior

**Almost surely broken now with the density alignment**

Use Image as a prior distribution. Again use tuned Hyperparameters for better training performance.

In [None]:
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1I2WP92MMWS5s5vin_4cvmruuV-1W77Hl", "mushroom_bw.png")

In [None]:
mmush = asmsa.AAEModel((X_train_np.shape[1],),
                       hp=hps,
                       enc_seed=best_enc_seed,
                       disc_seed=best_disc_seed,
                       prior='mushroom_bw.png'
                      )
mmush.compile()

In [None]:
mmush.fit(X_train_batched,
          epochs=500,
          verbose=2,
          callbacks=[
              early_stop_cb,
              visualizer.VisualizeCallback(mmush,freq=25,inputs=X_train_np[15000:25000],figsize=(12,3))
          ])

In [None]:
# - plot AE loss min during training
# - specify "since_epoch" for better plot scaling (ignore outliers)
# - note that numbering of epochs starts at 1, 0th epoch does not exist
since_epoch = 1
monitored_metric = 'AE loss min'

assert since_epoch > 0
history = np.array(mmush.history.history[monitored_metric])
y = history[since_epoch-1:]
x = list(range(since_epoch, len(y)+since_epoch))
result = np.array(list(map(lambda x: x+1, np.where(history == history.min())[0]))) # add +1 to convert index to epoch

[plt.axvline(_x, linewidth=0.5, color='r', ls=':') for _x in result]
plt.plot(x, y)
plt.title(f'Best weights for metric [{monitored_metric}] at epoch/s {result}')
plt.show()

In [None]:
batch_size = 256

val_result = testm.predict(X_test_batched)
mse = keras.losses.MeanSquaredError()
dataset_size = X_test_np.shape[0]
print(dataset_size)
mse_result=[]
for i in range(0, dataset_size, batch_size):
    if i+batch_size > dataset_size:
        batch_size = batch_size-(i+batch_size-dataset_size)
    batch_mse = mse(X_test_np[i:i+batch_size],val_result[i:i+batch_size]).numpy()
    mse_result.append(batch_mse)

mse_result

In [None]:
step=4
tr2 = tr[::step]
lows = mmush.call_enc(X_test_np[::step]).numpy()
rg = md.compute_rg(tr2)
base = md.load(conf)
rmsd = md.rmsd(tr2,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

## Save the model for Gromacs

*Another wave of magics ...*

There are multiple ways how atoms are numbered in PDB, GRO, etc. files. 

So far we worked with atoms numbered as in the `conf` PDB file, assuming `traj` XTC file was consistent with those.
If the topology was used, it might have had different numbering, as Gromacs likes. 

In the subsequent simulations, we assume the usual protocol starting with `pdb2gmx` to generate topology,
hence Gromacsish atom numbering will be followed afterwards.
Therefore we need `plumed.dat` to pick the atoms according to the PDB file order, and skip hydrogens added by Gromacs. 

Many things can go wrong, therefore we strongly encorage to check the results manually. For example, the first residuum (ASP) of tryptophan cage may look like the following in PDB file:

    ATOM      1  N   ASP     1      28.538  39.747  31.722  1.00  1.00           N
    ATOM      2  CA  ASP     1      28.463  39.427  33.168  1.00  1.00           C
    ATOM      3  C   ASP     1      29.059  37.987  33.422  1.00  1.00           C
    ATOM      4  O   ASP     1      30.226  37.748  33.735  1.00  1.00           O
    ATOM      5  CB  ASP     1      26.995  39.482  33.630  1.00  1.00           C
    ATOM      6  CG  ASP     1      26.889  39.307  35.101  1.00  1.00           C
    ATOM      7  OD1 ASP     1      27.749  39.962  35.773  1.00  1.00           O
    ATOM      8  OD2 ASP     1      26.012  38.510  35.611  1.00  1.00           O
    
Which turns Gromacs topology: 

     1         N3      1    ASP      N      1     0.0782      14.01   ; qtot 0.0782
     2          H      1    ASP     H1      2       0.22      1.008   ; qtot 0.2982
     3          H      1    ASP     H2      3       0.22      1.008   ; qtot 0.5182
     4          H      1    ASP     H3      4       0.22      1.008   ; qtot 0.7382
     5         CT      1    ASP     CA      5     0.0292      12.01   ; qtot 0.7674
     6         HP      1    ASP     HA      6     0.1141      1.008   ; qtot 0.8815
     7         CT      1    ASP     CB      7    -0.0235      12.01   ; qtot 0.858
     8         HC      1    ASP    HB1      8    -0.0169      1.008   ; qtot 0.8411
     9         HC      1    ASP    HB2      9    -0.0169      1.008   ; qtot 0.8242
    10          C      1    ASP     CG     10     0.8194      12.01   ; qtot 1.644
    11         O2      1    ASP    OD1     11    -0.8084         16   ; qtot 0.8352
    12         O2      1    ASP    OD2     12    -0.8084         16   ; qtot 0.0268
    13          C      1    ASP      C     13     0.5621      12.01   ; qtot 0.5889
    14          O      1    ASP      O     14    -0.5889         16   ; qtot 0
    
Besides adding hydrogens, the carboxyl group of the protein backbone (atoms 3,4 in PDB) is pushed down (to become 13,14 in the topology).

Consequently, the ATOMS setting in the generated `plumed.dat` must be:

    model: PYTORCH_MODEL_CV FILE=model.pt ATOMS=1,5,13,14,7,10,11,12, ...
    
i.e., the atoms are enumerated *in the order* of PDB file but *referring to numbers* of topology file. 

If there is any mismatch, the MD simulations are likely to fail, or at least to produce meaningless results.

It's also **critical** that `{conf}`, `{top}`, and `{gro}` correspond to one another, and that `{gro}` **includes hydrogens**.


In [None]:
import tf2onnx
import onnx2torch
import tempfile

def _convert_to_onnx(model, destination_path):
#    model = keras.models.load_model(source_path)

    input_tensor = model.layers[0]._input_tensor
#    input_tensor = model.inputs[0]
    input_signature = tf.TensorSpec(
        name=input_tensor.name, shape=input_tensor.shape, dtype=input_tensor.dtype
    )
    output_name = model.layers[-1].name

    @tf.function(input_signature=[input_signature])
    def _wrapped_model(input_data):
        return {output_name: model(input_data)}

    tf2onnx.convert.from_function(
        _wrapped_model, input_signature=[input_signature], output_path=destination_path
    )

In [None]:
model = testm

with tempfile.NamedTemporaryFile() as onnx:
#    tf2onnx.convert.from_keras(model.enc,output_path=onnx.name)
    _convert_to_onnx(model.enc,onnx.name)
    torch_encoder = onnx2torch.convert(onnx.name)

# load test geometry dataset
geom = np.stack(list(tf.data.Dataset.load('datasets/geoms/test')))

# XXX: we rely on determinism of the model creation, it must be the same as in prepare.ipynb
# better to store it there in onnx, and reload here

sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=nb_density)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])
    
mol_model = mol.get_model()

train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)

def complete_model(x):
    return torch_encoder(
        ((mol_model(x) - torch.from_numpy(np.reshape(train_mean,(-1,1)))) / torch.from_numpy(np.reshape(train_scale,(-1,1)))).reshape(-1)
    )

# Save Torch model using TorchScript trace
example_input = torch.randn([geom.shape[0], geom.shape[1], 1])
traced_script_module = torch.jit.trace(complete_model, example_input)

model_file_name = "model.pt"
traced_script_module.save(model_file_name)

In [None]:
# validate

example_geom = np.random.rand(geom.shape[0], geom.shape[1], 1)
#X = mol.intcoord(example_geom).T
X = ((mol.intcoord(example_geom) - np.reshape(train_mean,(-1,1))) / np.reshape(train_scale,(-1,1))).T
tf_low = np.array(model.enc(X))

torch_geom = torch.tensor(example_geom.reshape(-1), dtype=torch.float32, requires_grad=True)
torch_low = traced_script_module(torch_geom)

for out in torch_low:
    grad = torch.autograd.grad(out, torch_geom, retain_graph=True)

    
# should be very small, eg. less than 1e-5
np.max(np.abs(tf_low - torch_low.detach().numpy()))

In [None]:
# Atom numbering magic with Gromacs, see above

grotr = md.load(gro)
nhs = grotr.topology.select('element != H')

with open(index) as f:
    f.readline()
    ndx = np.fromstring(" ".join(f),dtype=np.int32,sep=' ')-1

pdb2gmx = nhs[np.argsort(ndx)]+1

# maybe double check manually wrt. the files
# pdb2gmx

In [None]:
ndx

#### Determine range of CVs for simulation

Plumed maintains a grid to approximate accumulated bias potential, which size must be known in advance.

Making it wider is safe, the simulation is less likely to escape and crash, but there is perfomance penalty.

Calculate the CVs on the testset, determine their range, and add some margins


In [None]:
grid_margin = 1.  # that many times the actual computed size added on both sides

lows = model.call_enc(X_test_np).numpy()
lmin = np.min(lows,axis=0)
lmax = np.max(lows,axis=0)
llen = lmax-lmin
lmin -= llen * grid_margin
lmax += llen * grid_margin

In [None]:
with open("plumed.dat","w") as p:
    p.write(f"""\
RESTART
WHOLEMOLECULES ENTITY0=1-{grotr.xyz.shape[1]}
model: PYTORCH_MODEL_CV FILE={model_file_name} ATOMS={','.join(map(str,pdb2gmx))}
metad: METAD ARG=model.node-0,model.node-1 PACE=1000 HEIGHT=1 BIASFACTOR=15 SIGMA=0.1,0.1 GRID_MIN={lmin[0]},{lmin[1]} GRID_MAX={lmax[0]},{lmax[1]} FILE=HILLS
PRINT FILE=COLVAR ARG=model.node-0,model.node-1,metad.bias STRIDE=100
""")

In [None]:
# XXX: these were some tests for the density alingment, I'm not sure anymore; not important either

In [None]:
traj = 'validate.xtc'
tr = md.load(traj,top=conf)
train_mean = np.loadtxt('datasets/intcoords/mean.txt',dtype=np.float32)
train_scale = np.loadtxt('datasets/intcoords/scale.txt',dtype=np.float32)

In [None]:
geom = np.moveaxis(tr.xyz,0,-1)
geom.shape

In [None]:
density = 2
sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])


In [None]:
intcoord = mol.intcoord(geom).T
intcoord -= train_mean
intcoord /= train_scale

In [None]:
import gromacs.fileformats as gf

In [None]:
rms = gf.XPM('rmsd.xpm')

In [None]:
plt.hist(rms.array.flatten(),bins=50)
plt.show()

In [None]:
rms_sort = np.sort(rms.array.astype(np.float32))
erms = np.exp(-rms_sort[:,:50])

In [None]:
dens = (np.sum(erms,axis=1)-1.) / (erms.shape[1] - 1)

In [None]:
plt.hist(dens,bins=20)
plt.show()

In [None]:
lows = testm.enc(intcoord)

In [None]:
lows.shape, dens.shape

In [None]:
plt.scatter(lows[:,0],lows[:,1],c=dens,marker='.',cmap='rainbow')
plt.colorbar()
plt.show()