# Analysis and Sampling of Molecular Simulations by adversarial Autoencoders
---
1. [Packages import](#1.-Packages-import)
2. [Internal coordinates computation](#2.-Internal-coordinates-computation)
3. [Execution & visualization](#3.-Execution-&-visualization)

## 1. Packages import

In [None]:
# Import packages

from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
from src import asmsa
from src.gan import GAN
import mdtraj as md
import numpy as np
import nglview as nv

In [None]:
# Define input files
%cd ~

# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"

# input trajectory
# atom numbering must be consistent with {conf}

#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"

# input topology
# expected to be produced with 
#    gmx pdb2gmx -f {conf} -p {topol} -n {index} 

# Gromacs changes atom numbering, the index file must be generated and used as well

#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'

## 2. Internal coordinates computation

In [None]:
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")
#idx=tr[0].top.select("element != H")
tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)

In [None]:
v = nv.show_mdtraj(tr)
v.clear()
v.add_representation("licorice")
v

In [None]:
geom.shape

In [None]:
# Define sparse and dense feture extensions of IC
density = 2 # integer in [1, n_atoms-1]
sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
dense_dists = asmsa.NBDistancesDense(geom.shape[0])

# mol = asmsa.Molecule(conf,topol)
# mol = asmsa.Molecule(conf,topol,fms=[sparse_dists])
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
X_train = mol.intcoord(geom).T
X_train.shape

## 3. Test various batch sizes

In [None]:
# Execute
output_file = 'lows.txt'

gan = GAN(X_train, out_file=output_file)
test = gan.train(epochs=3, batch_size=132, visualize_freq=None) 

In [None]:
# convergence seems to be similar
gan2 = GAN(X_train, out_file=output_file)
test2 = gan2.train(epochs=50, batch_size=50, visualize_freq=10) 

In [None]:
# too big batch size, seems not to converge anymore
gan3 = GAN(X_train, out_file=output_file)
test3 = gan3.train(epochs=50, batch_size=500, visualize_freq=10) 

In [None]:
# sort of optimal for Ala-Ala
gan4 = GAN(X_train, out_file=output_file)
test4 = gan4.train(epochs=50, batch_size=256, visualize_freq=10) 

## 4. Test various types of layers

In [None]:
# build & run default gan (already ran above)
gan = GAN(X_train, out_file=output_file)
test5 = gan.train(epochs=50, batch_size=256, visualize_freq=10) 

In [None]:
# set encoder with relu layers, also build the inverse decoder (with the same activation functions)
gan.set_encoder(params=[("relu", 32),
                        ("relu", 16),
                        ("relu", 8),
                        ("linear", None)], build_decoder=True)

test6 = gan.train(epochs=50, batch_size=256, visualize_freq=10) 

In [None]:
# build encoder with relu layers and decoder with selu layers
gan.set_encoder(params=[("relu", 32),
                        ("relu", 16),
                        ("relu", 8),
                        ("linear", None)])

gan.set_decoder(params=[("selu", 8),
                        ("selu", 16),
                        ("selu", 32),
                        ("linear", None)])

test7 = gan.train(epochs=50, batch_size=256, visualize_freq=10) 

In [None]:
# set discriminator with no activation functions 
gan.set_discriminator(params=[(None, 64),
                              (None, 16),
                              (None, 4),
                              (None, 1)])

test8 = gan.train(epochs=50, batch_size=256, visualize_freq=10) 

## 5. Specification of prior distribution

In [None]:
# run with default prior which is normal 
gan = GAN(X_train, out_file=output_file)
test9 = gan.train(epochs=50, batch_size=256, visualize_freq=10)

In [None]:
# run with uniform prior distribution
gan = GAN(X_train, out_file=output_file, prior='uniform')
test9 = gan.train(epochs=50, batch_size=256, visualize_freq=10)

## 6. Set early stop

In [None]:
# ae_estop (autoencoder early stop) is not triggered by default
# runs early stop for autoencoder with default parameters
test10 = gan.train(epochs=50, batch_size=256, visualize_freq=10, ae_estop=True)

In [None]:
# exactly the same for discriminator early stopping
test10 = gan.train(epochs=50, batch_size=256, visualize_freq=10, d_estop=True)

# can use both at the same time
test10 = gan.train(epochs=50, batch_size=256, visualize_freq=10, ae_estop=True, d_estop=True)

In [None]:
# user can customise own early stop callback
from keras.callbacks import EarlyStopping
ae_stop_callback = EarlyStopping(monitor=monitor,
                                 patience=4,
                                 verbose=1,
                                 mode='min')

test10 = gan.train(epochs=50, batch_size=256, visualize_freq=10, ae_estop=ae_stop_callback)

## 7. Visualization of the final result

In [None]:
# Visualization of low dimensional space

# define input files
%cd ~
lows = np.loadtxt(output_file)

%cd visualization
rama_ala = np.loadtxt('rama_ala_reduced.txt', usecols=(0,1))
angever1 = np.loadtxt('angever1.txt')
angever2 = np.loadtxt('angever2.txt')
angever3 = np.loadtxt('angever3.txt')


cvs = (lows[:, 0], lows[:, 1])
analysis_files = {
    'rama0' : rama_ala[:, 0],
    'rama1' : rama_ala[:, 1],
    'ang1' : angever1[:, 1],
    'ang2' : angever2[:, 1],
    'ang3' : angever3[:, 1]
}

# set limits
xmin, xmax = min(cvs[0]), max(cvs[0])
ymin, ymax = min(cvs[1]), max(cvs[1])

# plot configuration
plt.suptitle('Low Dimentional Space - Analysis')
plt.style.use("seaborn-white")
fig = plt.figure(figsize=(18, 10))
fig.supxlabel('CV1', x=0.5, fontsize=16, fontweight='bold')
fig.supylabel('CV2', x=0.1, fontsize=16, fontweight='bold')

# plot first graph
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X; pos[:, :, 1] = Y
values = np.vstack([cvs[0], cvs[1]])
kernel = gaussian_kde(values)
dens = np.reshape(kernel(positions).T, X.shape)
ax1 = plt.subplot(2, 3, 1)
ax1.set_xticks([])
plt.imshow(np.rot90(dens), cmap="hsv", aspect="auto", extent=[xmin, xmax, ymin, ymax])


# plot every other graph
i = 2
for name, data in analysis_files.items():
    ax = plt.subplot(2, 3, i)
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    ax.set_title(name)
    if i in [2,3,5,6]:
        ax.set_yticks([])
    if i in [2,3]:
        ax.set_xticks([])
    plt.scatter(cvs[0], cvs[1], s=1, c=data, cmap="hsv")
    i += 1

    
# you can view .png output in visualization folder
plt.savefig('analysis.png')

In [None]:
# Rgyr color coded in low dim (rough view)

lows = np.loadtxt(output_file)
rg = md.compute_rg(tr)
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,12))
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.show()