In [1]:
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import CrossEncoder
import pandas as pd
import numpy as np
import time
import pickle
import os
import re

In [4]:
embedding_model_name = 'all-mpnet-base-v2' #all-mpnet-base-v2 #all-MiniLM-L6-v2
embedding_model = SentenceTransformer(embedding_model_name)
embedding_model.save(embedding_model_name)

cross_encoder_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2' # cross-encoder/stsb-roberta-base #cross-encoder/ms-marco-MiniLM-L-6-v2 
cross_encoder_model = CrossEncoder(cross_encoder_model_name) 
cross_encoder_model.save(cross_encoder_model_name)

In [8]:
pickle_filename = f'{embedding_model_name}_name.pkl'
col_to_embed = "Company_name" #Company_name #Company_name_industry
parent_dir = os.path.dirname(os.getcwd())
company_names_path = os.path.join(parent_dir, "all_company_names_final.csv")
pickle_filepath = os.path.join(parent_dir, pickle_filename)

df = pd.read_csv(company_names_path)

if os.path.exists(pickle_filepath ):
    with open(pickle_filepath , "rb") as fIn:
       cache_data = pickle.load(fIn)
       company_names = cache_data['company_names']
       company_names_embeddings = cache_data['company_names_embeddings']
       u3_nums = cache_data['u3_nums']

else:

    company_names = df['Company_name'].values.tolist()
    company_names_for_embedding = df[col_to_embed].values.tolist()
    u3_nums = df['u3_num'].values.tolist()
    company_names_embeddings = embedding_model.encode(company_names_for_embedding, show_progress_bar=True, convert_to_tensor=True)

    with open(pickle_filepath, "wb") as fOut:
        pickle.dump({'company_names': company_names, 'company_names_embeddings': company_names_embeddings, 'u3_nums': u3_nums}, 
                    fOut
                    )


In [9]:
def return_top_k_most_similar(query, top_k=100):
    query_embedding = embedding_model.encode(query, convert_to_tensor=False)
    hits = util.semantic_search(query_embedding, company_names_embeddings, top_k=top_k)
    hits = hits[0]  
    candidates = [company_names[hit['corpus_id']] for hit in hits]
    return hits, candidates

def re_rank(query, hits, candidates):
    sentence_pairs = [[query, candidate] for hit,candidate in zip(hits, candidates)]
    ce_scores = cross_encoder_model.predict(sentence_pairs)
    for ce_score, hit in zip(ce_scores, hits):
        hit['cross-encoder_score'] = ce_score

    # Sort by score, highest score at the top
    hits = sorted(hits, key=lambda x: x['cross-encoder_score'], reverse=True)
    candidates_reranked = [company_names[hit['corpus_id']] for hit in hits]
    return hits, candidates_reranked

def get_match(query, thresh_cross_encoder_score=0):

    hits, candidates = return_top_k_most_similar(query, top_k=100)
    hits, candidates_reranked = re_rank(query, hits, candidates)

    company_found = None
    for hit in hits:

        candidate = company_names[hit['corpus_id']]
        if float(hit['cross-encoder_score']) > thresh_cross_encoder_score :
            company_found = candidate
        else:
            pass
        break
    return company_found

In [10]:
queries = [
"General Electric (Conglomerate)", 'Danaher (Medical Technology)'
'Berkshire Hathaway - Insurance',  'Berkshire Hathaway (Insurance)', 
'Lee Enterprises (Media)', "Jordan's (Retail)", 'RC Willey (Retail)', 
'Berkshire Hathaway (Insurance)', 'Xilinx (Semiconductors)', 
'Caterpillar (Heavy Machinery)', 'Xerox (Technology)', 
'Samsonite (Luggage)', 'Netflix (Entertainment)', 
'Marvel (Entertainment)', 'Medicine of the Angels (Cannabis)', 
'Coda (Cannabis)', 'Intel (Semiconductors)', 
'Morgan Stanley (Investment Banking)', 'Nvidia (Semiconductors)',
"Pepsi (Food & Beverage)", "Pepsi"]

In [15]:
# demo

for query in queries:

    match = get_match(query)
    if match is None:
        # remove industry and try to match again
        query = re.sub(r"(\(.+\))|\[.+\]",'', query)
        #print(query)
        match = get_match(query) 
        if match is None:
            print(f"No match found for query: '{query}'")
        else:
            print(f"match: '{match}' found for query: '{query}'")

    else:
        print(f"match: '{match}' found for query: '{query}'")


match: 'General Electric Co PLC' found for query: 'General Electric (Conglomerate)'
No match found for query: 'Danaher Berkshire Hathaway - Insurance'
match: 'Berkshire Hathaway Inc' found for query: 'Berkshire Hathaway (Insurance)'
match: 'Lee Enterprises Inc' found for query: 'Lee Enterprises (Media)'
match: 'Jordan American Holdings Inc' found for query: 'Jordan's (Retail)'
match: 'RCL Retail Ltd' found for query: 'RC Willey (Retail)'
match: 'Berkshire Hathaway Inc' found for query: 'Berkshire Hathaway (Insurance)'
match: 'Xilinx Inc' found for query: 'Xilinx (Semiconductors)'
match: 'Caterpillar Inc' found for query: 'Caterpillar (Heavy Machinery)'
match: 'Xantrex Technology Inc' found for query: 'Xerox (Technology)'
match: 'Samsonite LLC/OLD' found for query: 'Samsonite (Luggage)'
match: 'Netflix Inc' found for query: 'Netflix (Entertainment)'
match: 'Marvel Entertainment LLC' found for query: 'Marvel (Entertainment)'
No match found for query: 'Medicine of the Angels '
match: 'Cod

In [16]:
queries = ['Amazon (in reference to a partnership)', 'Unilver (industries: personal care and household products)']

In [17]:
# demo

for query in queries:

    match = get_match(query)
    if match is None:
        # remove industry and try to match again
        query = re.sub(r"(\(.+\))|\[.+\]",'', query)
        #print(query)
        match = get_match(query) 
        if match is None:
            print(f"No match found for query: '{query}'")
        else:
            print(f"match: '{match}' found for query: '{query}'")

    else:
        print(f"match: '{match}' found for query: '{query}'")


match: 'Amazon.com Inc' found for query: 'Amazon '
match: 'Uniliver Nepal Ltd' found for query: 'Unilver '
