## Load and process the dataset

In [10]:
import gzip, json
import numpy as np
import utils 
from sklearn.model_selection import train_test_split

# 20 standard amino acids
aa2idx = {'A':0, 'R':1, 'N':2, 'D':3, 'C':4, 'Q':5, 'E':6, 'G':7, 'H':8, 'I':9,
          'L':10, 'K':11, 'M':12, 'F':13, 'P':14, 'S':15, 'T':16, 'W':17, 'Y':18, 'V':19}

# load
dataset = utils.load_phipsi()

# 90% train, 10% test
train,test = train_test_split(dataset, test_size=0.1, random_state=42)

## Clustering

In [11]:
from sklearn.cluster import KMeans
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import log_loss


In [12]:
WINDOW = 15

X_train = utils.getX(train, WINDOW)
X_test = utils.getX(test, WINDOW)

phi_ref = utils.getPHI(test, WINDOW)
psi_ref = utils.getPSI(test, WINDOW)


In [58]:
NCLUST = 20

KM = KMeans(n_clusters=NCLUST, max_iter=5, random_state=42)
KM.fit(np.vstack([item['avec'] for item in train]))


KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=5,
    n_clusters=20, n_init=10, n_jobs=None, precompute_distances='auto',
    random_state=42, tol=0.0001, verbose=0)

## Test tensorflow on GPU

In [59]:
import tensorflow as tf

print("Built with GPU:", tf.test.is_built_with_cuda())
print("GPU available:", tf.test.is_gpu_available())
print("GPU device:", tf.test.gpu_device_name())


Built with GPU: True
GPU available: True
GPU device: /device:GPU:0


In [60]:
# convert sequences & dihedral clusters
# to one-hot representation
for item in train:
    item['X'] = np.eye(20)[item['sequence']]
    item['Y'] = np.eye(NCLUST)[np.array(KM.predict(item['avec']), dtype=np.int8)]
    item['X'] = item['X'][np.newaxis]
    item['Y'] = item['Y'][np.newaxis]

for item in test:
    item['X'] = np.eye(20)[item['sequence']]
    item['Y'] = np.eye(NCLUST)[np.array(KM.predict(item['avec']), dtype=np.int8)]
    item['X'] = item['X'][np.newaxis]
    item['Y'] = item['Y'][np.newaxis]


In [61]:
phi_ref = np.hstack([item['phi'] for item in test])
psi_ref = np.hstack([item['psi'] for item in test])


In [64]:
import utils
from random import shuffle

lr           = 0.001   # learning rate
l2_coef      = 0.001  # L2 penalty weight
nb_epochs    = 1000


