# ASMSA: Tune AAE model hyperparameters

**Previous step**
- [prepare.ipynb](prepare.ipynb): Download and sanity check input files

**Next steps**
- [train.ipynb](train.ipynb): Use results of previous tuning in more thorough training
- [md.ipynb](md.ipynb): Use a trained model in MD simulation with Gromacs

## Notebook setup

In [None]:
threads = 2
import os
os.environ['OMP_NUM_THREADS']=str(threads)
import tensorflow as tf

# PyTorch favours OMP_NUM_THREADS in environment
import torch

# Tensorflow needs explicit cofig calls
tf.config.threading.set_inter_op_parallelism_threads(threads)
tf.config.threading.set_intra_op_parallelism_threads(threads)

In [None]:
import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
import urllib.request
from tensorflow import keras
import keras_tuner
import asmsa
from datetime import datetime

## Input files

All input files are prepared (up- or downloaded) in [prepare.ipynb](prepare.ipynb). 

This is for demonstration purpose, in real use the inputs should be placed here, and _conf, traj, topol, index_ variables set to their filenames names.

In [None]:
# Define input files

# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"

# input trajectory
# atom numbering must be consistent with {conf}

#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"

# input topology
# expected to be produced with 
#    gmx pdb2gmx -f {conf} -p {topol} -n {index} -o {gro}

# Gromacs changes atom numbering, the index file must be generated and used as well
# gro file is used to generate inverse indexing for plumed.dat

#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'
gro = 'trpcage_correct.gro'

## Internal coordinates computation

In [None]:
# Load the trajectory, it should report expected numbers of frames and atoms/residua

tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")

# for trivial cases like Ala-Ala, where superposing on CAs fails
#idx=tr[0].top.select("element != H")

tr.superpose(tr[0],atom_indices=idx)

In [None]:
# reshuffle the geometry to get frame last so that we can use vectorized calculations

geom = np.moveaxis(tr.xyz ,0,-1)
geom.shape

In [None]:
# Prepare internal coordinates computation. There are multiple options, see prepare.ipynb, and adjust here accordingly

density = 2 # integer in [1, n_atoms-1]

sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
# compute on the actual input 

X_train = mol.intcoord(geom).T
X_train.shape

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices(X_train).shuffle(2048).batch(64,drop_remainder=True)

## Sequential hyperparameter tuning

This is robust, it does not require Kubernetes environment for additional job submission. On the other hand, it is slow accordingly.

**Skip to the next section if you run the notebook in our recommended setup in Kubernets.**

In [None]:
medium_hp = {
    'activation' : ['relu','gelu'],
    'ae_neuron_number_seed' : [32,96,128],
    'disc_neuron_number_seed' : [32,96],
    'ae_number_of_layers' : [2,2],
    'disc_number_of_layers' : [3,3],
    'batch_size' : [64,128,256],
    'optimizer' : ['Adam'],
    'learning_rate' : 0.0002,
    'ae_loss_fn' : ['MeanSquaredError'],
    'disc_loss_fn' : ['BinaryCrossentropy']
}

In [None]:
# Just testing numbers of epochs and hyperparameter setting trials
# Don't expect anything meaningful

trials=3
epochs=15
results_dir=datetime.today().strftime("%m%d%Y-%H%M%S")

os.environ['START_TIME']=results_dir
tuner = keras_tuner.RandomSearch(
    max_trials=trials,
    hypermodel=asmsa.AAEHyperModel((X_train.shape[1],),hpfunc=medium_hp),
    objective=keras_tuner.Objective("score", direction="min"),
    directory="./results",
    project_name="Random",
    overwrite=True
)

In [None]:
tuner.search(X_train,epochs=epochs)

In [None]:
from asmsa.tuning_analyzer import TuningAnalyzer

# Create analyzer object that analyses results of tuning
# By default it is the latest tuning, but can by configured with tuning flag
#  which is set to the directory of the tuning, e.g tuning='ASMSA_visualization/05092023-135249'
analyzer = TuningAnalyzer()

In [None]:
# Get sorted hyperparameters by score, by default 10 best HP
analyzer.get_best_hp()

In [None]:
# Matplotlib visualization - not recommended way, does not look that good and does not scale 
#  that well but at least the colors are consistent accross measures. After more work could look better
# There is an option to reduce amount of plots num_trials and can choose only one
# By default choosing the latest tuning (can be changed in definition of TuingAnalyzer)
analyzer.visualize_tuning()

In [None]:
# Recommended option via tensorboard. This function populates TB event
#  which can be viewed in native way via tensorboard. By default chooses
#  latest tuning and populates into its directory _TB, e.g:
#  ASMSA_visualization/05092023-135249/_TB
analyzer.populate_TB()

In [None]:
%load_ext tensorboard
%tensorboard --logdir ASMSA_visualization/

## Parallel hyperparameter tuning

In [None]:
# Finally, this is the real stuff
# medium settings known to be working for trpcage

epochs=15
trials=3
hpfunc=medium_hp

# testing only
#epochs=8
#trials=6
#hpfunc=tiny_hp

In [None]:
# number of parallel workers, each runs a single trial at time
# balance between resource availability and size of the problem
# currently each slave runs on 4 cores and 4 GB RAM (hardcoded in src/asmsa/tunewrapper.py)

slaves=3

In [None]:
# XXX: Kubernetes magic: find out names of container image and volume
# check the result, it can go wrong

with open('IMAGE') as img:
    image=img.read().rstrip()

import re
mnt=os.popen('mount | grep /home/jovyan').read()
pvcid=re.search('pvc-[0-9a-z-]+',mnt).group(0)
pvc=os.popen(f'kubectl get pvc | grep {pvcid} | cut -f1 -d" "').read().rstrip()

print(f"""\
image: {image}
volume: {pvc}
""")

In [None]:
# Python wrapper around scripts that prepare and execute parellel Keras Tuner in Kubernetes
from asmsa.tunewrapper import TuneWrapper

wrapper = TuneWrapper(hpfunc=hpfunc,output='best.txt',epochs=epochs,trials=trials,pdb=conf,top=topol,xtc=traj,ndx=index, pvc=pvc)

In [None]:
# Necessary but destructive cleanup before hyperparameter tuning

# DON'T RUN THIS CELL BLINDLY
# it kills any running processes including the workers, and it purges previous results

!kubectl delete job/tuner
!kill $(ps ax | grep tuning.py | awk '{print $1}')
!rm -rf results

In [None]:
# backup previous results; do so only if you want to

!mv best.txt best.txtO

In [None]:
# start the master (chief) of tuners in background
# the computation takes rather long, this is a more robust approach then keeping it in the notebook

wrapper.master_start()

In [None]:
# therefore one should check the status ocassionally; it should show a tuning.py process running
print(wrapper.master_status())

In [None]:
# spawn the requested number of workers as separate Kubernetes job with several pods 
# they receive work from 

wrapper.workers_start(num=slaves)

In [None]:
# This status should show {slaves} number of pods, all of them start in Pending state, and follow through ContainerCreating 
# to Running, and Completed finally

# This takes time, minutes to hours depending on size of the model, number of trials, and number of slaves
# Run this cell repeatedly, waiting until all the pods are completed

wrapper.workers_status()

In [None]:
# Same steps for analysis as with serial tuning
analyzer = TuningAnalyzer()
analyzer.get_best_hp()

In [None]:
# We can choose output dir for TB event this time
analyzer.populate_TB(out_dir='test2')

In [None]:
# Might need to kill previous tensorboard instance to change logdir
!pkill -f 'tensorboard'

%load_ext tensorboard
%tensorboard --logdir test2