In [None]:
import pandas as pd                     
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import shap

from constants import *
from indexing import BasicInvertedIndex
from document_preprocessor import RegexTokenizer
from ranker import Ranker, BM25, CrossEncoderScorer
from l2r import L2RFeatureExtractor, L2RRanker
from relevance import save_query_result, run_relevance_tests

In [None]:
document_preprocessor = RegexTokenizer('\\w+')
stopwords = set()
with open(STOPWORD_PATH, "r") as f:
    for word in f:
        stopwords.add(word.strip())

title_index = BasicInvertedIndex()
title_index.load(PAPER_TITLE_INDEX)
abstract_index = BasicInvertedIndex()
abstract_index.load(PAPER_ABSTRACT_INDEX)

In [None]:
print("Load docid list")
with open(DOCID_LIST_PATH, 'rb') as f:
    docid_list = pickle.load(f)

In [None]:
print("Load categories")
with open(DOC_CATEGORY_INFO_PATH, 'rb') as f:
    doc_category_info = pickle.load(f)
with open(RECOG_CATEGORY_PATH, 'rb') as f:
    recognized_categories = pickle.load(f)

In [None]:
print("Load year release")
with open(DOCID_TO_YEAR_RELEASE_PATH, 'rb') as f:
    docid_to_yr = pickle.load(f)

In [None]:
print("Load citation")
with open(DOCID_TO_CITATION_PATH, 'rb') as f:
    docid_to_citation = pickle.load(f)

In [None]:
print("Load network features")
with open(DOCID_TO_NETWORK_FEATURES_PATH, 'rb') as f:
    docid_to_network_features = pickle.load(f)

In [None]:
print("Load Cross Encoder")
cescorer = CrossEncoderScorer(abstract_index.raw_text_dict)

In [None]:
print("Initializing Feature Extractor")
feature_extractor = L2RFeatureExtractor(abstract_index, title_index,
                doc_category_info, document_preprocessor, stopwords,
                recognized_categories, docid_to_network_features, docid_to_yr, docid_to_citation, cescorer)

In [None]:
print("Initializing Ranker")
BM25scorer = BM25(abstract_index)
BM25Ranker = Ranker(abstract_index, document_preprocessor, stopwords, BM25scorer, raw_text_dict=abstract_index.raw_text_dict)

l2rRanker = L2RRanker(abstract_index, title_index, document_preprocessor, 
                    stopwords, BM25Ranker, feature_extractor)

# with open(BM25_RANKER_PATH, 'wb') as f:
#     pickle.dump(BM25Ranker, f, protocol=pickle.HIGHEST_PROTOCOL)
# with open(L2R_RANKER_PATH, 'wb') as f:
#     pickle.dump(l2rRanker, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
save_query_result(TEST_PAPER_DATA_PATH, BM25Ranker, PAPER_BM25_RANK_RESULT_PATH)

l2rRanker.train(TRAIN_PAPER_DATA_PATH)
save_query_result(TEST_PAPER_DATA_PATH, l2rRanker, PAPER_L2R_RANK_RESULT_PATH)
with open(L2R_RANKER_FITTED_PATH, 'wb') as f:
    pickle.dump(l2rRanker, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# with open(BM25_RANKER_PATH, 'rb') as f:
#     BM25Ranker = pickle.load(f)
# with open(L2R_RANKER_FITTED_PATH, 'rb') as f:
#     l2rRanker = pickle.load(f)

In [None]:
bm25_eval = run_relevance_tests(TEST_PAPER_DATA_PATH, PAPER_BM25_RANK_RESULT_PATH, id_col='docid')
l2r_eval = run_relevance_tests(TEST_PAPER_DATA_PATH, PAPER_L2R_RANK_RESULT_PATH, id_col='docid')

eval_result = dict()
eval_result['bm25_eval'] = bm25_eval
eval_result['l2r_eval'] = l2r_eval
# with open(PAPER_EVAL_RESULT_PATH, 'wb') as f:
#     pickle.dump(eval_result, f, protocol=pickle.HIGHEST_PROTOCOL)
# with open(PAPER_EVAL_RESULT_PATH, 'rb') as f:
#     pickle.load(f)


score_length = len(l2r_eval['map'])
method_count = len(eval_result)

methods = ['MAP'] * score_length * method_count + ['NDCG'] * score_length * method_count

scores = []
scores += bm25_eval['map']
scores += l2r_eval['map']
scores += bm25_eval['ndcg']
scores += l2r_eval['ndcg']

model_flags = (['BM25'] * score_length + ['L2R'] * score_length) * 2

eval_df = pd.DataFrame({"methods" : methods, "scores" : scores, "model_flags" : model_flags})
barplot = sns.boxplot(x="methods", y="scores", hue="model_flags", data=eval_df)
plt.xlabel('Evaluation Method')
plt.ylabel('Score')
plt.savefig('paper_eval.png')
plt.title('Paper Model Evaluation')

In [None]:
ft_list = ['article length', 'title length', 'query length', 'TF (abstract)', 'TF-IDF (abstract)', 'TF (title)', 'TF-IDF (title)', 
           'BM25', 'Pv Norm', 'Pagerank', 'Hub score', 'Authority score', 'Paris hierarchy 1', 'Paris hierarchy 2', 'Paris hierarchy 3', 
           'Year release', 'Citation number', 'Cross encoder score'] + [f'{cat} flag' for cat in l2rRanker.feature_extractor.recognized_categories]
X_pred = pd.DataFrame(l2rRanker.feature_vectors_collection, columns=ft_list)
explainer = shap.Explainer(l2rRanker.model.ranker)
shap_values = explainer.shap_values(X_pred)
shap.summary_plot(shap_values, X_pred, show=False)
plt.show()
plt.savefig('paper_shap.png')