In [None]:
import tensorflow as tf
from tensorflow import keras as k
import tensorflow_probability as tfp
import numpy as np
import math

import matplotlib.pyplot as plt

import asmsa

In [None]:
_default_hp = {
    'activation' : 'gelu',
    'ae_loss_fn': 'MeanSquaredError',
    'optimizer': 'Adam',
    'learning_rate' : 0.0002,
    'kde_sigma': 0.01,
    'kl_weight': 1.0,
}

@tf.function
def _KDEProb(ref,qry,sigma=1.):
    rsigma2 = -1./(2.*sigma*sigma)

    refs = tf.concat([tf.slice(tf.shape(qry),[0],[1]),tf.shape(ref)],0)
    qrys = tf.concat([tf.slice(tf.shape(ref),[0],[1]),tf.shape(qry)],0)
 
    mref = tf.broadcast_to(ref,refs)
    mqry = tf.broadcast_to(qry,qrys)
    mqry = tf.transpose(mqry,[1,0,2]) # XXX: exactly 1D shape of latent space dim

    dist2 = tf.math.reduce_sum(
        tf.math.pow(mref-mqry,2),
        axis=2
    )
    kdes = tf.exp(dist2 * rsigma2)
    kde = tf.reduce_mean(kdes,axis=1)\
        * tf.math.pow(tf.constant(2.*math.pi), tf.cast(tf.shape(ref)[1],tf.float32) * -.5)\
        / sigma
    return kde

class VAEModel(k.models.Model):
    def __init__(self,mol_shape,latent_dim=2,ae_layers=[64,32,8],
                 prior=tfp.distributions.MultivariateNormalDiag(loc=[0.,0.]),hp=_default_hp):
        
        super().__init__()
        self.hp = hp
        self.latent_dim = latent_dim
        self.prior = prior

        inp = k.Input(shape = mol_shape)
        out = inp

        for i,n in enumerate(ae_layers):
            out = k.layers.Dense(n,activation=hp['activation'],name=f'enc_{i}')(out)
            out = k.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_{i}')(out)

        out = k.layers.Dense(latent_dim,name='enc_out')(out)
        latent = out

        for i,n in enumerate(reversed(ae_layers)):
            out = k.layers.Dense(n,activation=hp['activation'],name=f'dec_{i}')(out)
            out = k.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{i}')(out)

        out = k.layers.Dense(mol_shape,activation=hp['activation'],name='dec_out')(out)

        self.enc = k.Model(inputs=inp,outputs=latent)
        self.dec = k.Model(inputs=latent,outputs=out)

    def compile(self,optimizer=None,lossfn=k.losses.MeanSquaredError()):
        if optimizer is None:
            optimizer = self.hp['optimizer']

        if isinstance(optimizer,str):
            optimizer = k.optimizers.legacy.__dict__[optimizer]

        super().compile(optimizer = optimizer(learning_rate=self.hp['learning_rate']))
        self.ae_weights = self.enc.trainable_weights + self.dec.trainable_weights
        self.lossfn = lossfn
   
    @tf.function
    def train_step(self,batch):
        prob_shift = 1e-15
        
        if isinstance(batch,tuple):
            batch = batch[0]

        with tf.GradientTape() as aetape:
            out = self.dec(self.enc(batch))
            loss = self.lossfn(batch,out)

        ae_grad = aetape.gradient(loss,self.ae_weights)
        self.optimizer.apply_gradients(zip(ae_grad,self.ae_weights))

        prior_sample = self.prior.sample(tf.shape(batch)[0])
        prior_mean = tf.reduce_mean(prior_sample,axis=0)
        with tf.GradientTape() as kltape:
            latent = self.enc(batch)
            latent_mean = tf.reduce_mean(latent,axis=0)
            latent_prior = latent_mean - prior_mean
            mean_dist2 = tf.reduce_sum(latent_prior * latent_prior)
#            tf.print(latent_mean, prior_mean, mean_dist2)
            
            both = tf.concat([latent,prior_sample],axis=0)
            latent_prob = _KDEProb(latent,both,self.hp['kde_sigma']) + prob_shift
            prior_prob = self.prior.prob(both) + prob_shift
            kl = tf.reduce_mean(tf.math.log(latent_prob / prior_prob))
            kl_loss = kl * kl * self.hp['kl_weight'] + mean_dist2

        kl_grad = kltape.gradient(kl_loss,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(kl_grad,self.enc.trainable_weights))

#        tf.print('\nstep: ',kl)

        return { 'AE loss' : loss, 'KL' : kl, 'mean distance' : mean_dist2}