In [None]:
with tf.Graph().as_default():
    with tf.name_scope('input'):
        features = tf.placeholder(dtype=tf.int8, shape=(1, None, 20))
        labels = tf.placeholder(dtype=tf.int8, shape=(1, None, NCLUST))

    l0 = tf.nn.relu(tf.layers.conv1d(tf.to_float(features), 60, 3, padding='SAME'))
    
    l10 = tf.nn.relu(tf.layers.conv1d(l0, 60, 5, padding='SAME'))
    l11 = tf.nn.relu(tf.layers.conv1d(l10, 60, 5, padding='SAME')+l0)

    l20 = tf.nn.relu(tf.layers.conv1d(l11, 60, 5, padding='SAME'))
    l21 = tf.nn.relu(tf.layers.conv1d(l20, 60, 5, padding='SAME')+l11)

    l30 = tf.nn.relu(tf.layers.conv1d(l21, 60, 5, padding='SAME'))
    l31 = tf.nn.relu(tf.layers.conv1d(l30, 60, 5, padding='SAME')+l21)

    l40 = tf.nn.relu(tf.layers.conv1d(l31, 60, 5, padding='SAME'))
    l41 = tf.nn.relu(tf.layers.conv1d(l40, 60, 5, padding='SAME')+l31)

    l50 = tf.nn.relu(tf.layers.conv1d(l41, 60, 5, padding='SAME'))
    l51 = tf.nn.relu(tf.layers.conv1d(l50, 60, 5, padding='SAME')+l41)

    l60 = tf.nn.relu(tf.layers.conv1d(l51, 60, 5, padding='SAME'))
    l61 = tf.nn.relu(tf.layers.conv1d(l60, 60, 5, padding='SAME')+l51)

    l70 = tf.nn.relu(tf.layers.conv1d(l61, 60, 5, padding='SAME'))
    l71 = tf.nn.relu(tf.layers.conv1d(l70, 60, 5, padding='SAME')+l61)

    l80 = tf.nn.relu(tf.layers.conv1d(l71, 60, 5, padding='SAME'))
    l81 = tf.nn.relu(tf.layers.conv1d(l80, 60, 5, padding='SAME')+l71)

    l90 = tf.nn.relu(tf.layers.conv1d(l81, 60, 5, padding='SAME'))
    l91 = tf.nn.relu(tf.layers.conv1d(l90, 60, 5, padding='SAME')+l81)
    
    l10 = tf.nn.relu(tf.layers.conv1d(l51, NCLUST, 5, padding='SAME'))
    
    out = tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.to_float(labels), logits=l10)
    prob = tf.nn.softmax(l10)
    
    loss = tf.reduce_mean(out)
    
    vars = tf.trainable_variables()
    lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if v.name not
                       in ['bias', 'gamma', 'b', 'g', 'beta']]) * l2_coef
    # optimizer
    opt = tf.train.AdamOptimizer(learning_rate=lr)

    # training op
    train_op = opt.minimize(loss+lossL2)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    total_parameters=np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
    print("tot. params: " + str(total_parameters))
    
    with tf.Session() as sess:
        sess.run(init_op)
        
        for epoch in range(nb_epochs):
            
            train_loss = 0
            step = 0
            rmse_phi = 0
            rmse_psi = 0
            shuffle(train)
            for item in train:
                _, loss_value = sess.run([train_op, loss],
                                         feed_dict={
                                             features: item['X'],
                                             labels: item['Y'] })
                step += 1
                train_loss += loss_value
            train_loss /= step

            val_loss = 0
            step = 0
            for item in test:
                loss_value,pred  = sess.run([loss, prob],
                                      feed_dict={
                                          features: item['X'],
                                          labels: item['Y'] })
                step += 1
                val_loss += loss_value

                # RMSE
                avec = np.matmul(pred.reshape((len(item['sequence']),NCLUST)), KM.cluster_centers_)
                norm_phi = np.sqrt(np.square(avec[:,0])+np.square(avec[:,1]))
                norm_psi = np.sqrt(np.square(avec[:,2])+np.square(avec[:,3]))
                phi_pred = np.arctan2(avec[:,0] / norm_phi, avec[:,1] / norm_phi)
                psi_pred = np.arctan2(avec[:,2] / norm_psi, avec[:,3] / norm_psi)
                
                rmse_phi += utils.ang_rmse(item['phi'], phi_pred)
                rmse_psi += utils.ang_rmse(item['psi'], psi_pred)

            val_loss /= step
            rmse_phi /= step
            rmse_psi /= step
        
            print("epoch {:5d} | train_loss {:8.5f} | val_loss {:8.5f} | rmse(phi) {:9.5f} | rmse(psi) {:9.5f}".
                  format(epoch, train_loss, val_loss, rmse_phi*180/np.pi, rmse_psi*180/np.pi))


tot. params: 334760
epoch     0 | train_loss  2.15998 | val_loss  2.06199 | rmse(phi)  40.21774 | rmse(psi)  78.85559
epoch     1 | train_loss  2.05160 | val_loss  2.02722 | rmse(phi)  41.75149 | rmse(psi)  77.65162
epoch     2 | train_loss  2.03211 | val_loss  2.01913 | rmse(phi)  40.98060 | rmse(psi)  78.44041
epoch     3 | train_loss  2.02714 | val_loss  2.06549 | rmse(phi)  40.82686 | rmse(psi)  79.94071
epoch     4 | train_loss  2.02437 | val_loss  2.05392 | rmse(phi)  41.81914 | rmse(psi)  77.19992
epoch     5 | train_loss  2.02204 | val_loss  2.02693 | rmse(phi)  41.77697 | rmse(psi)  77.38639
epoch     6 | train_loss  2.02090 | val_loss  2.01174 | rmse(phi)  41.64509 | rmse(psi)  76.37003
epoch     7 | train_loss  2.01927 | val_loss  2.01975 | rmse(phi)  41.37024 | rmse(psi)  76.98155
epoch     8 | train_loss  2.01729 | val_loss  2.01467 | rmse(phi)  42.19787 | rmse(psi)  78.32435
epoch     9 | train_loss  2.01766 | val_loss  2.00726 | rmse(phi)  40.64189 | rmse(psi)  77.45339


