#Define callback for recording training times

In [None]:
from timeit import default_timer as timer
import keras
class TimingCallback(keras.callbacks.Callback):
    def __init__(self, logs={}):
        self.logs=[]
    def on_epoch_begin(self, epoch, logs={}):
        self.starttime = timer()
    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(timer()-self.starttime)
tcb = TimingCallback()

## sBERT Model Definition

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from util import load_WOS,  load_glove_embeddings,  create_embeddings_matrix
#Implement a Transformer block as a layer
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'rate': self.rate
        })
        return config

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
#Implement embedding layer
#Two seperate embedding layers, one for tokens, one for token index (positions).
class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim, word_index):
        self.maxlen = maxlen
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim


        #Toekem embedding loaded from glove

        embeddings_index = load_glove_embeddings(embed_dim)
        embeddings_matrix, found = create_embeddings_matrix(embeddings_index, word_index, embed_dim)

        super(TokenAndPositionEmbedding, self).__init__()



        #define position embedding
        def getPositionEncoding(seq_len, d, n=10000):
          P = np.zeros((seq_len, d))
          for k in range(seq_len):
              for i in np.arange(int(d/2)):
                  denominator = np.power(n, 2*i/d)
                  P[k, 2*i] = np.sin(k/denominator)
                  P[k, 2*i+1] = np.cos(k/denominator)
          return P

        self.token_emb = layers.Embedding(input_dim = embeddings_matrix.shape[0], output_dim=embed_dim, weights = [embeddings_matrix], trainable = True)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim, weights = [getPositionEncoding(maxlen, embed_dim)], trainable=True)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'maxlen': self.maxlen,
            'vocab_size': self.vocab_size,
            'embed_dim': self.embed_dim
        })
        return config


def create_model(emb_dim, n_blocks, n_heads, ff_dim, maxlen, vocab_size, nClasses, word_index):
  #print('Model : emb_dim = ', emb_dim, ', n_heads = ', n_heads, ', ff_dim = ', ff_dim)
  inputs = layers.Input(shape=(maxlen,))
  embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, emb_dim, word_index)
  x = embedding_layer(inputs)
  for i in range(n_blocks):
    transformer_block = TransformerBlock(emb_dim, n_heads, ff_dim)
    x = transformer_block(x)
  #x = transformer_block(x)
  x = layers.GlobalAveragePooling1D()(x)
  x = layers.Dropout(0.2)(x)
  x = layers.Dense(256, activation="relu")(x)
  x = layers.Dropout(0.2)(x)
  outputs = layers.Dense(nClasses, activation="softmax")(x)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile("adam", "categorical_crossentropy", metrics=["accuracy"])
  return model

#Function for plotting training graphs

In [None]:
import matplotlib.pyplot as plt

def plot_history(history):
    # Extract training and validation metrics from the history object
    training_loss = history.history['loss']
    validation_loss = history.history['val_loss']
    training_accuracy = history.history['accuracy']
    validation_accuracy = history.history['val_accuracy']

    epochs = range(1, len(training_loss) + 1)

    # Create a pretty plot
    plt.figure(figsize=(12, 6))

    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, training_loss, 'bo-', label='Training Loss')
    plt.plot(epochs, validation_loss, 'r*-', label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, training_accuracy, 'bo-', label='Training Accuracy')
    plt.plot(epochs, validation_accuracy, 'r*-', label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

## Experiments on WOS Datasets

In [None]:
from util import load_WOS, load_slc
from keras.models import load_model
import pandas as pd
import numpy as np
import os
import keras

vocab_size = 20000 #20000  # Only consider the top 20k words
maxlen = 250  # Only consider the first 250 words of each abstract
emb_dim = 100
n_blocks = 1
n_heads = 12
ff_dim = 256
batchsize = 16
num_epochs = 100
save_dir = 'saved_models6/'
test = False

model_name = 'sBERT_' + str(n_blocks) + '_' + str(emb_dim) + '_' +  str(n_heads) + '_' +  str(ff_dim) + '_' + str(vocab_size)
for dsname in  ["WOS5736",  "WOS11967", "WOS46985"]:#
  start = "\033[1m"
  end = "\033[0;0m"
  print('\n\n\n' + start + dsname + end)
  if 'WOS' in dsname:
    x_train, Y_train, x_test, Y_test,  word_index, nClasses  = ㅤload_WOS(dsname, vocab_size, maxlen)
  else:
    x_train, Y_train, x_test, Y_test,  word_index, nClasses  =   load_slc(dsname, maxlen, vocab_size)
  def get_model_name(model_name, dsname):
      return model_name + '_' + dsname + '.h5'
  model = create_model(emb_dim, n_blocks, n_heads, ff_dim, maxlen, vocab_size, nClasses, word_index)
  if(test == True and os.path.exists(save_dir + get_model_name(model_name, dsname))):
    model.load_weights(save_dir + get_model_name(model_name, dsname))
    #print('Evaluating!')
    results = model.evaluate(x=x_test, y = Y_test)
    results = dict(zip(model.metrics_names, results))
    print(results)
  else:
    es = keras.callbacks.EarlyStopping(monitor='val_accuracy', mode='max', min_delta=0.001, patience=8)
    checkpoint = keras.callbacks.ModelCheckpoint(save_dir + get_model_name(model_name, dsname), monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
    tcb = TimingCallback()
    callbacks_list = [tcb, checkpoint, es]
    print("Training..")
    history = model.fit(x_train, Y_train,
          epochs=10,
          batch_size = batchsize,
          callbacks=callbacks_list,
          validation_data=(x_test, Y_test),
          verbose = 1)
    plot_history(history)
    history.history["time"]=tcb.logs
    import pickle
    with open(save_dir + model_name + '_' + dsname + '_trainHistoryDict', 'wb') as file_pi:
      pickle.dump(history.history, file_pi)
    model.load_weights(save_dir + get_model_name(model_name, dsname))
    #print('!')
    results = model.evaluate(x=x_test, y = Y_test)
    results = dict(zip(model.metrics_names, results))
    print(results)
