# Import Libraries

In [1]:
'''
Main file to run the experiment 2:
Compare MRR and HIT 
'''

#Import Libraries
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, Embedding
from tensorflow.keras.models import Sequential
from tensorflow.keras.initializers import Constant
from gensim.models import Word2Vec
import functools
import numpy as np
import sys
import os
import pprint
from keras.preprocessing.text import Tokenizer
pp = pprint.PrettyPrinter(indent=4)
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import StratifiedKFold, KFold

# Import modules
import utils
import evaluation
# Import Baselines
from SimpleRNN import Simple_RNN_baseline
from Average_baseline import AVG_baseline

Using TensorFlow backend.


# Cont declaration

In [2]:
# Files Paths
type_of_Word2Vec_model = 'CBOW'
vector_file_name = 'wiki-db_more50_200'
vector_file_name_path = './../model/' + type_of_Word2Vec_model + '/' + vector_file_name
train_file_name = 'uni_pair_combine_less100'
train_file_path = './../dataset/train_data/'

save_model_path = './../model/'
x_file = save_model_path + 'Evaluation/' + type_of_Word2Vec_model + '_X_feature.npy'
y_file = save_model_path + 'Evaluation/' + type_of_Word2Vec_model + '_Y_label.npy'

# Integer Constant
MAX_SEQUENCE_LENGTH = 21
num_of_epochs = 1000
batch_size = 1024*32
validation_split = 0.01
# Hyperparameters Setup
embedding_dim = 200
num_hidden = 128

# Hyper parameters Setup

# Function Implementations

In [3]:
def train_evaluate(wordvec, main_baseline, x_train_cv, y_train_cv , x_test_cv, y_label_cv):
    '''
    Function to train main_baseline evaluation in Cross-validation scenario for Experiment 2
    Input: 
            main_baseline: the main baseline that need to be compare with comparison_baseline
            x_train_cv: feature matrix (X) for training, shape(90% number_of_data, MAX_SEQUENCE_LENGTH) of word_idx
            y_train_cv: label matrix (Y) for training, shape(90% number_of_data, embedding_dim) word vector of compount word
            x_test_cv: x_train_cv: feature matrix (X) for testing, shape(10% number_of_data, MAX_SEQUENCE_LENGTH) of word_idx
            y_test_cv: label matrix (Y) for testing, shape(10% number_of_data, embedding_dim) word vector of compount word

    Output:
            MRR: Mean reciprocal rank of the main_baseline
            HIT_1: HIT@1 of the main_baseline
            HIT_10: HIT@10 of the main_baseline
    '''
    ## Training Phase
    # Train the main_baseline
    main_baseline.train(x_train_cv,y_train_cv,num_of_epochs,batch_size,validation_split)

    ## Inference Phase
    # Predict result of the main_baseline
    main_baseline_y_predict = main_baseline.predict(x_test_cv)

    
    ## Testing 
    MRR, HIT_1, HIT_10 = evaluation.calculateMRR_HIT(wordvec,y_label_cv,main_baseline_y_predict)
    
    
    return MRR , HIT_1, HIT_10

# Main

In [4]:
# Load the Pretrained Word Vector from Gensim
wordvec = Word2Vec.load(vector_file_name_path) # Load the model from the vector_file_name
wordvec.wv.init_sims(replace=True)
print('Loaded Word2Vec model')

Loaded Word2Vec model


In [5]:
# Get Vocabulary Size
vocab_size = len(wordvec.wv.vocab)
print('Vocab size: ', vocab_size)

Vocab size:  968009


In [6]:
    # Prepare Train_data
fname = os.path.join(train_file_path,train_file_name)
label = utils.load_label_data_from_text_file(fname,wordvec,MAX_SEQUENCE_LENGTH) # Preprocess the input data for the model
X,Y = utils.load_data_from_text_file(fname,wordvec,MAX_SEQUENCE_LENGTH)
# X, Y = utils.load_data_from_numpy(x_file, y_file)            # Load input data from numpy file

In [7]:
 # Convert Word2Vec Gensim Model to Embedding Matrix to input into RNN
embedding_matrix = utils.Word2VecTOEmbeddingMatrix(wordvec,embedding_dim)

In [8]:
len(X)

201970

In [9]:
len(Y)

201970

In [10]:
len(label)

201970

In [28]:
print(Y[0])

