In [2]:
# %pip install tensorflow==2.4.1
# %pip install transformers
# %pip install pyarrow
# %pip install tensorflow-addons

In [3]:
%tensorflow_version 2.x
import tensorflow as tf
import pandas as pd
import pickle
import os
import tensorflow_addons as tfa
import math
from math import ceil
from tensorflow.keras.utils import plot_model
from transformers import RobertaTokenizer, RobertaTokenizerFast, TFRobertaModel, TFAlbertModel

AUTO = tf.data.experimental.AUTOTUNE

In [4]:
model_iteration = 'iteration_1'

In [5]:
test_data = pd.read_parquet(f"/content/drive/My Drive/Colab Notebooks/mag_model/test_data_{model_iteration}/data_with_predictions.parquet")
test_data['target_test'] = test_data['target_tok'].apply(lambda x: [i for i in x if i!=-1])
test_data['target_test'] = test_data['target_test'].apply(len)
test_data = test_data[test_data['target_test'] > 0].copy()

In [6]:
with open(f"/content/drive/My Drive/Colab Notebooks/mag_model/vocab_{model_iteration}/topics_vocab.pkl", "rb") as f:
    target_vocab = pickle.load(f)
    
target_vocab_inv = {j:i for i,j in target_vocab.items()}

with open(f"/content/drive/My Drive/Colab Notebooks/mag_model/vocab_{model_iteration}/doc_type_vocab.pkl", "rb") as f:
    doc_vocab = pickle.load(f)
    
doc_vocab_inv = {j:i for i,j in doc_vocab.items()}

with open(f"/content/drive/My Drive/Colab Notebooks/mag_model/vocab_{model_iteration}/journal_name_vocab.pkl", "rb") as f:
    journal_vocab = pickle.load(f)
    
journal_vocab_inv = {j:i for i,j in journal_vocab.items()}


In [7]:
encoding_layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
    max_tokens=len(target_vocab)+1, output_mode="binary", sparse=False)

In [8]:
mag_model = tf.keras.models.load_model(f'/content/drive/My Drive/Colab Notebooks/mag_model/model_{model_iteration}/iteration_first_try_epoch_6')

In [9]:
final_model = tf.keras.Model(inputs=mag_model.inputs, 
                             outputs=tf.math.top_k(mag_model.outputs, k=30))

In [10]:
final_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
paper_title_ids (InputLayer)    [(None, 512)]        0                                            
__________________________________________________________________________________________________
title_embedding (Embedding)     (None, 512, 512)     15360512    paper_title_ids[0][0]            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 512, 1024)    525312      title_embedding[0][0]            
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 512, 1024)    0           dense_1[0][0]                    
______________________________________________________________________________________________

In [11]:
def thresh_preds(preds, scores, threshold):
  new_preds = [x for x,y in zip(preds, scores) if y >= threshold]
  if len(new_preds) == 0:
    new_preds = preds[:1]
  return new_preds

In [12]:
thresh=35
test_data[f'predictions_{str(thresh)}'] = test_data.apply(lambda x: thresh_preds(x.predictions, x.scores, thresh/100), axis=1)

In [13]:
test_raw = pd.read_parquet(f"/content/drive/My Drive/Colab Notebooks/mag_model/test_data_{model_iteration}/test_extra_data.parquet")

In [13]:
test_raw.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150630 entries, 0 to 150629
Data columns (total 5 columns):
 #   Column       Non-Null Count   Dtype 
---  ------       --------------   ----- 
 0   paper_id     150630 non-null  int64 
 1   paper_title  150630 non-null  object
 2   year         150630 non-null  int32 
 3   month        150630 non-null  int32 
 4   topic_len    150630 non-null  int32 
dtypes: int32(3), int64(1), object(1)
memory usage: 4.0+ MB


In [14]:
test_raw[test_raw['paper_id']==9881621]

Unnamed: 0,paper_id,paper_title,year,month,topic_len
18361,9881621,intersecting influences in american haiku,2001,1,3


In [21]:
test_data[test_data['paper_id']==2966288531]

