In [None]:
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.backend as K
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import *

In [None]:
tmp = np.load('./Data/python_processed_sstubs.npz')
X = tmp['X']
Y = tmp['Y']
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
x_train.shape,x_test.shape,y_train.shape,y_test.shape

In [None]:
x_train_pid = np.repeat(np.reshape(np.arange(1,101),(1,100)),len(x_train),axis=0)
x_train_pid[x_train==1]=0
x_test_pid = np.repeat(np.reshape(np.arange(1,101),(1,100)),len(x_test),axis=0)
x_test_pid[x_test==1]=0
x_train.shape,x_test.shape,x_train_pid.shape,x_test_pid.shape,y_train.shape,y_test.shape

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()
print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
def _getPosEncodingMat(length, dim):
    posEnc = np.array([[pos/np.power(10000, 2*(j//2)/dim) for j in range(dim)]
                        if pos!=0 else np.zeros(dim) for pos in range(length)], dtype=np.float32)
    posEnc[1:, 0::2] = np.sin(posEnc[1:, 0::2])
    posEnc[1:, 1::2] = np.cos(posEnc[1:, 1::2])
    return posEnc

def categorical_crossentropyx(trues, preds):
    t = tf.one_hot(trues,depth=50265)
    return tf.metrics.categorical_crossentropy(t,preds)


In [None]:
max_len = 100
with strategy.scope():
    input1 = tfk.layers.Input(shape=(max_len, ), name='code_input')
    input2 = tfk.layers.Input(shape=(max_len, ), name='pid_input')
    
    emb = tfk.layers.Embedding(input_dim=50265, output_dim=512, name='embds')(input1)
    pidsEmbd = tfk.layers.Embedding(input_dim=max_len, output_dim=512, trainable=False, 
                                    weights=[_getPosEncodingMat(max_len, 512)], name='pid_embds')(input2)
    emb = tfk.layers.Add(name='seq_embdAdd')([emb, pidsEmbd])
    
    x1 = tfk.layers.Conv1D( 256, 11, strides=1, padding='same', activation='relu', name='conv1')(emb)
    x2 = tfk.layers.Conv1D( 256, 7, strides=1, padding='same', activation='relu', name='conv2')(emb)
    x3 = tfk.layers.Conv1D( 256, 3, strides=1, padding='same', activation='relu', name='conv3')(emb)
    x4 = tfk.layers.Conv1D( 256, 1, strides=1, padding='same', activation='relu', name='conv4')(emb)

    x = tfk.layers.concatenate([emb, x1, x2, x3,x4], axis=-1, name='concat')

    model = tfk.layers.Bidirectional( tfk.layers.GRU(units=256, return_sequences=True), name='bigru1')(x)
    model = tfk.layers.Bidirectional( tfk.layers.GRU(units=256, return_sequences=True), name='bigru2')(model)
    model = tfk.layers.Bidirectional( tfk.layers.GRU(units=512, return_sequences=True), name='bigru3')(model)

    model = tfk.layers.concatenate([model, x], axis=-1, name='concat_')

    model = tfk.layers.TimeDistributed( tfk.layers.Dense(1028, activation='relu'), name='td')(model)

    output = tfk.layers.TimeDistributed( tfk.layers.Dense(50265, activation='softmax'), name='output')(model)

    model = tfk.models.Model([input1,input2], output)
    model.compile(loss=categorical_crossentropyx, metrics=['acc'], optimizer='adam')

    model.summary()


In [None]:
model.fit([x_train, x_train_pid], y_train, verbose=1, batch_size=128, epochs=10)

In [None]:
model.save_weights('m1.h5')

In [None]:
# model.load_weights('m1.h5')
preds = model.predict([x_test[:16],x_test_pid[:16]])
preds = np.argmax(preds, axis=-1)
for i in tqdm(range(1,(len(x_test)//16)+1)):
    tmp = model.predict([x_test[i*16:(i+1)*16],x_test_pid[i*16:(i+1)*16]])
    preds = np.append(preds, np.argmax(tmp,axis=-1), axis=0)


In [None]:
correct_pred = 0
for i in tqdm(range(len(y_test))):
    y_t = y_test[i]
    y_p = preds[i]
    y_p = y_p[y_t!=1]
    y_t = y_t[y_t!=1]
    for j in range(len(y_t)):
        if y_t[j] != y_p[j]:
            break
    else:
        correct_pred += 1
correct_pred/len(y_test)