In [None]:
conf='alaninedipeptide_H.pdb'
gro='aladip_H.gro'
topol='aladip_H.top'
index=None
traj='alaninedipeptide_reduced.xtc'

In [None]:
import mdtraj as md
import nglview as nv

In [None]:
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")

# for trivial cases like Ala-Ala, where superposing on CAs fails
idx=tr[0].top.select("element != H")

tr.superpose(tr[0],atom_indices=idx)
nv.show_mdtraj(tr)

In [None]:
# gmx pdb2gmx -f alaninedipeptide_H.pdb -o aladip_H.gro -p aladip_H.top -n aladip_H.ndx -water tip3p -ff amber99 -ignh

In [None]:
np.random.shuffle(tr.xyz)

In [None]:
train = .7
validation = .15
test = .15

assert train + validation + test == .9999999999999999 or 1

tr_i = len(tr) * train
X_train = tr.slice(slice(0,int(tr_i)))

va_i = len(tr) * validation
X_validate = tr.slice(slice(int(tr_i),int(tr_i)+int(va_i)))

te_i = len(tr) * test
X_test = tr.slice(slice(int(tr_i)+int(va_i),len(tr)))

X_train.xyz.shape, X_validate.xyz.shape, X_test.xyz.shape

In [None]:
X_train.save_xtc('train.xtc')
X_validate.save_xtc('validate.xtc')
X_test.save_xtc('test.xtc')

In [None]:
trajs = [X_train, X_validate, X_test]
geoms = []

for i in range(len(trajs)):
    geoms.append(np.moveaxis(trajs[i].xyz,0,-1))
    print(geoms[i].shape)

In [None]:
tf.data.Dataset.from_tensor_slices(geoms[0]).save('datasets/geoms/train')
tf.data.Dataset.from_tensor_slices(geoms[1]).save('datasets/geoms/validate')
tf.data.Dataset.from_tensor_slices(geoms[2]).save('datasets/geoms/test')

In [None]:
mols = []
for i in range(len(geoms)):
    sparse_dists = asmsa.NBDistancesDense(geoms[i].shape[0])
    mols.append(asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists]))


In [None]:
intcoords = []
for i in range(len(mols)):
    intcoords.append(mols[i].intcoord(geoms[i]).T)
    print(intcoords[i].shape)

In [None]:
tf.data.Dataset.from_tensor_slices(intcoords[0]).save('datasets/intcoords/train')
tf.data.Dataset.from_tensor_slices(intcoords[1]).save('datasets/intcoords/validate')
tf.data.Dataset.from_tensor_slices(intcoords[2]).save('datasets/intcoords/test')

In [None]:
[train,validate,test] = intcoords

In [None]:
train_batch = tf.data.Dataset.from_tensor_slices(train).batch(1024,drop_remainder=True)

In [None]:
class _Uniform2d:
    def __init__(self,low=-1.,high=1):
        self.low = tf.constant(low,tf.float32)
        self.high = tf.constant(high,tf.float32)
        self.density = 1./(self.high-self.low)**2
        self.uniform = tfp.distributions.Uniform(low=low,high=high)

    def prob(self,samples):
        out = tf.broadcast_to(self.density,shape=[tf.shape(samples)[0]])
        out *= tf.cast(samples[:,0] >= self.low,tf.float32)
        out *= tf.cast(samples[:,0] <= self.high,tf.float32)
        out *= tf.cast(samples[:,1] >= self.low,tf.float32)
        out *= tf.cast(samples[:,1] <= self.high,tf.float32)
        return out

    def sample(self,n):
        return self.uniform.sample([n,2])
        
        

In [None]:
prior=tfp.distributions.MultivariateNormalDiag(loc=[0.,0.])
# prior = _Uniform2d()
mod=VAEModel(train.shape[1], 
            ae_layers=[16,8],
            prior=prior)
mod.compile()

In [None]:
mod.fit(train_batch,epochs=10)

In [None]:
lows = mod.enc(test).numpy()

plt.scatter(lows[:,0],lows[:,1],marker='.')
plt.show()

In [None]:
rg = md.compute_rg(X_test)
base = md.load(conf)
rmsd = md.rmsd(X_test,base[0])
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("Rg")
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rmsd,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.title("RMSD")
plt.show()

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(121)
dih=md.compute_dihedrals(X_test,np.array([[4,6,8,14],[6,8,14,16]]))
plt.scatter(lows[:,0],lows[:,1],marker='.',c=dih[:,0],cmap=cmap)
plt.colorbar(cmap=cmap)
plt.subplot(122)
plt.scatter(lows[:,0],lows[:,1],marker='.',c=dih[:,1],cmap=cmap)
plt.colorbar(cmap=cmap)
plt.show()

