# Inputs

* `path`: In this path, these must be present:
    * `entities.pkl`: Pickle file that stores a dictionary. In this dictionary, keys are IDs of entities. For each entity, the value is a dictionary with a single key:
        * `Labels`: List of strings where each string is an alternative name/label of the organization.
       Will be unpickled as:
       ```python
with open(path+'entities.pkl','rb') as f:
            entity_labels=pickle.load(f)
       ```
     
    * `commonness.json`: Stores estimated commonness values. Should be calculated beforehand. In this json file, keys are mentions, and values are also in json format. For each mention that a commonness value is calculated, the json contains the entity IDs with their corresponding commonness values. Example:
    ```json
    {
        "mention_1":{
                         "entity_1":0.9,
                         "entity_2":0.05,
                         "entity_3":0.05
        },
        "mention_2": {
                         "entity_1":0.6,
                         "entity_4":0.3
        },
        "mention_3":{
                         "entity_5":1.0
        }
    }
            
    ```
    Mention, entity pairs with a commonness of 0 do not need to be included.
    * `link_prob.json`: Stores estimated link probabilities. Should be calculated beforehand. Keys are mentions and values are link probabilities
    * `popularity.json`: Stores estimated entity popularity values. Should be calculated beforehand. Keys are entities and values are popularities.
    * `train_preds.pkl` and `dev_preds.pkl`: Can be obtained by running `Prediction with Biencoder.ipynb`
    * `entity_pool.pkl`
    * `train.jsonl`
    * `dev.jsonl`
    
# Outputs
* Model is written to `"lgbm12.pkl"`. Also, threshold selection (for NIL mention detection) is performed at the end of the notebook.

In [None]:
import random
import numpy as np
import pickle
import json
from fuzzywuzzy import fuzz
from sklearn.linear_model import LogisticRegression
import pandas as pd
import itertools
import time
import lightgbm as lgb

In [None]:
path = ''

In [None]:
train_fname = "train.jsonl"
monitor_fname = "dev.jsonl"
seed = 0
num_cands = 12
file_train_cands = "train_preds.pkl"
file_monitor_cands = "dev_preds.pkl"

In [None]:
random.seed(seed)
np.random.seed(seed)

