In [16]:
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import CrossEncoder
import pandas as pd
import numpy as np
import time
import pickle
import os
import re

In [17]:
embedding_model_name = 'all-mpnet-base-v2' #all-mpnet-base-v2 #all-MiniLM-L6-v2
embedding_model = SentenceTransformer(embedding_model_name)
embedding_model.save(embedding_model_name)

cross_encoder_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2' # cross-encoder/stsb-roberta-base #cross-encoder/ms-marco-MiniLM-L-6-v2 
cross_encoder_model = CrossEncoder(cross_encoder_model_name) 
cross_encoder_model.save(cross_encoder_model_name)

In [18]:
pickle_filename = f'{embedding_model_name}_industry.pkl'
col_to_embed = "Company_name_industry" #Company_name #Company_name_industry
company_names_path = r"all_company_names_final.csv"
df = pd.read_csv(company_names_path)

if os.path.exists(pickle_filename ):
    with open(pickle_filename , "rb") as fIn:
       cache_data = pickle.load(fIn)
       company_names = cache_data['company_names']
       company_names_embeddings = cache_data['company_names_embeddings']
       u3_nums = cache_data['u3_nums']

else:

    company_names = df['Company_name'].values.tolist()
    company_names_for_embedding = df[col_to_embed].values.tolist()
    u3_nums = df['u3_num'].values.tolist()
    company_names_embeddings = embedding_model.encode(company_names_for_embedding, show_progress_bar=True, convert_to_tensor=True)

    with open(pickle_filename, "wb") as fOut:
        pickle.dump({'company_names': company_names, 'company_names_embeddings': company_names_embeddings, 'u3_nums': u3_nums}, 
                    fOut
                    )


In [19]:
# inputs
main_dir = "wsj"
ner_dir = "gemma"
result_col_name = f"company_found_{ner_dir}"

# outputs
ner_mapped = f"mapped_{ner_dir}"

if not os.path.exists(os.path.join(main_dir, ner_mapped)):
    os.mkdir(os.path.join(main_dir, ner_mapped))

In [20]:
def return_top_k_most_similar(query, top_k=100):
    query_embedding = embedding_model.encode(query, convert_to_tensor=False)
    hits = util.semantic_search(query_embedding, company_names_embeddings, top_k=top_k)
    hits = hits[0]  
    candidates = [company_names[hit['corpus_id']] for hit in hits]
    return hits, candidates

def re_rank(query, hits, candidates):
    sentence_pairs = [[query, candidate] for hit,candidate in zip(hits, candidates)]
    ce_scores = cross_encoder_model.predict(sentence_pairs)
    for ce_score, hit in zip(ce_scores, hits):
        hit['cross-encoder_score'] = ce_score

    # Sort by score, highest score at the top
    hits = sorted(hits, key=lambda x: x['cross-encoder_score'], reverse=True)
    candidates_reranked = [company_names[hit['corpus_id']] for hit in hits]
    return hits, candidates_reranked

def get_match(query, thresh_cross_encoder_score=0):

    hits, candidates = return_top_k_most_similar(query, top_k=100)
    hits, candidates_reranked = re_rank(query, hits, candidates)

    company_found = None
    for hit in hits:

        candidate = company_names[hit['corpus_id']]
        if float(hit['cross-encoder_score']) > thresh_cross_encoder_score :
            company_found = candidate
        else:
            pass
        break
    return company_found

In [21]:
for filename in os.listdir(os.path.join(main_dir, ner_dir)):

    filepath_save = os.path.join(os.path.join(main_dir, ner_mapped, filename))
    if os.path.exists(filepath_save):
        continue

    df_batch = pd.read_csv(os.path.join(main_dir, ner_dir, filename))
    vid_ids = df_batch['vid_id'].values.tolist()
    company_names_found = df_batch[result_col_name].values.tolist()
    queries = company_names_found

    start_time = time.time()
    company_names_mapped = []
    u3_nums = []

    for query in queries:

        if pd.isnull(query):
            company_names_mapped.append(np.nan)
            u3_nums.append(np.nan)
            continue

        match = get_match(query)
        if match is None:
            # remove industry and try to match again
            query = re.sub(r"(\(.+\))|\[.+\]",'', query)
            match = get_match(query) 
            if match is None:
                company_names_mapped.append(np.nan)
                u3_nums.append(np.nan)
            else:
                company_names_mapped.append(match)
                u3_nums.append(df.loc[df['Company_name'] == match]['u3_num'].values.tolist()[0])
        else:
            company_names_mapped.append(match)
            u3_nums.append(df.loc[df['Company_name'] == match]['u3_num'].values.tolist()[0])

    df_batch_save = pd.DataFrame(zip(vid_ids, company_names_found, company_names_mapped, u3_nums), columns=['vid_id', result_col_name, 'Company_name_mapped', 'u3_num'])
    df_batch_save.to_csv(os.path.join(os.path.join(main_dir, ner_mapped, filename)), index=False)

    print(f"{filename} completed. Time elapsed:{time.time() - start_time}")
    #break

batch_10.csv completed. Time elapsed:9.773066997528076


In [None]:
# combine all the intermediate results, remove any possible duplicate entity in the same vid_id
# save in a new file