Unnamed: 0,paper_id,publication_date,doc_type_tok,journal_tok,target_tok,paper_title_tok,paper_title_mask,predictions,scores,target_test,predictions_25
6949,2966288531,2018-06-01,[2],[2],"[849, 2734, 1667]","[2, 534, 172, 14257, 147, 363, 14, 3815, 16, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2734, 6, 849, 13144, 36369, 425, 1015, 3804, ...","[0.585706889629364, 0.5703823566436768, 0.5270...",3,"[2734, 6, 849]"


In [17]:
for_examples = test_data[['paper_id','target_tok','predictions',f'predictions_{thresh}']].merge(test_raw[['paper_id','paper_title']], 
                                                                                                how='left', on='paper_id').sample(25).reset_index(drop=True)

In [26]:
for_examples

Unnamed: 0,paper_id,target_tok,predictions,predictions_25,paper_title
0,2375751698,"[20, 2922, 2325, 2, 2468, 66289, 402]","[2, 20, 5463, 576, 10, 95, 3321, 4443, 12862, ...","[2, 20, 5463, 576, 10, 95, 3321, 4443]",diagnosis and treatment of 76 cases with pulmo...
1,3093817657,"[1171, 83, 150, 4, 1248, 2475, 1688, 9411, 494...","[4, 150, 1033, 63, 83, 1822, 190, 1108, 1477, ...","[4, 150, 1033, 63, 83, 1822, 190, 1108, 1477, ...",efficient activation of peroxymonosulfate on c...
2,1979042108,"[68, 183567, 7024, 1132, 154, 2505, 3555, 2582...","[13, 4840, 8955, 136, 2871, 864, 3, 254, 52, 3...","[13, 4840, 8955, 136, 2871, 864, 3]",a problem in the theory of numerical estimation
3,3145918133,[110139],"[110139, 97511, 23, 68588, 3, 462, 2, 1314, 23...","[110139, 97511, 23, 68588, 3, 462, 2, 1314, 2305]",automatic applanation tonometer
4,3112076344,"[73, 21, 84]","[52, 24, 330, 21, 39, 84, 32, 2270, 172, 336, ...","[52, 24, 330, 21, 39, 84]",the life and work of ernesto de martino italia...
5,3134760140,"[880, 84949, 66, 14, 1424, 6868, 3818, 23418, ...","[51, 14, 23418, 880, 1558, 20066, 8122, 22854,...","[51, 14, 23418, 880, 1558, 20066, 8122, 22854,...",obligations erga omnes and the question of sta...
6,2858971720,"[43, 22, 8805, 441, 81, 3, 175, 30267, 101]","[3, 22, 101, 38, 81, 177, 43, 75, 47, 26, 230,...","[3, 22, 101, 38, 81, 177, 43]",charging system with detection function
7,3151379152,"[946, 1056, 775, 1084, 1917, 22672, 1154, 14, ...","[22672, 535, 251, 14, 29, 733, 4685, 16, 1917,...","[22672, 535, 251, 14, 29, 733, 4685, 16, 1917,...",influence of policy discourse networks on loca...
8,2906970193,"[11, 1026, 276, 508, 7132, 41089, 2946, 10320,...","[603, 11, 2946, 1074, 431, 14, 51, 1895, 356, ...","[603, 11, 2946, 1074, 431, 14, 51]",methods of developing stress resistance of law...
9,2391137018,"[3, 1178, 2315, 542, 613, 28915, 17396, 221399...","[2315, 3, 28915, 19871, 613, 32288, 43033, 264...","[2315, 3, 28915, 19871, 613, 32288, 43033, 264...",the theory and implement method of java rmi


In [20]:
for i in range(25):
  print(for_examples.iloc[i,0])
  print(for_examples.iloc[i,-1])
  print(f"Tags: {[target_vocab_inv.get(x) for x in for_examples.iloc[i,1]]}")
  print(f"Predictions: {[target_vocab_inv.get(x) for x in for_examples.iloc[i,3]]}")
  print("----------------------------------------------------------------------")
  print("\n")