epoch    84 | train_loss  2.00921 | val_loss  2.01489 | rmse(phi)  41.09754 | rmse(psi)  78.01065
epoch    85 | train_loss  2.00914 | val_loss  2.00821 | rmse(phi)  40.60214 | rmse(psi)  77.37670
epoch    86 | train_loss  2.00957 | val_loss  2.01110 | rmse(phi)  40.31326 | rmse(psi)  77.59677
epoch    87 | train_loss  2.00988 | val_loss  2.02802 | rmse(phi)  40.74608 | rmse(psi)  78.12624
epoch    88 | train_loss  2.00963 | val_loss  2.00863 | rmse(phi)  40.19829 | rmse(psi)  77.97068
epoch    89 | train_loss  2.00941 | val_loss  2.00477 | rmse(phi)  40.76427 | rmse(psi)  77.39531
epoch    90 | train_loss  2.00952 | val_loss  2.02930 | rmse(phi)  41.66777 | rmse(psi)  80.13583
epoch    91 | train_loss  2.00861 | val_loss  2.02489 | rmse(phi)  42.63584 | rmse(psi)  81.02082
epoch    92 | train_loss  2.00984 | val_loss  2.02239 | rmse(phi)  40.26197 | rmse(psi)  76.55724
epoch    93 | train_loss  2.00933 | val_loss  2.01113 | rmse(phi)  41.14929 | rmse(psi)  78.47055
epoch    94 | train_

epoch   168 | train_loss  2.00834 | val_loss  2.00493 | rmse(phi)  40.83327 | rmse(psi)  77.06669
epoch   169 | train_loss  2.00827 | val_loss  2.00782 | rmse(phi)  40.92821 | rmse(psi)  76.79160
epoch   170 | train_loss  2.00843 | val_loss  2.02599 | rmse(phi)  41.95144 | rmse(psi)  79.67996
epoch   171 | train_loss  2.00852 | val_loss  2.00461 | rmse(phi)  40.39780 | rmse(psi)  76.57389
epoch   172 | train_loss  2.00745 | val_loss  2.02036 | rmse(phi)  39.98861 | rmse(psi)  77.20259
epoch   173 | train_loss  2.00826 | val_loss  2.00677 | rmse(phi)  40.56040 | rmse(psi)  77.24479
epoch   174 | train_loss  2.00803 | val_loss  2.00671 | rmse(phi)  41.51460 | rmse(psi)  78.43040
epoch   175 | train_loss  2.00843 | val_loss  2.01740 | rmse(phi)  40.91351 | rmse(psi)  78.59634
epoch   176 | train_loss  2.00827 | val_loss  2.00402 | rmse(phi)  40.00360 | rmse(psi)  76.79534
epoch   177 | train_loss  2.00835 | val_loss  2.00303 | rmse(phi)  40.21120 | rmse(psi)  77.12880
epoch   178 | train_

epoch   252 | train_loss  2.00811 | val_loss  2.00925 | rmse(phi)  40.75659 | rmse(psi)  78.75019
epoch   253 | train_loss  2.00793 | val_loss  2.04646 | rmse(phi)  40.14118 | rmse(psi)  78.15004
epoch   254 | train_loss  2.00786 | val_loss  2.00507 | rmse(phi)  40.65122 | rmse(psi)  77.96697
epoch   255 | train_loss  2.00795 | val_loss  2.00650 | rmse(phi)  40.09195 | rmse(psi)  76.68621
epoch   256 | train_loss  2.00786 | val_loss  2.00559 | rmse(phi)  40.32652 | rmse(psi)  76.85802
epoch   257 | train_loss  2.00810 | val_loss  2.00859 | rmse(phi)  41.00842 | rmse(psi)  77.14567
epoch   258 | train_loss  2.00790 | val_loss  2.01334 | rmse(phi)  41.17079 | rmse(psi)  77.55210
epoch   259 | train_loss  2.00787 | val_loss  2.01786 | rmse(phi)  40.05708 | rmse(psi)  77.15062
epoch   260 | train_loss  2.00765 | val_loss  2.03697 | rmse(phi)  40.80792 | rmse(psi)  77.83097
epoch   261 | train_loss  2.00816 | val_loss  2.00925 | rmse(phi)  40.75735 | rmse(psi)  76.29603
epoch   262 | train_

epoch   336 | train_loss  2.00727 | val_loss  2.01850 | rmse(phi)  40.14711 | rmse(psi)  77.28890
epoch   337 | train_loss  2.00795 | val_loss  2.03650 | rmse(phi)  42.29253 | rmse(psi)  81.42824
epoch   338 | train_loss  2.00786 | val_loss  2.01041 | rmse(phi)  41.61743 | rmse(psi)  77.93311
epoch   339 | train_loss  2.00868 | val_loss  2.00867 | rmse(phi)  40.74255 | rmse(psi)  77.98250
epoch   340 | train_loss  2.00816 | val_loss  2.01976 | rmse(phi)  40.22644 | rmse(psi)  77.52289
epoch   341 | train_loss  2.00809 | val_loss  2.01374 | rmse(phi)  40.03749 | rmse(psi)  76.75791
epoch   342 | train_loss  2.00783 | val_loss  2.00643 | rmse(phi)  40.68632 | rmse(psi)  78.09790
epoch   343 | train_loss  2.00891 | val_loss  2.01324 | rmse(phi)  40.28973 | rmse(psi)  76.84927
epoch   344 | train_loss  2.00834 | val_loss  2.00790 | rmse(phi)  40.42861 | rmse(psi)  77.34595
epoch   345 | train_loss  2.00893 | val_loss  2.00335 | rmse(phi)  40.52016 | rmse(psi)  76.72861
epoch   346 | train_