In [None]:
lows.shape

In [None]:
lowbatch = train[:16384,:]
p = _KDEProb(lowbatch,lowbatch,sigma=.01)

In [None]:
prior=tfp.distributions.MultivariateNormalDiag(loc=[2.,0.])

In [None]:
tf.reduce_mean(tf.math.log(p)-prior.log_prob(lowbatch))

In [None]:
mod.dec.summary()

In [None]:
norm = tfp.distributions.Normal(loc=0.,scale=1.)

In [None]:
n = 1000
s=norm.sample(n)
#s+=3

In [None]:
tf.reduce_mean(norm.log_prob(s))

In [None]:
ss = np.sort(s.numpy())
#plt.plot(ss)
plt.plot(ss,norm.prob(ss))
plt.show()

In [None]:
tf.reduce_sum(norm.prob(ss))

In [None]:
tf.shape(s)[0]

In [None]:
tf.broadcast_to(s,[2,*list(tf.shape(s).numpy())])

In [None]:
[2,*list(tf.shape(s).numpy())]

In [None]:
list(tf.shape(s).numpy())

In [None]:
tf.shape(s)[0].numpy()

In [None]:
_KDEProb(tf.constant([[1,1],[4,4]],tf.float32),tf.constant([[1.1,1],[2.5,2.5],[8.,10.]],tf.float32))

In [None]:
ref=tf.constant([1,4],tf.float32)

In [None]:
tf.slice(ref,[1],[1])

In [None]:
n=100
g = tf.repeat([tf.linspace(tf.constant(-3,tf.float32),tf.constant(3,tf.float32),n)],n,axis=0)
g = tf.reshape(tf.stack([g,tf.transpose(g)],axis=2),[n*n,2])

In [None]:
ref=tf.random.uniform(shape=[5,2],minval=-.05,maxval=.2)
ref

In [None]:
tf.shape(ref)[1]

In [None]:
plt.scatter(ref[:,0],ref[:,1])
plt.show()

In [None]:
p = _KDEProb(tf.constant([[1,1],[4,4.5],[1.2,3.3]],tf.float32),g,sigma=1)
p = _KDEProb(ref,g,sigma=2.)
p = _KDEProb(tf.constant([[0,0],[0,2]],tf.float32),g,sigma=.8)


In [None]:
plt.imshow(tf.reshape(p,[n,n]))
plt.colorbar()
plt.show()

In [None]:
mnorm = tfp.distributions.MultivariateNormalDiag(loc=[0.,0.])

In [None]:
plt.scatter(g[:,0],g[:,1],c=mnorm.prob(g))
plt.colorbar()
plt.show()

In [None]:
tf.reduce_mean(tf.math.log(p)-mnorm.log_prob(g))
#tf.math.log(p)
#p
#mnorm.log_prob(g)

In [None]:
plt.plot(p)
plt.plot(mnorm.prob(g))
plt.show()

In [None]:
s=mnorm.sample(10000)
p=mnorm.prob(s)
plt.scatter(s[:,0],s[:,1],c=p)
plt.show()

In [None]:
nn=1000
ref=tf.random.uniform(shape=[nn,2],minval=-.05,maxval=.05)

In [None]:
mnorm = tfp.distributions.MultivariateNormalDiag(loc=[0.2,0.])

In [None]:
p = _KDEProb(ref,ref,sigma=1)

In [None]:
plt.scatter(ref[:,0],ref[:,1],c=p)
plt.colorbar()
plt.show()

In [None]:
tf.reduce_mean(tf.math.log(p)-mnorm.log_prob(ref))

In [None]:
a=np.array([[1,2,3],[4,5,6]])

In [None]:
np.random.shuffle(a)
a

In [None]:
tr.xyz.shape

In [None]:
pb = _KDEProb(mod.enc(lowbatch),g,sigma=_default_hp['kde_sigma'])
plt.imshow(tf.reshape(pb,[n,n]))
plt.colorbar()
plt.show()

In [None]:
pp = prior.prob(g)
plt.imshow(tf.reshape(pp,[n,n]))
plt.colorbar()
plt.show()

In [None]:
tf.reduce_mean(tf.math.log(tf.cast(pb,tf.float64))-tf.math.log(tf.cast(pp,tf.float64)))

In [None]:
l=mod.enc(lowbatch)
pm=_KDEProb(l,l,sigma=_default_hp['kde_sigma'])
tf.reduce_mean(tf.math.log(pm)-prior.log_prob(l))

In [None]:
prior.sample(10)

In [None]:
tf.constant(1.14,shape=(2,2))