In [29]:
import os
import numpy as np
import faiss
import pandas as pd
from collections import Counter
import sys
import argparse
import joblib
from joblib import Parallel,delayed
from joblib import parallel_backend
from pandarallel import pandarallel
from  tqdm import tqdm
pandarallel.initialize()
import json
from datetime import datetime
from datetime import timedelta
import pickle



# ---------------------------------------------
USE_TFIDF = True
USE_sBert = False
USE_Longformer = False
USE_doc2vec = True
# ---------------------------------------------
date_str = '2020-12-10'

N_PROBE = 100
model_pkl_dir = 'model_pkl_dir'
mapping_df_dir = 'mapping_data_dir'
df_Mapping = pd.read_csv(os.path.join(mapping_df_dir, 'mapping_data_{}.csv'.format(date_str)),index_col=None)

# choices : tfidf, sBert, LongFormer
def read_indices_from_file(_typeID, date):
    global model_pkl_dir
    filename = os.path.join(model_pkl_dir ,'faiss_index_{}_{}'.format(_typeID, date))
    index = faiss.read_index(filename)
    return index

# -------------------------------------
# Vector returned is an ordered dict
# choices : tfidf, sBert, LongFormer
# -------------------------------------
def read_vectors_from_file(_typeID, _date):
    global model_pkl_dir
    if _typeID == 'LongFormer':
        fname = os.path.join(model_pkl_dir, "doc_id2sBertEmb_{}.pkl".format(_date))
    elif _typeID == 'sBert':
        fname = os.path.join(model_pkl_dir,"doc_id2LongFormerEmb_{}.pkl".format(_date))
    elif _typeID == 'tfidf':
        fname =  os.path.join(model_pkl_dir, "doc_id2tfidfEmb_{}.pkl".format(_date))
    elif _typeID == 'doc2vec':
        fname =  os.path.join(model_pkl_dir, "doc_id2doc2vecEmb_{}.pkl".format(_date))
    print(fname)      
    with open(fname,'rb') as fh:
        vec = pickle.load(fh)
        vec = np.array(list(vec.values()))
        return vec

index_sBert = None
index_tfidf = None
index_longformer = None
index_doc2vec = None
vectors_sBert = None
vectors_longformer = None
vectors_tfidf = None 
vectors_doc2vec = None

def initialize(date_str):
    global df_Mapping
    global index_sBert
    global index_tfidf
    global index_longformer
    global index_doc2vec
    global vectors_longformer
    global vectors_sBert
    global vectors_tfidf
    global vectors_doc2vec
    global N_PROBE
    global USE_TFIDF, USE_Longformer, USE_sBert
    
    if USE_sBert:
        index_sBert = read_indices_from_file('sBert', date_str)
        index_sBert.nprobe = N_PROBE
    if USE_TFIDF:
        index_tfidf = read_indices_from_file('tfidf', date_str)    
        index_tfidf.nprobe = N_PROBE

    if USE_Longformer:
        index_longformer = read_indices_from_file('LongFormer',date_str)
        index_longformer.nprobe = N_PROBE
    if USE_doc2vec:
        index_doc2vec = read_indices_from_file('doc2vec',date_str)
        index_doc2vec.nprobe = N_PROBE

    try:    
        vectors_longformer = read_vectors_from_file('LongFormer',date_str )
    except :
        pass
    try:
        vectors_sBert = read_vectors_from_file('sBert', date_str)
    except:
        pass
    try:
        vectors_tfidf = read_vectors_from_file('tfidf', date_str)  
    except:
        pass
    try:
        vectors_doc2vec = read_vectors_from_file('doc2vec', date_str)  
    except:
        pass
    return 


INFO: Pandarallel will run on 40 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [30]:
def query(
    doc_ID =None, 
    synID = None,
    find_NN = 20,
    min_count_threshold = 1,
    n_probe = 20
  
):
    global df_Mapping
    global index_sBert
    global index_tfidf
    global index_longformer
    global vectors_longformer
    global vectors_sBert
    global vectors_tfidf
    global index_doc2vec
    global vectors_doc2vec
    
    obj_list = [index_tfidf, index_sBert, index_longformer,index_doc2vec]
    vec_list = [vectors_tfidf, vectors_sBert, vectors_longformer,vectors_doc2vec]
    result = []
    
    if doc_ID is None and synID is None :
        return
    if doc_ID is not None:
        _tmp_ = df_Mapping.loc[(df_Mapping['id']==doc_ID)]
        synID = _tmp_['synID'].values[0]
        
    _type_of_index = ['tfidf', 'bert', 'lf']
    i = 0 
    for _index,_vector in zip(obj_list, vec_list):
        i+=1
        if _index is None: 
            continue
        _index.nprobe = n_probe
        D, I = _index.search(
            np.array([_vector[synID]]).astype(np.float32),
            find_NN
        ) 
        result.extend(I[0][1:])
        
    counter = Counter(result)
    filtered = [ k for k,v in counter.items() if v >= min_count_threshold and k >-1 and k!=synID]    
    return filtered
    
    

In [38]:
initialize(date_str = date_str)


model_pkl_dir/doc_id2sBertEmb_2020-12-10.pkl
model_pkl_dir/doc_id2LongFormerEmb_2020-12-10.pkl
model_pkl_dir/doc_id2tfidfEmb_2020-12-10.pkl
model_pkl_dir/doc_id2doc2vecEmb_2020-12-10.pkl


In [42]:
input_syn_id = 250
res = query(
    synID = input_syn_id, 
    find_NN = 10,
    min_count_threshold = 2,
    n_probe = 20
)
print(input_syn_id, res)
for r in [input_syn_id] + res:
    print(df_Mapping.loc[df_Mapping['synID']==r].title)

250 [778]
250    Whiteford tells Detroit Rep. who received raci...
Name: title, dtype: object
778    Whiteford tells Detroit Rep. who received raci...
Name: title, dtype: object


In [45]:
input_syn_id = 100
res = query(
    synID = input_syn_id, 
    find_NN = 10,
    min_count_threshold = 2,
    n_probe = 20
)
print(input_syn_id, res)
for r in [input_syn_id] + res:
    print(df_Mapping.loc[df_Mapping['synID']==r].title)

100 [490, 16]
100    Protesters descend on Secretary of State Jocel...
Name: title, dtype: object
490    Armed Thugs Showed Up at the House of Michigan...
Name: title, dtype: object
16    Armed Thugs Showed Up at the House of Michigan...
Name: title, dtype: object


In [46]:
input_syn_id = 32
res = query(
    synID = input_syn_id, 
    find_NN = 10,
    min_count_threshold = 2,
    n_probe = 20
)
print(input_syn_id, res)
for r in [input_syn_id] + res:
    print(df_Mapping.loc[df_Mapping['synID']==r].title)

32 [511, 1652]
32    Seattle mayor Jenny Durkan won`t seek reelecti...
Name: title, dtype: object
511    Seattle mayor Jenny Durkan won`t seek reelecti...
Name: title, dtype: object
1652    Seattle mayor Jenny Durkan won`t seek reelecti...
Name: title, dtype: object


In [48]:
input_syn_id = 150
res = query(
    synID = input_syn_id, 
    find_NN = 10,
    min_count_threshold = 2,
    n_probe = 20
)
print(input_syn_id, res)
for r in [input_syn_id] + res:
    print(df_Mapping.loc[df_Mapping['synID']==r].title)

150 [1002]
150    Stock market rally amid COVID-19 creates disto...
Name: title, dtype: object
1002    Stock market rally amid COVID-19 creates disto...
Name: title, dtype: object
