In [None]:
# !pip install seaborn --user
# !pip install bayesian-optimization --user
# !pip install fast-ml --user

In [None]:
# Import packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from keras.models import Sequential
from keras.layers import Dense, BatchNormalization, Dropout
from tensorflow.keras.optimizers import Adam, SGD, RMSprop, Adadelta, Adagrad, Adamax, Nadam, Ftrl
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.wrappers.scikit_learn import KerasClassifier
from math import floor
from keras.models import Sequential, Model
from sklearn.metrics import make_scorer, accuracy_score
from bayes_opt import BayesianOptimization
from sklearn.model_selection import StratifiedKFold
from keras.layers import LeakyReLU
from sklearn.model_selection import train_test_split
from fast_ml.model_development import train_valid_test_split

from keras.losses import BinaryCrossentropy, MeanSquaredError
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
from src import asmsa
from src.gan import GAN
from src.visualizer import GAN_visualizer
import mdtraj as md
import numpy as np
import nglview as nv

LeakyReLU = LeakyReLU(alpha=0.1)
import warnings
warnings.filterwarnings('ignore')
pd.set_option("display.max_columns", None)

In [None]:
# Make scorer accuracy
score_acc = make_scorer(accuracy_score)

In [None]:
# Define input files
%cd ~

# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"

# input trajectory
# atom numbering must be consistent with {conf}

#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"

# input topology
# expected to be produced with 
#    gmx pdb2gmx -f {conf} -p {topol} -n {index} 

# Gromacs changes atom numbering, the index file must be generated and used as well

#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'

In [None]:
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")
#idx=tr[0].top.select("element != H")
tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)

In [None]:
v = nv.show_mdtraj(tr)
v.clear()
v.add_representation("licorice")
v

In [None]:
geom.shape

In [None]:
# Define sparse and dense feture extensions of IC
density = 2 # integer in [1, n_atoms-1]
sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
dense_dists = asmsa.NBDistancesDense(geom.shape[0])

# mol = asmsa.Molecule(conf,topol)
# mol = asmsa.Molecule(conf,topol,fms=[sparse_dists])
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
X_train = mol.intcoord(geom).T
X_train.shape

In [None]:
_train, test = train_test_split(X_train, test_size=0.2)

In [None]:
train, validation = train_test_split(X_train, test_size=0.2)

In [None]:
lowdim = 2
prior = 'normal'
molecule_shape = (X_train.shape[1],)

In [None]:
def build_encoder(params):
    model = Sequential()
    # input layer
    model.add(Dense(params['neurons'], input_dim=np.prod(molecule_shape), activation=params['activation']))
    model.add(BatchNormalization(momentum=0.8))
    # hidden layers
    model.add(Dense(params['neurons'], activation=params['activation']))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(params['neurons'], activation=params['activation']))
    model.add(BatchNormalization(momentum=0.8))
    #output layer
    model.add(Dense(lowdim, activation='linear'))
    mol = Input(shape=molecule_shape)
    lowdim = model(mol)
    return Model(mol, lowdim, name="Encoder")


In [None]:
# decoder tuning

def build_decoder(params):
    model = Sequential()
    model._name = "Decoder"
    # input layer
    model.add(Dense(params['neurons'], input_dim=lowdim, activation='linear'))
    model.add(BatchNormalization(momentum=0.8))
    # hidden layers
    model.add(Dense(params['neurons'], activation=params['activation']))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(params['neurons'], activation=params['activation']))
    model.add(BatchNormalization(momentum=0.8))
    # output layer
    model.add(Dense(np.prod(molecule_shape), activation=params['activation']))
    model.add(Reshape(molecule_shape))
    lowdim = Input(shape=(lowdim,))
    mol = model(lowdim)
    return Model(lowdim, mol, name="Decoder")

In [None]:
# discriminator tuning

def build_discriminator(params):
    model = Sequential()
    model._name = "Discriminator"
    model.add(Flatten(input_shape=(lowdim,)))
    model.add(Dense(params['neurons']))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params['neurons']))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params['neurons']))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params['neurons']))

    mol = Input(shape=(lowdim,))
    validity = model(mol)
    return Model(mol, validity, name="Discriminator")

In [None]:
class AAEModel(Model):
    def __init__(self,enc,dec,disc,lowdim,prior):
        super().__init__()
        self.enc = enc
        self.dec = dec
        self.disc = disc
        self.lowdim = lowdim
        self.prior = prior


    def compile(self,
        opt = Adam(0.0002,0.5),	# FIXME: justify
        ae_loss_fn = MeanSquaredError(),
# XXX: logits as in https://keras.io/guides/customizing_what_happens_in_fit/, 
# hope it works as the discriminator output is never used directly
        disc_loss_fn = BinaryCrossentropy(from_logits=True)	
    ):

        super().compile()
        self.opt = opt
        self.ae_loss_fn = ae_loss_fn
        self.disc_loss_fn = disc_loss_fn

        
    def train_step(self,batch):
        def _get_prior(name, shape):
            if name == "normal":
                return tf.random.normal(shape=shape)
            if name == "uniform":
                return tf.random.uniform(shape=shape)
            
            raise ValueError(f"Invalid prior type '{name}'. Choose from 'normal|uniform'")
        
        if isinstance(batch,tuple):
            batch = batch[0]

        batch_size = tf.shape(batch)[0]