[ 0.00837345  0.06357518 -0.0471146   0.06825517 -0.05427921  0.13691325
 -0.17753075  0.08130135 -0.00359477 -0.08016488  0.04961602 -0.02380891
 -0.06177854 -0.10081691  0.06255907  0.07954949 -0.10121015  0.0416428
 -0.17606759 -0.02897603 -0.04077215  0.03595511 -0.05163613 -0.09634292
  0.03125735 -0.04170639 -0.11263253 -0.05378058 -0.0231298   0.01396706
  0.02751955  0.10176833  0.02180496  0.09411184 -0.02724924 -0.00399662
 -0.02790575 -0.16606046 -0.00913196  0.04561265  0.11841544  0.00432115
 -0.06407496  0.03478611 -0.02942246  0.10347781  0.02792424  0.01999918
 -0.04653015 -0.07954869  0.02942627 -0.09890237 -0.03873263  0.04726112
  0.111301    0.0531924   0.14243779 -0.09308343 -0.01928188 -0.00358713
 -0.04230163 -0.09325445  0.01532099  0.05977419 -0.00139433 -0.10548727
 -0.01949493 -0.0683512   0.13112333 -0.04309424  0.164556   -0.0300553
 -0.00548885  0.0457101  -0.10913439 -0.09718797  0.1171404  -0.12039133
 -0.01881818  0.08144035  0.01074499  0.04531887  0.0

In [32]:
# Do Cross Validation
kFold = KFold(n_splits = 10)
#Init the Accuracy dictionary = {}
accuracy = {}
accuracy['MRR'] = np.zeros(10)
accuracy['HIT_1'] = np.zeros(10)
accuracy['HIT_10'] = np.zeros(10)
idx = 0 # Index of accuracy
for train_idx, test_idx in kFold.split(X,Y):
    # Define train and test data
#     print(train_idx)
#     print(test_idx)

    x_train_cv = X[train_idx]
    x_test_cv  = X[test_idx]

    y_train_cv = Y[train_idx]
    y_test_cv  = Y[test_idx]
    y_label_train_cv = [label[j] for j in train_idx]
    
#     print(x_train_cv[0])
#     print(y_train_cv[0])
#     print(y_label_train_cv[0])
#     print(wordvec.wv.similar_by_vector(y_train_cv[0],topn=10))
#     print('===========================')
    # Compare two baseline 
    # Define two baseline
#     main_baseline = Simple_RNN_baseline(type_of_Word2Vec_model,vocab_size,embedding_dim,embedding_matrix,MAX_SEQUENCE_LENGTH) # Init main baseline: SimpleRNN

#     accuracy['MRR'][idx],accuracy['HIT_1'][idx],accuracy['HIT_10'][idx] = train_evaluate(wordvec, main_baseline, x_train_cv, y_train_cv , x_test_cv,y_label_cv)
#     idx += 1
#     print('========= Fold {} ============='.format(idx))
#     print('MRR: {}'.format(accuracy['MRR'][idx]))
#     print('HIT@1: {}'.format(accuracy['HIT_1'][idx]))
#     print('HIT@10: {}'.format(accuracy['HIT_10'][idx]))

[     0      0      0      0      0      0      0      0      0      0
      0      0      0      0      0      0      0      0      0 585512
 242953]
[-2.96031144e-02 -3.11792106e-03 -4.27692980e-02 -7.75558874e-02
 -1.94320306e-02  1.18603073e-01 -7.95774758e-02 -9.96847302e-02
  2.19520200e-02  7.24950656e-02  7.42032379e-02 -4.38216925e-02
  5.24335029e-03 -7.37807453e-02 -3.29931304e-02 -1.46611854e-01
 -2.66297702e-02  3.75816748e-02 -2.63197422e-02  7.99580961e-02
 -6.20610593e-03 -3.35663855e-02 -3.92808281e-02  3.06023192e-02
  7.33362883e-02 -1.79714356e-02 -1.16241515e-01 -1.45088481e-02
  7.80041218e-02 -5.33989333e-02 -2.32661199e-02 -2.28107758e-02
  7.79021680e-02 -5.21597005e-02  7.78206140e-02  9.60430726e-02
  1.55481860e-01  9.14756861e-03  6.14261664e-02  2.67730057e-02
  2.02006139e-02  8.31534639e-02 -3.43203265e-03 -1.61544263e-01
  3.56948227e-02  2.07596451e-01 -3.74688618e-02  1.13737710e-01
 -5.19509651e-02  4.63786907e-02  1.13432966e-01 -1.58561990e-02
  5.

[     0      0      0      0      0      0      0      0      0      0
      0      0      0      0 442093 443905 494987 498297 498297  61380
  22192]
[ 0.00837345  0.06357518 -0.0471146   0.06825517 -0.05427921  0.13691325
 -0.17753075  0.08130135 -0.00359477 -0.08016488  0.04961602 -0.02380891
 -0.06177854 -0.10081691  0.06255907  0.07954949 -0.10121015  0.0416428
 -0.17606759 -0.02897603 -0.04077215  0.03595511 -0.05163613 -0.09634292
  0.03125735 -0.04170639 -0.11263253 -0.05378058 -0.0231298   0.01396706
  0.02751955  0.10176833  0.02180496  0.09411184 -0.02724924 -0.00399662
 -0.02790575 -0.16606046 -0.00913196  0.04561265  0.11841544  0.00432115
 -0.06407496  0.03478611 -0.02942246  0.10347781  0.02792424  0.01999918
 -0.04653015 -0.07954869  0.02942627 -0.09890237 -0.03873263  0.04726112
  0.111301    0.0531924   0.14243779 -0.09308343 -0.01928188 -0.00358713
 -0.04230163 -0.09325445  0.01532099  0.05977419 -0.00139433 -0.10548727
 -0.01949493 -0.0683512   0.13112333 -0.0430942

[     0      0      0      0      0      0      0      0      0      0
      0      0      0      0 442093 443905 494987 498297 498297  61380
  22192]
[ 0.00837345  0.06357518 -0.0471146   0.06825517 -0.05427921  0.13691325
 -0.17753075  0.08130135 -0.00359477 -0.08016488  0.04961602 -0.02380891
 -0.06177854 -0.10081691  0.06255907  0.07954949 -0.10121015  0.0416428
 -0.17606759 -0.02897603 -0.04077215  0.03595511 -0.05163613 -0.09634292
  0.03125735 -0.04170639 -0.11263253 -0.05378058 -0.0231298   0.01396706
  0.02751955  0.10176833  0.02180496  0.09411184 -0.02724924 -0.00399662
 -0.02790575 -0.16606046 -0.00913196  0.04561265  0.11841544  0.00432115
 -0.06407496  0.03478611 -0.02942246  0.10347781  0.02792424  0.01999918
 -0.04653015 -0.07954869  0.02942627 -0.09890237 -0.03873263  0.04726112
  0.111301    0.0531924   0.14243779 -0.09308343 -0.01928188 -0.00358713
 -0.04230163 -0.09325445  0.01532099  0.05977419 -0.00139433 -0.10548727
 -0.01949493 -0.0683512   0.13112333 -0.0430942