In [None]:
!python -m pip install keras_tuner

In [None]:
import keras_tuner
import tensorflow as tf
from tensorflow import keras
import numpy as np

from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
from src import asmsa
from src.gan import GAN
from src.visualizer import GAN_visualizer
import mdtraj as md
import nglview as nv

from tensorflow.keras.optimizers import Adam
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import CSVLogger, EarlyStopping
from keras.models import Sequential, Model
from keras.losses import BinaryCrossentropy, MeanSquaredError
from keras import backend as kb

In [None]:
# create session due to insufficient vram error

config = tf.compat.v1.ConfigProto(gpu_options = 
                         tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.8)
# device_count = {'GPU': 1}
)
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(session)

In [None]:
# Define input files
%cd ~

# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"

# input trajectory
# atom numbering must be consistent with {conf}

#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"

# input topology
# expected to be produced with 
#    gmx pdb2gmx -f {conf} -p {topol} -n {index} 

# Gromacs changes atom numbering, the index file must be generated and used as well

#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'

In [None]:
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")
#idx=tr[0].top.select("element != H")
tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)

In [None]:
# Define sparse and dense feture extensions of IC
density = 2 # integer in [1, n_atoms-1]
sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
dense_dists = asmsa.NBDistancesDense(geom.shape[0])

# mol = asmsa.Molecule(conf,topol)
# mol = asmsa.Molecule(conf,topol,fms=[sparse_dists])
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
X_train = mol.intcoord(geom).T
X_train.shape

In [None]:
molecule_shape = (X_train.shape[1],)
latent_dim = 2
prior = 'normal'

In [None]:
# encoder tuning

def build_encoder(params=[("selu", 32),
                                 ("selu", 16),
                                 ("selu", 8),
                                 ("linear", None)]):
    model = Sequential()
    # input layer
    model.add(Dense(params[0][1], input_dim=np.prod(molecule_shape), activation=params[0][0]))
    model.add(BatchNormalization(momentum=0.8))
    # hidden layers
    model.add(Dense(params[1][1], activation=params[1][0]))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(params[2][1], activation=params[2][0]))
    model.add(BatchNormalization(momentum=0.8))
    #output layer
    model.add(Dense(latent_dim, activation=params[3][0]))
    mol = Input(shape=molecule_shape)
    lowdim = model(mol)
    return Model(mol, lowdim, name="Encoder")

In [None]:
# decoder tuning

def build_decoder(params=[("selu", 8),
                                 ("selu", 16),
                                 ("selu", 32),
                                 ("linear", None)]):
    model = Sequential()
    model._name = "Decoder"
    # input layer
    model.add(Dense(params[0][1], input_dim=latent_dim, activation=params[0][0]))
    model.add(BatchNormalization(momentum=0.8))
    # hidden layers
    model.add(Dense(params[1][1], activation=params[1][0]))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(params[2][1], activation=params[2][0]))
    model.add(BatchNormalization(momentum=0.8))
    # output layer
    model.add(Dense(np.prod(molecule_shape), activation=params[3][0]))
    model.add(Reshape(molecule_shape))
    lowdim = Input(shape=(latent_dim,))
    mol = model(lowdim)
    return Model(lowdim, mol, name="Decoder")

In [None]:
# discriminator tuning

def build_discriminator(params=[(None, 512),
                                       (None, 256),
                                       (None, 256),
                                       (None, 1)]):
    model = Sequential()
    model._name = "Discriminator"
    model.add(Flatten(input_shape=(latent_dim,)))
    model.add(Dense(params[0][1]))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params[1][1]))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params[2][1]))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(params[3][1]))

    mol = Input(shape=(latent_dim,))
    validity = model(mol)
    return Model(mol, validity, name="Discriminator")

In [None]:
# adversarial autoencoder

class AAEModel(Model):
    def __init__(self,enc,dec,disc,latent_dim,prior):
        super().__init__()
        self.enc = enc
        self.dec = dec
        self.disc = disc
        self.lowdim = latent_dim
        self.prior = prior

In [None]:
# define hypermodel

