In [1]:
# Load dictionaries for companies
import pickle
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np

def load_pickle_file(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

# Assuming the files are in the current directory
company_kw_vec = load_pickle_file('../glanos-data/company.kw.vec')
# company_kw_vec_hr = load_pickle_file('glanos-data/company.kw.vec.hr')
company_short_definitions_kw_vec = load_pickle_file('../glanos-data/company_short_definitions.kw.vec')
# company_short_definitions_kw_vec_hr = load_pickle_file('glanos-data/company_short_definitions.kw.vec.hr')

In [2]:
params = {
    "DEFINITION_WEIGHT": 0.0,
    "MODEL": "sbert", # sbert
}
cn_descriptors_df = pd.read_csv('../glanos-data/cn_descriptors_top.tsv', sep='\t').drop(columns=['Unnamed: 6'])
cn_descriptors_df

Unnamed: 0,occurances,company,country,definition,additional definitions,keywords
0,1175619,Technavio,GB,technology research company,consultancy|technology research company,406:confident strategic decisions|403:healthca...
1,1111946,Cable News Network,US,,WarnerMedia company,1111917:warnermedia|13:warner media|6:states|6...
2,336949,Iiroc,CA,self-regulatory company,self-regulatory company,336866:investment dealers|330014:equity market...
3,242464,Rosen,US,law company,investor rights law company,241104:global investor rights|200641:purchaser...
4,196238,Schall,US,,shareholder rights litigation company,195736:national shareholder rights|179758:viol...
...,...,...,...,...,...,...
187934,50,Cdp North America,,research company,disclosure platform research company|investmen...,136:major corporations|136:financial markets|4...
187935,50,Money Carer Foundation,,social company,social company,33:own financial affairs|33:vulnerable adults|...
187936,50,Emerging Markets Private Equity,,trade company,capital company|trade company|trade private in...,40:emerging markets|36:changes|6:private inves...
187937,50,Bryght Ai,,intelligence company,conversational intelligence company|scoring pl...,2:research services


In [8]:
import torch
from tqdm import tqdm
tqdm.pandas()

sbert_model = SentenceTransformer('all-MiniLM-L12-v2', device='cuda' if torch.cuda.is_available() else 'cpu')
sbert_dict = {}

def sbert_encode(string_to_encode):
    if string_to_encode in sbert_dict:
        return sbert_dict[string_to_encode]
    else:
        encoding = sbert_model.encode(string_to_encode)
        sbert_dict[string_to_encode] = encoding
        return encoding
    
def aggregate_embeddings(keys, embeddings_dict):
    embeddings = []
    for key in keys:
        embedding = sbert_encode(key) if params["MODEL"] == "sbert" else embeddings_dict[key]
        embeddings.append(embedding)
    return np.mean(embeddings, axis=0)


def get_embedding(row):
    company, definition, additional_definitions = row['company'], row['definition'], row['additional definitions']
    if pd.isna(additional_definitions) and pd.isna(definition):
        return np.nan
    if params["MODEL"] != "sbert":
        additional_definitions = additional_definitions.replace(' ', '-').split('|')
    else:
        additional_definitions = additional_definitions.split('|')
    row['embedding_add_def'] = aggregate_embeddings(additional_definitions, company_kw_vec)
    if pd.isna(definition):
        row['embedding'] = row['embedding_add_def']
        return row
    if params["MODEL"] != "sbert":
        definition = definition.replace(' ', '-')
    row['embedding_def'] = sbert_encode(definition) if params["MODEL"] == "sbert" else np.array(company_kw_vec[definition])
    row['embedding'] = (1-params["DEFINITION_WEIGHT"])*row['embedding_add_def']+params["DEFINITION_WEIGHT"]*row['embedding_def']
    return row

def save_embeddings():
    company_embedding_dict = cn_descriptors_df.set_index('company')['embedding'].to_dict()
    company_embedding_dict = {k: v for k, v in company_embedding_dict.items() if not pd.isna(k)}
    company_embedding_dict = {k.lower(): v for k, v in company_embedding_dict.items()}
    prefix = "../glanos-data/embeddings/"
#     with open(f'../glanos-data/company.kw.sbert.vec', 'wb') as f:
#         pickle.dump(sbert_dict, f)
    with open(f'{prefix}company_embedding_dicts_{params["MODEL"]}_{params["DEFINITION_WEIGHT"]}.pickle', 'wb') as f:
        pickle.dump(company_embedding_dict, f)

In [9]:
for definition_weight in np.arange(0.1, 1.1, 0.1):
    print('definition_weight', definition_weight)
    params["DEFINITION_WEIGHT"] = definition_weight
    cn_descriptors_df = cn_descriptors_df.progress_apply(lambda row: get_embedding(row), axis=1)
    save_embeddings()

definition_weight 0.1


100%|██████████████████████████████████████████████████████████| 187939/187939 [47:00<00:00, 66.64it/s]


definition_weight 0.2


100%|███████████████████████████████████████████████████████| 187939/187939 [00:11<00:00, 16953.94it/s]


definition_weight 0.30000000000000004


100%|███████████████████████████████████████████████████████| 187939/187939 [00:10<00:00, 17798.26it/s]


definition_weight 0.4


100%|███████████████████████████████████████████████████████| 187939/187939 [00:10<00:00, 17837.69it/s]


definition_weight 0.5


100%|███████████████████████████████████████████████████████| 187939/187939 [00:11<00:00, 16400.54it/s]


definition_weight 0.6


100%|███████████████████████████████████████████████████████| 187939/187939 [00:10<00:00, 17942.75it/s]


definition_weight 0.7000000000000001


100%|███████████████████████████████████████████████████████| 187939/187939 [00:10<00:00, 17783.30it/s]


definition_weight 0.8


100%|███████████████████████████████████████████████████████| 187939/187939 [00:13<00:00, 13815.32it/s]


definition_weight 0.9


100%|███████████████████████████████████████████████████████| 187939/187939 [00:12<00:00, 15088.98it/s]


definition_weight 1.0


100%|███████████████████████████████████████████████████████| 187939/187939 [00:10<00:00, 17900.33it/s]