In [None]:
# Load train data 
train_samples = []
with open(path_b+train_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        train_samples.append(json.loads(line.strip()))
print(len(train_samples))

monitor_samples = []
with open(path_b+monitor_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        monitor_samples.append(json.loads(line.strip()))
print(len(monitor_samples))

In [None]:
with open(path+'entities.pkl','rb') as f:
    entity_labels=pickle.load(f)

In [None]:
with open(path+'entity_pool.pkl',"rb") as f:
    entity_pool = pickle.load(f)

In [None]:
with open(path+file_train_cands,'rb') as f:
    train_init = pickle.load(f)

with open(path+file_monitor_cands,'rb') as f:
    monitor_init = pickle.load(f)

In [None]:
train_correct_entities = [x['label_id'] for x in train_samples]
train_mentions = [x['mention'] for x in train_samples]

monitor_correct_entities = [x['label_id'] for x in monitor_samples]
monitor_mentions = [x['mention'] for x in monitor_samples]

In [None]:
with open(path+'commonness.json','r',encoding='utf-8') as f:
    commonness = json.load(f)
with open(path+'link_prob.json','r',encoding='utf-8') as f:
    link_probability = json.load(f)
with open(path+'popularity.json','r',encoding='utf-8') as f:
    popularity = json.load(f)

In [None]:
train_bert_scores = []
train_fw_scores = []
train_fw_scores2 = []
train_target = []
train_unique_id = []
train_entity = []
train_commonness = []
train_popularity = []
train_link_probability = []

for i in range(len(train_samples)):
    candidates = train_init[i]
    this_mention = train_mentions[i]
    for j in range(num_cands):
        #Get FW score
        
        this_ent_labels = entity_labels[str(candidates[j][1])]['Labels']
        fw_score = 0
        for lbl in this_ent_labels:
            fw_score = max(fw_score,fuzz.token_set_ratio(this_mention,lbl)/100)
        train_fw_scores.append(fw_score)
        
        train_commonness.append(commonness.get(this_mention.lower(),{}).get(str(candidates[j][1]),0.))
        train_popularity.append(popularity.get(str(candidates[j][1]),0.))
        train_link_probability.append(link_probability.get(this_mention.lower(),0.))
        
        fw_score2 = 0
        for lbl in this_ent_labels:
            fw_score2 = max(fw_score2,fuzz.token_sort_ratio(this_mention,lbl)/100)
        train_fw_scores2.append(fw_score2)
        
        #Get BERT score
        train_bert_scores.append(candidates[j][0])
        
        #Get target
        if train_correct_entities[i] is not None:
            #if str(candidates[j][1]) in entity_pool[train_correct_entities[i]]:
            if str(candidates[j][1]) == train_correct_entities[i]:
                train_target.append(1)
            else:
                train_target.append(0)
        else:
            train_target.append(0)
        
        train_unique_id.append(i)
        train_entity.append(candidates[j][1])

In [None]:
train_df=pd.DataFrame({'ID':train_unique_id,'Commonness':train_commonness,'BERT':train_bert_scores,
                       'Popularity':train_popularity,'Link_Probability':train_link_probability,
                       'FW':train_fw_scores,'FW2':train_fw_scores2,'Entity':train_entity,'Target':train_target})

In [None]:
monitor_bert_scores = []
monitor_fw_scores = []
monitor_fw_scores2 = []
monitor_target = []
monitor_unique_id = []
monitor_entity = []
monitor_commonness = []
monitor_popularity = []
monitor_link_probability = []

for i in range(len(monitor_samples)):
    candidates = monitor_init[i]
    this_mention = monitor_mentions[i]
    for j in range(num_cands):
        #Get FW score
        this_ent_labels = entity_labels[str(candidates[j][1])]['Labels']
        fw_score = 0
        for lbl in this_ent_labels:
            fw_score = max(fw_score,fuzz.token_set_ratio(this_mention,lbl)/100)
        monitor_fw_scores.append(fw_score)
        
        fw_score2 = 0
        for lbl in this_ent_labels:
            fw_score2 = max(fw_score2,fuzz.token_sort_ratio(this_mention,lbl)/100)
        monitor_fw_scores2.append(fw_score2)
        
        #Get BERT score
        monitor_bert_scores.append(candidates[j][0])
        
        monitor_commonness.append(commonness.get(this_mention.lower(),{}).get(str(candidates[j][1]),0.))
        monitor_popularity.append(popularity.get(str(candidates[j][1]),0.))
        monitor_link_probability.append(link_probability.get(this_mention.lower(),0.))
        
        #Get target
        if monitor_correct_entities[i] is not None:
            #if str(candidates[j][1]) in entity_pool[monitor_correct_entities[i]]:
            if str(candidates[j][1]) == train_correct_entities[i]:
                monitor_target.append(1)
            else:
                monitor_target.append(0)
        else:
            monitor_target.append(0)
        
        monitor_unique_id.append(i)
        monitor_entity.append(candidates[j][1])
monitor_df=pd.DataFrame({'ID':monitor_unique_id,'Commonness':monitor_commonness,'BERT':monitor_bert_scores,
                         'Popularity':monitor_popularity,'Link_Probability':monitor_link_probability,
                         'FW':monitor_fw_scores,'FW2':monitor_fw_scores2,'Entity':monitor_entity,'Target':monitor_target})

In [None]:
def get_thresholded_preds(th,pred,scores):
    thresholded_preds = []
    for i in range(len(pred)):
        if scores[i]>=th:
            thresholded_preds.append(pred[i])
        else:
            thresholded_preds.append(None)
    return thresholded_preds

In [None]:
train_df= train_df.astype({'Commonness':float,
                           'BERT':float,
                           'Popularity':float,
                           'Link_Probability':float,
                           'FW2':float})
monitor_df= monitor_df.astype({'Commonness':float,
                           'BERT':float,
                           'Popularity':float,
                           'Link_Probability':float,
                           'FW2':float})

In [None]:
params=dict()
params['max_bin']=63
params['learning_rate']=0.1
params['min_data_in_leaf']=100
params['bagging_freq']=1
params['bagging_fraction']=0.9
params['lambda_l1']=1
params['lambda_l2']=1
params['min_gain_to_split']=1
params['objective']='binary' 
params['metric']='binary_logloss' 
params["is_unbalance"] = False
params["seed"] = 25
params["extra_trees"] = False

In [None]:
selected_features = ['Commonness', 'BERT', 'FW2', 'Popularity','Link_Probability']

In [None]:
d_train=lgb.Dataset(train_df[selected_features], label=train_df['Target'].values)
model_lgb=lgb.train(params,d_train,100,valid_sets=[lgb.Dataset(monitor_df[selected_features], label=monitor_df['Target'].values),d_train],verbose_eval=50) 

In [None]:
lgb.plot_importance(model_lgb,importance_type='gain')

In [None]:
train_df['Pred'] = model_lgb.predict(train_df[selected_features])

In [None]:
temp = train_df.copy(deep=True)
temp=temp.loc[temp.groupby('ID').Pred.idxmax().values][['ID','Pred','Entity']]
train_entities = temp.Entity.values
train_scores = temp.Pred.values

In [None]:
#Find best threshold based on micro avg acc
thresholds = np.arange(0,1.001,0.001)
best_th = None
max_score = 0
for th in thresholds:
    ctr = 0
    thresholded_preds =  get_thresholded_preds(th,train_entities,train_scores)
    for i in range(len(train_correct_entities)):
        if thresholded_preds[i] is None and train_correct_entities[i] is None:
            ctr+=1
        elif train_correct_entities[i] is not None and str(thresholded_preds[i]) in entity_pool[train_correct_entities[i]]:
            ctr+=1
    if ctr>max_score:
        max_score=ctr
        best_th = th
print("Best threshold: ",best_th)

In [None]:
#Calculate micro avg acc per threshold
print("TRAIN")
thresholds = [0, 0.5,best_th]
for th in thresholds:
    print("Threshold: ",th)
    ctr = 0
    thresholded_preds =  get_thresholded_preds(th,train_entities,train_scores)
    for i in range(len(train_correct_entities)):
        if thresholded_preds[i] is None and train_correct_entities[i] is None:
            ctr+=1
        elif train_correct_entities[i] is not None and str(thresholded_preds[i]) in entity_pool[train_correct_entities[i]]:
            ctr+=1
    print('Micro Average Accuracy: ',np.round((100*ctr)/len(train_correct_entities),2),'%\n')

In [None]:
with open('lgbm12.pkl','wb') as f:
    pickle.dump(model_lgb,f)