class MyHyperModel(keras_tuner.HyperModel):
    def build(self, hp):
        enc = build_encoder()
        dec = build_decoder()
        disc = build_discriminator()
        
        return AAEModel(enc,dec,disc,latent_dim,prior)
    

    def fit(self, hp, model, X_train, callbacks=None, **kwargs):
        # Convert the datasets to tf.data.Dataset.
        
        batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)
#         batch_size = 256
        
        # create dataset
        dataset = tf.data.Dataset.from_tensor_slices(X_train)
        dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
        
        # set loss functions
        ae_loss_fn = MeanSquaredError()
        disc_loss_fn = BinaryCrossentropy(from_logits=True)
        opt = Adam(0.0002,0.5)
        
        # The metric to track loss across training steps
        epoch_loss_metric = keras.metrics.Mean()

        # Function to run the train step.
        @tf.function
        def run_train_step(batch):
            def _get_prior(name, shape):
                if name == "normal":
                    return tf.random.normal(shape=shape)
                if name == "uniform":
                    return tf.random.uniform(shape=shape)

                raise ValueError(f"Invalid prior type '{name}'. Choose from 'normal|uniform'")

            if isinstance(batch,tuple):
                batch = batch[0]

            batch_size = tf.shape(batch)[0]

            # improve AE to reconstruct
            with tf.GradientTape(persistent=True) as ae_tape:
                reconstruct = model.dec(model.enc(batch))
                ae_loss = ae_loss_fn(batch,reconstruct)

            enc_grads = ae_tape.gradient(ae_loss, model.enc.trainable_weights)
            opt.apply_gradients(zip(enc_grads,model.enc.trainable_weights))

            dec_grads = ae_tape.gradient(ae_loss, model.dec.trainable_weights)
            opt.apply_gradients(zip(dec_grads,model.dec.trainable_weights))

            # improve discriminator
            rand_low = _get_prior(model.prior, (batch_size, model.lowdim))
            better_low = model.enc(batch)
            low = tf.concat([rand_low,better_low],axis=0)

            labels = tf.concat([tf.ones((batch_size,1)), tf.zeros((batch_size,1))], axis=0)
            labels += 0.05 * tf.random.uniform(tf.shape(labels))	# guide

            with tf.GradientTape() as disc_tape:
                pred = model.disc(low)
                disc_loss = disc_loss_fn(labels,pred)

            disc_grads = disc_tape.gradient(disc_loss,model.disc.trainable_weights)
            opt.apply_gradients(zip(disc_grads,model.disc.trainable_weights))

            # teach encoder to cheat
            alltrue = tf.ones((batch_size,1))

            with tf.GradientTape() as cheat_tape:
                cheat = model.disc(model.enc(batch))
                cheat_loss = disc_loss_fn(alltrue,cheat)

            cheat_grads = cheat_tape.gradient(cheat_loss,model.enc.trainable_weights)
            opt.apply_gradients(zip(cheat_grads,model.enc.trainable_weights))
            
            epoch_loss_metric.update_state(ae_loss)

        # Assign the model to the callbacks.
        for callback in callbacks:
            callback.model = model

        # Record the best validation loss value
        best_epoch_loss = float("inf")

        # The custom training loop.
        for epoch in range(2):
            print(f"Epoch: {epoch}")

            # Iterate the training data to run the training step.
            for batch in dataset:
                run_train_step(batch)

            # Calling the callbacks after epoch.
            epoch_loss = float(epoch_loss_metric.result().numpy())
            for callback in callbacks:
                # The "my_metric" is the objective passed to the tuner.
                callback.on_epoch_end(epoch, logs={"ae_loss": epoch_loss})
            epoch_loss_metric.reset_states()

            print(f"Epoch loss: {epoch_loss}")
            best_epoch_loss = min(best_epoch_loss, epoch_loss)

        # Return the evaluation metric value.
        return best_epoch_loss

In [None]:
tuner = keras_tuner.RandomSearch(
    objective=keras_tuner.Objective("ae_loss", "min"),
    max_trials=10,
    hypermodel=MyHyperModel(),
    directory="results",
    project_name="custom_training",
    overwrite=True,
)

In [None]:
tuner.search(X_train)

In [None]:
tuner.results_summary()