In [None]:
import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.backend as K
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import *

In [None]:
label_index = {'same_function_more_args': 1, 'wrong_method/function_name': 2, 'change_identifier_used': 3,
               'change_numeric_literal': 4, 'change_operand': 5, 'same_function_less_args': 6, 
               'more_specific_if': 7, 'change_unary_operator': 8, 'change_boolean_literal': 9,
               'same_function_wrong_caller': 10, 'change_binary_operator': 11, 'less_specific_if': 12,
               'same_function_swap_args': 13, 'change_modifier': 14, 'add_throws_exception': 15,
               'delete_throws_exception': 16, 'change_attribute_used': 17, 'change_keyword_argument_used': 18,
               'add_method_call': 19, 'add_attribute_access': 20, 'add_elements_to_iterable': 21,
               'add_function_around_expression': 22, 'change_constant_type': 23}

In [None]:
tmp = np.load('./Data/java_processed_sstubs.npz')
X = tmp['X']
Y2 = tmp['Y2']
x_train, x_test, y_train, y_test = train_test_split(X, Y2, stratify=Y2, test_size=0.2, random_state=42)
x_train.shape,x_test.shape,y_train.shape,y_test.shape

In [None]:
x_train_pid = np.repeat(np.reshape(np.arange(1,101),(1,100)),len(x_train),axis=0)
x_train_pid[x_train==1]=0
y_train = np.eye(len(label_index)+1)[y_train.reshape(-1)]
x_test_pid = np.repeat(np.reshape(np.arange(1,101),(1,100)),len(x_test),axis=0)
x_test_pid[x_test==1]=0
y_test = np.eye(len(label_index)+1)[y_test.reshape(-1)]
x_train.shape,x_test.shape,x_train_pid.shape,x_test_pid.shape,y_train.shape,y_test.shape

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()
print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
def _getPosEncodingMat(length, dim):
    posEnc = np.array([[pos/np.power(10000, 2*(j//2)/dim) for j in range(dim)]
                        if pos!=0 else np.zeros(dim) for pos in range(length)], dtype=np.float32)
    posEnc[1:, 0::2] = np.sin(posEnc[1:, 0::2])
    posEnc[1:, 1::2] = np.cos(posEnc[1:, 1::2])
    return posEnc

In [None]:
max_len = 100
with strategy.scope():
    input1 = tfk.layers.Input(shape=(max_len, ), name='code_input')
    input2 = tfk.layers.Input(shape=(max_len, ), name='pid_input')
    
    emb = tfk.layers.Embedding(input_dim=50265, output_dim=512, name='embds')(input1)
    pidsEmbd = tfk.layers.Embedding(input_dim=max_len, output_dim=512, trainable=False, 
                                    weights=[_getPosEncodingMat(max_len, 512)], name='pid_embds')(input2)
    emb = tfk.layers.Add(name='seq_embdAdd')([emb, pidsEmbd])
    
    model = tfk.layers.Bidirectional( tfk.layers.GRU(units=1024, return_sequences=True), name='bigru1')(emb)
    model = tfk.layers.Bidirectional( tfk.layers.GRU(units=1024, return_sequences=True), name='bigru2')(model)

    model = tfk.layers.concatenate([model, emb], axis=-1, name='concat')

    model = tfk.layers.TimeDistributed( tfk.layers.Dense(2056, activation='relu'), name='td')(model)
    model = tfk.layers.Flatten(name='flatten')(model)

    output = tfk.layers.Dense(len(label_index)+1, activation='softmax' ,name='output')(model)

    model = tfk.models.Model([input1,input2], output)
    model.compile(loss='categorical_crossentropy', metrics=['acc'], optimizer='adam')

    model.summary()

In [None]:
model.fit([x_train, x_train_pid], y_train, verbose=1, batch_size=256, epochs=10)

In [None]:
model.save_weights('m1.h5')

In [None]:
# model.load_weights('m1.h5')
preds = model.predict([x_test,x_test_pid], batch_size=32, verbose=1)


In [None]:
y_p = np.argmax(preds,axis=-1).flatten()
y_t = np.argmax(y_test,axis=-1).flatten()
print(classification_report(y_t,y_p))
accuracy_score(y_t,y_p)