# improve AE to reconstruct
        with tf.GradientTape(persistent=True) as ae_tape:
            reconstruct = self.dec(self.enc(batch))
            ae_loss = self.ae_loss_fn(batch,reconstruct)

        enc_grads = ae_tape.gradient(ae_loss, self.enc.trainable_weights)
        self.opt.apply_gradients(zip(enc_grads,self.enc.trainable_weights))

        dec_grads = ae_tape.gradient(ae_loss, self.dec.trainable_weights)
        self.opt.apply_gradients(zip(dec_grads,self.dec.trainable_weights))

# improve discriminator
        rand_low = _get_prior(self.prior, (batch_size, self.lowdim))
        better_low = self.enc(batch)
        low = tf.concat([rand_low,better_low],axis=0)

        labels = tf.concat([tf.ones((batch_size,1)), tf.zeros((batch_size,1))], axis=0)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))	# guide

        with tf.GradientTape() as disc_tape:
            pred = self.disc(low)
            disc_loss = self.disc_loss_fn(labels,pred)

        disc_grads = disc_tape.gradient(disc_loss,self.disc.trainable_weights)
        self.opt.apply_gradients(zip(disc_grads,self.disc.trainable_weights))

# teach encoder to cheat
        alltrue = tf.ones((batch_size,1))

        with tf.GradientTape() as cheat_tape:
            cheat = self.disc(self.enc(batch))
            cheat_loss = self.disc_loss_fn(alltrue,cheat)

        cheat_grads = cheat_tape.gradient(cheat_loss,self.enc.trainable_weights)
        self.opt.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))

        return { 'ae_loss' : ae_loss, 'd_loss' : disc_loss }


In [None]:
# Create function
def nn_cl_bo(neurons, activation, optimizer, learning_rate,  batch_size, epochs ):
    optimizerL = ['SGD', 'Adam', 'RMSprop', 'Adadelta', 'Adagrad', 'Adamax', 'Nadam', 'Ftrl','SGD']
    optimizerD= {'Adam':Adam(lr=learning_rate), 'SGD':SGD(lr=learning_rate),
                 'RMSprop':RMSprop(lr=learning_rate), 'Adadelta':Adadelta(lr=learning_rate),
                 'Adagrad':Adagrad(lr=learning_rate), 'Adamax':Adamax(lr=learning_rate),
                 'Nadam':Nadam(lr=learning_rate), 'Ftrl':Ftrl(lr=learning_rate)}
    activationL = ['relu', 'sigmoid', 'softplus', 'softsign', 'tanh', 'selu',
                   'elu', 'exponential', LeakyReLU,'relu']
    
    params = {}
    
    params['neurons'] = round(neurons)
    params['activation'] = activationL[round(activation)]
    params['batch_size'] = round(batch_size)
    params['epochs'] = round(epochs)
    
#     neurons = round(neurons)
#     activation = activationL[round(activation)]
#     batch_size = round(batch_size)
#     epochs = round(epochs)
    def nn_cl_fun():
        opt = Adam(lr = learning_rate)
        
        enc = build_encoder(params)
        dec = build_decoder(params)
        disc = build_discriminator(params)
        
        aae = AAEModel(enc, dec, disc, lowdim, prior)
        aae.compile(opt=opt) #also parametrizable ae_loss_fn and others
#         nn = Sequential()
#         nn.add(Dense(neurons, input_dim=10, activation=activation))
#         nn.add(Dense(neurons, activation=activation))
#         nn.add(Dense(1, activation='sigmoid'))
#         nn.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
        return aae
    es = EarlyStopping(monitor='ae_loss', mode='min', verbose=0, patience=20)
    nn = KerasClassifier(build_fn=nn_cl_fun, epochs=epochs, batch_size=batch_size,verbose=0)
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
    score = cross_val_score(nn, train, validation, scoring=score_acc, cv=kfold, fit_params={'callbacks':[es]}).mean()
    return score

In [None]:
# Set paramaters
params_nn ={
    'neurons': (10, 100),
    'activation':(0, 9),
    'optimizer':(0,7),
    'learning_rate':(0.01, 1),
    'batch_size':(64, 500),
    'epochs':(20, 100)
}
# Run Bayesian Optimization
nn_bo = BayesianOptimization(nn_cl_bo, params_nn, random_state=111)
nn_bo.maximize(init_points=25, n_iter=4)

In [None]:
params_nn_ = nn_bo.min['params']
activationL = ['relu', 'sigmoid', 'softplus', 'softsign', 'tanh', 'selu',
               'elu', 'exponential', LeakyReLU,'relu']
params_nn_['activation'] = activationL[round(params_nn_['activation'])]
params_nn_