2375751698
diagnosis and treatment of 76 cases with pulmonary contusion
Tags: ['surgery', 'mechanical ventilation', 'therapeutic effect', 'medicine', 'airway', 'pulmonary contusion', 'retrospective cohort study']
Predictions: ['medicine', 'surgery', 'pulmonary tuberculosis', 'lung', 'internal medicine', 'radiology', 'pulmonary function testing', 'clinical pathology']
----------------------------------------------------------------------


3093817657
efficient activation of peroxymonosulfate on cobalt hydroxychloride nanoplates through hydrogen bond for degradation of tetrabromobisphenol a
Tags: ['cobalt', 'catalysis', 'nuclear chemistry', 'chemistry', 'reactivity', 'electron transfer', 'reaction rate constant', 'singlet oxygen', 'tetrabromobisphenol a', 'hydrogen bond']
Predictions: ['chemistry', 'nuclear chemistry', 'degradation', 'inorganic chemistry', 'catalysis', 'phenol', 'aqueous solution', 'decomposition', 'hydrogen peroxide', 'persulfate', 'chemical engineering', 'cobalt', 'kin

In [14]:
def get_metrics(data, target_col, predict_col):

    targets = data[target_col].to_list()
    predictions = data[predict_col].to_list()
    
    recall_score = tf.keras.metrics.Recall()
    precision_score = tf.keras.metrics.Precision()
    accuracy = tf.keras.metrics.CategoricalAccuracy()

    for i,j in zip(targets, predictions):
      recall_score.update_state(encoding_layer([i]),encoding_layer([j]))
      precision_score.update_state(encoding_layer([i]),encoding_layer([j]))
      accuracy.update_state(encoding_layer([i]),encoding_layer([j]))
    
    print(f"Recall: {round(recall_score.result().numpy()*100, 1)}%")
    print(f"Precision: {round(precision_score.result().numpy()*100, 1)}%")
    print(f"Accuracy: {round(accuracy.result().numpy()*100, 1)}%")

In [51]:
get_metrics(test_data.sample(10000), 'target_tok', f'predictions_25')

Recall: 44.1%
Precision: 32.9%
Accuracy: 65.1%


In [52]:
get_metrics(test_data.sample(10000), 'target_tok', 'predictions_35')

Recall: 31.0%
Precision: 53.8%
Accuracy: 67.1%


In [53]:
get_metrics(test_data.sample(10000), 'target_tok', 'predictions_50')

Recall: 16.4%
Precision: 74.6%
Accuracy: 53.6%


In [28]:
get_metrics(test_data.sample(10000), 'target_tok', f'predictions_{thresh}')

Recall: 46.2%
Precision: 36.4%
Accuracy: 64.0%


### Levels

In [36]:
levels_df = pd.read_parquet(f"/content/drive/My Drive/Colab Notebooks/mag_model/test_data_{model_iteration}/tag_levels.parquet").fillna(6)
levels_df['level'] = levels_df['level'].astype('int')
levels_df['topic_tok'] = levels_df['topic_name'].apply(lambda x: target_vocab.get(x))
levels_df = levels_df[~levels_df['topic_tok'].isnull()].copy()
levels_df['topic_tok'] = levels_df['topic_tok'].astype('int')

In [37]:
def fill_blank(arr):
  if not type(arr) == list:
    arr=[0]
  return arr

In [38]:
def get_df_for_specific_level(old_df, levels, level_to_get=1, pred_col="predictions"):
    df = old_df.copy()
    tags_list = levels[levels['level']==level_to_get]['topic_name'].to_list()
    tags_id_list = [target_vocab.get(x) for x in tags_list]
    print(f"Number of Tags: {len(tags_list)}")
    
    # df[f'tags_level_{level_to_get}'] = df['target_tok'].apply(lambda x: [i for i in x if i in tags_id_list])
    # df[f'preds_level_{level_to_get}'] = df[pred_col].apply(lambda x: [i for i in x if i in tags_id_list])

    df_preds = df.explode(pred_col).merge(levels[levels['level']==level_to_get], 
                                                  how='inner', left_on=pred_col, right_on='topic_tok')
    
    df_targs = df.explode('target_tok').merge(levels[levels['level']==level_to_get], 
                                                  how='inner', left_on='target_tok', right_on='topic_tok')
    
    df_preds_1 = df_preds[['paper_id',pred_col]].groupby('paper_id')[pred_col].apply(list).reset_index()
    df_targs_2 = df_targs[['paper_id','target_tok']].groupby('paper_id')['target_tok'].apply(list).reset_index()

    new_df = df_preds_1.merge(df_targs_2, how='outer', on='paper_id')
    
    papers_with_tag = (new_df[~new_df['target_tok'].isnull()].shape[0])/df.shape[0]
    papers_with_pred = (new_df[~new_df[pred_col].isnull()].shape[0])/df.shape[0]

    new_df['target_tok'] = new_df['target_tok'].apply(lambda x: fill_blank(x))
    new_df[pred_col] = new_df[pred_col].apply(lambda x: fill_blank(x))
    
    print(f"Percentage of papers with Level {level_to_get} Tags: {round(papers_with_tag*100, 1)}")
    print(f"Percentage of papers with Level {level_to_get} Preds: {round(papers_with_pred*100, 1)}")
    
    return new_df

In [54]:
print(f"Threshold: {thresh}")
for level_to_check in range(0,7):
  print(f"Topic Level: {level_to_check}")
  df = get_df_for_specific_level(test_data.sample(10000), levels_df, level_to_check, f"predictions_{thresh}")
  get_metrics(df, 'target_tok', f'predictions_{thresh}')
  print("------------------------------------------------------------")
  print("\n")

Threshold: 35
Topic Level: 0
Number of Tags: 19
Percentage of papers with Level 0 Tags: 99.3
Percentage of papers with Level 0 Preds: 93.2
Recall: 73.3%
Precision: 65.2%
Accuracy: 68.0%
------------------------------------------------------------


Topic Level: 1
Number of Tags: 292
Percentage of papers with Level 1 Tags: 98.4
Percentage of papers with Level 1 Preds: 75.0
Recall: 46.2%
Precision: 42.7%
Accuracy: 39.4%
------------------------------------------------------------


Topic Level: 2
Number of Tags: 91991
Percentage of papers with Level 2 Tags: 93.1
Percentage of papers with Level 2 Preds: 77.9
Recall: 21.2%
Precision: 48.4%
Accuracy: 24.7%
------------------------------------------------------------


Topic Level: 3
Number of Tags: 141305
Percentage of papers with Level 3 Tags: 76.6
Percentage of papers with Level 3 Preds: 45.7
Recall: 18.4%
Precision: 29.1%
Accuracy: 17.0%
------------------------------------------------------------


Topic Level: 4
Number of Tags: 99839
P

### Journal and Doc Types

In [40]:
test_data['journal'] = test_data['journal_tok'].apply(lambda x: [journal_vocab_inv.get(i) for i in x][0])
test_data['doc_type'] = test_data['doc_type_tok'].apply(lambda x: [doc_vocab_inv.get(i) for i in x][0])

In [55]:
for doc_type in test_data['doc_type'].value_counts().index:
    print(doc_type)
    num_samples = test_data[test_data['doc_type']==doc_type].shape[0]
    if num_samples >= 5000:
      num_to_sample = 5000
    else:
      num_to_sample = num_samples
    get_metrics(test_data[test_data['doc_type']==doc_type].sample(num_to_sample), "target_tok", f"predictions_{thresh}")
    print("-----------------------------------------------------------------------")
    print("\n")

Journal
Recall: 32.0%
Precision: 53.2%
Accuracy: 72.9%
-----------------------------------------------------------------------


[NONE]
Recall: 31.1%
Precision: 53.5%
Accuracy: 54.8%
-----------------------------------------------------------------------


Patent
Recall: 27.7%
Precision: 58.9%
Accuracy: 61.8%
-----------------------------------------------------------------------


Conference
Recall: 29.4%
Precision: 53.6%
Accuracy: 72.8%
-----------------------------------------------------------------------


Repository
Recall: 27.8%
Precision: 56.0%
Accuracy: 74.1%
-----------------------------------------------------------------------


Book
Recall: 34.4%
Precision: 52.4%
Accuracy: 51.6%
-----------------------------------------------------------------------


Thesis
Recall: 30.4%
Precision: 52.4%
Accuracy: 55.9%
-----------------------------------------------------------------------


BookChapter
Recall: 39.8%
Precision: 54.9%
Accuracy: 54.4%
--------------------------------------

### Paper Title Length

In [None]:
# 0 - 20, 20 - 40, 40+

In [16]:
test_data['paper_title_tok_len'] = test_data['paper_title_tok'].apply(len)

In [32]:
(test_data[test_data['paper_title_tok_len'] < 10].shape[0])/test_data.shape[0]

0.12967622436582596

In [33]:
(test_data[(test_data['paper_title_tok_len'] >= 10) & (test_data['paper_title_tok_len']<20)].shape[0])/test_data.shape[0]

0.48596219851423034

In [34]:
(test_data[(test_data['paper_title_tok_len'] >= 20) & (test_data['paper_title_tok_len']<40)].shape[0])/test_data.shape[0]

0.3599240518094125

In [35]:
(test_data[test_data['paper_title_tok_len'] >= 40].shape[0])/test_data.shape[0]

0.024437525310531172

In [21]:
get_metrics(test_data[test_data['paper_title_tok_len'] < 10].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 27.8%
Precision: 56.5%
Accuracy: 57.6%


In [22]:
get_metrics(test_data[(test_data['paper_title_tok_len'] >= 10) & (test_data['paper_title_tok_len']<20)].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 30.1%
Precision: 54.3%
Accuracy: 65.8%


In [23]:
get_metrics(test_data[(test_data['paper_title_tok_len'] >= 20) & (test_data['paper_title_tok_len']<40)].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 32.9%
Precision: 53.1%
Accuracy: 73.2%


In [24]:
get_metrics(test_data[test_data['paper_title_tok_len'] >= 40], "target_tok", f"predictions_{thresh}")

Recall: 32.7%
Precision: 52.3%
Accuracy: 68.4%


### Time

In [25]:
explore_time = test_data.merge(test_raw, how='left', on='paper_id')

In [26]:
explore_time['yearMonth'] = explore_time['year']*100+explore_time['month']

In [29]:
get_metrics(explore_time[explore_time['yearMonth'] >= 202106].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 30.7%
Precision: 51.9%
Accuracy: 72.9%


In [30]:
get_metrics(explore_time[explore_time['yearMonth'] < 202106].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 30.9%
Precision: 55.4%
Accuracy: 65.0%


In [51]:
get_metrics(explore_time[explore_time['year'] <= 2002].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 32.8%
Precision: 54.5%
Accuracy: 65.3%


In [52]:
get_metrics(explore_time[(explore_time['year'] >= 2003) & (explore_time['year'] <= 2008)].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 31.1%
Precision: 55.7%
Accuracy: 63.7%


In [53]:
get_metrics(explore_time[(explore_time['year'] >= 2009) & (explore_time['year'] <= 2014)].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 30.3%
Precision: 55.7%
Accuracy: 62.9%


In [55]:
get_metrics(explore_time[(explore_time['year'] >= 2015) & (explore_time['year'] <= 2020)].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 30.0%
Precision: 55.7%
Accuracy: 65.4%


In [56]:
get_metrics(explore_time[explore_time['year'] == 2021].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 31.0%
Precision: 51.8%
Accuracy: 73.3%


### Journal and Doc null vs not null

In [44]:
get_metrics(test_data[(test_data['doc_type'] != "[NONE]") & 
                      (test_data['journal'] != "[NONE]")].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 32.1%
Precision: 53.1%
Accuracy: 73.2%


In [43]:
get_metrics(test_data[(test_data['doc_type'] == "[NONE]") & 
                      (test_data['journal'] == "[NONE]")].sample(10000), "target_tok", f"predictions_{thresh}")

Recall: 31.0%
Precision: 53.3%
Accuracy: 55.7%