epoch   420 | train_loss  2.00707 | val_loss  2.01437 | rmse(phi)  40.66571 | rmse(psi)  77.93640
epoch   421 | train_loss  2.00790 | val_loss  2.04478 | rmse(phi)  42.24680 | rmse(psi)  81.23540
epoch   422 | train_loss  2.00745 | val_loss  2.00389 | rmse(phi)  40.58837 | rmse(psi)  77.94005
epoch   423 | train_loss  2.00808 | val_loss  2.01472 | rmse(phi)  40.07275 | rmse(psi)  77.14191
epoch   424 | train_loss  2.00802 | val_loss  2.01366 | rmse(phi)  40.25899 | rmse(psi)  76.17469
epoch   425 | train_loss  2.00828 | val_loss  2.03184 | rmse(phi)  41.72564 | rmse(psi)  79.61374
epoch   426 | train_loss  2.00753 | val_loss  2.00484 | rmse(phi)  40.59029 | rmse(psi)  76.78500
epoch   427 | train_loss  2.00807 | val_loss  2.01144 | rmse(phi)  42.30831 | rmse(psi)  77.48283
epoch   428 | train_loss  2.00796 | val_loss  2.01039 | rmse(phi)  40.46804 | rmse(psi)  77.37150
epoch   429 | train_loss  2.00729 | val_loss  2.01277 | rmse(phi)  42.17642 | rmse(psi)  77.28145
epoch   430 | train_

epoch   504 | train_loss  2.00797 | val_loss  2.00248 | rmse(phi)  41.77340 | rmse(psi)  77.45373
epoch   505 | train_loss  2.00800 | val_loss  2.00197 | rmse(phi)  40.96646 | rmse(psi)  76.94976
epoch   506 | train_loss  2.00697 | val_loss  2.03219 | rmse(phi)  42.36223 | rmse(psi)  78.36898
epoch   507 | train_loss  2.00873 | val_loss  2.00308 | rmse(phi)  40.52887 | rmse(psi)  75.69607
epoch   508 | train_loss  2.00863 | val_loss  2.07153 | rmse(phi)  40.11391 | rmse(psi)  80.26872
epoch   509 | train_loss  2.00855 | val_loss  2.04165 | rmse(phi)  40.44529 | rmse(psi)  78.66780
epoch   510 | train_loss  2.00868 | val_loss  2.03211 | rmse(phi)  40.48721 | rmse(psi)  78.89693
epoch   511 | train_loss  2.00827 | val_loss  2.01479 | rmse(phi)  40.61548 | rmse(psi)  79.60747
epoch   512 | train_loss  2.00803 | val_loss  2.01043 | rmse(phi)  40.01243 | rmse(psi)  77.47746
epoch   513 | train_loss  2.00819 | val_loss  2.01221 | rmse(phi)  40.71816 | rmse(psi)  78.74077
epoch   514 | train_

epoch   588 | train_loss  2.00826 | val_loss  2.00916 | rmse(phi)  40.59608 | rmse(psi)  78.82076
epoch   589 | train_loss  2.00852 | val_loss  2.02055 | rmse(phi)  41.27078 | rmse(psi)  79.93997
epoch   590 | train_loss  2.00761 | val_loss  2.00762 | rmse(phi)  40.65275 | rmse(psi)  77.51588
epoch   591 | train_loss  2.00837 | val_loss  2.03945 | rmse(phi)  41.81785 | rmse(psi)  80.64498
epoch   592 | train_loss  2.00766 | val_loss  2.00305 | rmse(phi)  40.99797 | rmse(psi)  77.73949
epoch   593 | train_loss  2.00767 | val_loss  2.01906 | rmse(phi)  41.42690 | rmse(psi)  76.71568
epoch   594 | train_loss  2.00730 | val_loss  2.03968 | rmse(phi)  42.69384 | rmse(psi)  82.11560
epoch   595 | train_loss  2.00803 | val_loss  2.00409 | rmse(phi)  41.39264 | rmse(psi)  78.08825
epoch   596 | train_loss  2.00759 | val_loss  2.02597 | rmse(phi)  40.11227 | rmse(psi)  78.09585
epoch   597 | train_loss  2.00827 | val_loss  2.00693 | rmse(phi)  41.40813 | rmse(psi)  77.10719
epoch   598 | train_