### Imports

In [12]:
import json
import pandas as pd
import os
import collections
import math
from typing import Dict, Any, Text, Tuple
import yaml
import sys
from pathlib import Path
import re
from tqdm import tqdm

import awswrangler as wr
from src.utils.s3_utils import put_json_obj
from src.utils.data import overwrite_json
from src.utils.logs import get_logger
from src.relation_extraction.infer import infer_from_trained
from src.matcher.core import SimCSE_Matcher
from src.relation_extraction.reporter import (agg_relations,
                                              process_relations,
                                              match_companies)
from src.glue.glue_etl import GlueETL


icon = "\U0001F4AB "
logger = get_logger(f"{icon} test", log_level="INFO")

# Load GlueEtl Worker
etl = GlueETL()

10/01/2023 20:20:43 - INFO - botocore.credentials -   Found credentials in shared credentials file: ~/.aws/credentials
10/01/2023 20:20:52 - INFO - botocore.credentials -   Found credentials in shared credentials file: ~/.aws/credentials


In [2]:
pd.options.mode.chained_assignment = None  # default='warn'

In [3]:
machine = "local"
#machine = "paperspace"

if machine == "local":
    src_dir= Path.cwd().parent    
elif machine == "paperspace":
    src_dir = Path("/notebooks/inferess-relation-extraction/")

sys.path.append(str(src_dir))


In [4]:
entity_matcher = SimCSE_Matcher(
        model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"
    )


  from .autonotebook import tqdm as notebook_tqdm


### Test set from annotated

In [5]:
labeled_relationships_sentences_pos = pd.read_csv("test_pipeline_data/labeled-relationships/labeled-relationships-sentences_pos_entity_in_sentence.csv")

labeled_relationships_sentences_neg = pd.read_csv("test_pipeline_data/labeled-relationships/labeled-relationships-sentences_neg_entity_in_sentence.csv")


In [7]:
print(labeled_relationships_sentences_pos.columns)

print(labeled_relationships_sentences_pos.shape)

# find unique count on company_name and related_entity
print(labeled_relationships_sentences_pos.drop_duplicates(subset=["key", "related_entity"]).shape)


Index(['key', 'company_name', 'related_entity', 'relationship_type',
       'sentence'],
      dtype='object')
(150228, 5)
(6775, 5)


In [None]:
# # debug - check overlap in positive and negative samples

# # form join_id by joining all keys except sentence
# labeled_relationships_sentences["join_id"] = labeled_relationships_sentences.apply(lambda x: "_".join([str(x[k]) for k in x.keys() if k != "sentence"]), axis=1)
# labeled_relationships_sentences_neg["join_id"] = labeled_relationships_sentences_neg.apply(lambda x: "_".join([str(x[k]) for k in x.keys() if k != "sentence"]), axis=1)

# labeled_relationships_sentences["acc_join_id"] = labeled_relationships_sentences.apply(lambda x: "_".join([str(x[k]) for k in x.keys() if k not in ["relationship_type", "join_id", "sentence"]]), axis=1)
# labeled_relationships_sentences_neg["acc_join_id"] = labeled_relationships_sentences_neg.apply(lambda x: "_".join([str(x[k]) for k in x.keys() if k not in ["relationship_type", "join_id", "sentence"]]), axis=1)


# join_id_true = set(labeled_relationships_sentences["join_id"].to_list())
# join_id_false = set(labeled_relationships_sentences_neg["join_id"].to_list())

# acc_join_id_true = set(labeled_relationships_sentences["acc_join_id"].to_list())
# acc_join_id_false = set(labeled_relationships_sentences_neg["acc_join_id"].to_list())


# # debug 
# print(len(join_id_false))

# print(len(acc_join_id_false))

# len(join_id_true.intersection(join_id_false))

# len(acc_join_id_false.intersection(acc_join_id_true))


In [9]:
# get accessionnumber_list 

accessionnumber_list_pos = labeled_relationships_sentences_pos["key"].unique().tolist()
accessionnumber_list_neg = labeled_relationships_sentences_neg["key"].unique().tolist()

print(len(accessionnumber_list_pos))
print(len(accessionnumber_list_neg))

# find size of intersection of accessionnumber_list_pos and accessionnumber_list_neg
print(len(set(accessionnumber_list_pos).intersection(set(accessionnumber_list_neg))))



3077
1458
521


### SEC data from Athena

In [26]:
# create SQL query string with accessionnumber_list_pos


query_string = """SELECT * FROM "ecomap-release"."text_relationship_20221201" where accession_number IN {}""".format(str(accessionnumber_list_pos).replace("[", "(").replace("]", ")"))


query_data = etl.run_query(query_string)


Running query:
 SELECT * FROM "ecomap-release"."text_relationship_20221201" where accession_number IN ('0001628280-17-002606', '0001628280-17-002859', '0001628280-17-005885', '0001628280-17-006835', '0001644406-17-000024', '0001645383-17-000015', '0001649338-17-000158', '0001654954-17-001584', '0001654954-17-002864', '0001654954-17-010721', '0001669162-17-000006', '0001628280-17-012164', '0001637757-17-000007', '0001678463-17-000031', '0001679268-17-000006', '0001683168-17-000284', '0000813762-17-000010', '0000844965-17-000003', '0000845877-17-000018', '0000856982-17-000011', '0000894081-17-000036', '0000023197-17-000069', '0000354707-17-000020', '0000731802-17-000041', '0000763744-17-000049', '0000910612-17-000011', '0000915358-17-000014', '0000918965-17-000022', '0000936395-17-000064', '0000939798-17-000041', '0000944480-17-000013', '0001005414-17-000011', '0001022408-17-000009', '0001024725-17-000005', '0001027884-17-000082', '0001029831-17-000009', '0001047469-17-001086', '00016507

In [28]:
#query_data.to_csv("test_pipeline_data/labeled-relationships/old_model_relation_data_pos.csv", index=False)



In [165]:

query_data = pd.read_csv("test_pipeline_data/labeled-relationships/old_model_relation_data_pos.csv")


query_data.columns

Index(['valid_from', 'last_seen', 'reporter_id', 'reporter_name',
       'relationship_type', 'reported_company_id', 'reported_company',
       'date_from', 'date_to', 'confidence', 'accession_number'],
      dtype='object')

In [166]:
# find the unique count of accession_number, reporter_name, reported_company 

query_data.drop_duplicates(subset=["accession_number", "reporter_name", "reported_company"]).shape


(30455, 11)

In [167]:
len(set(accessionnumber_list_pos).intersection(set(query_data["accession_number"].unique())))

# note - all accession numbers in accessionnumber_list_pos are present in query_data



3077

In [68]:

# total matching relations in query_data

labeled_relationships_sentences_pos["join_id"] = labeled_relationships_sentences_pos.apply(lambda x: 
"_".join([str(x[k]).lower() for k in ['accession_number', 'reporter_name', 'reported_company']]), axis=1)

labeled_relationships_sentences_neg["join_id"] = labeled_relationships_sentences_neg.apply(lambda x:
"_".join([str(x[k]).lower() for k in ['accession_number', 'reporter_name', 'reported_company']]), axis=1)


query_data["join_id"] = query_data.apply(lambda x: 
"_".join([str(x[k]).lower() for k in ['accession_number', 'reporter_name', 'reported_company']]), axis=1)

print("total relations count")

print(labeled_relationships_sentences_pos.join_id.nunique())
print(labeled_relationships_sentences_neg.join_id.nunique())
print(query_data.join_id.nunique())

print("intersection count with query data")

print(len(set(labeled_relationships_sentences_pos["join_id"].unique()).intersection(set(query_data["join_id"].unique()))))
print(len(set(labeled_relationships_sentences_neg["join_id"].unique()).intersection(set(query_data["join_id"].unique()))))



total relations count
6770
2315
30454
intersection count with query data
1773
9


In [173]:
unique_company_relations = labeled_relationships_sentences_pos.drop_duplicates(subset=["key", "company_name", "related_entity"]).copy() 

unique_company_relations = unique_company_relations[['key', 'company_name', 'related_entity', 'relationship_type']]

# rename columns to match query_data
unique_company_relations.rename(columns={"key": "accession_number", 
                                         "company_name": "reporter_name", 
                                         "related_entity": "reported_company"}, inplace=True)



print(unique_company_relations.shape)
print(unique_company_relations.columns)

# drop multiple rows with same accession_number, reporter_name, reported_company due to multiple relationship_type
unique_company_relations = unique_company_relations[~ unique_company_relations.duplicated(subset=["accession_number", "reporter_name", "reported_company"], keep=False)]

print(unique_company_relations.shape)


(6775, 4)
Index(['accession_number', 'reporter_name', 'reported_company',
       'relationship_type'],
      dtype='object')
(6775, 4)


False    6775
Name: count, dtype: int64

In [157]:
query_data.columns

Index(['valid_from', 'last_seen', 'reporter_id', 'reporter_name',
       'relationship_type', 'reported_company_id', 'reported_company',
       'date_from', 'date_to', 'confidence', 'accession_number', 'join_id'],
      dtype='object')

In [174]:
# group by accession_number and iterate over each group

accession_to_reported_map = {}
for _, group in tqdm(query_data.groupby("accession_number")):
    accession = group["accession_number"].iloc[0]
    reported_companies = group["reported_company"].unique().tolist().copy()
    accession_to_reported_map[accession] = reported_companies


def check_entity_in_old_relations(entity_matcher, accession_number, query_reported_company, threshold=0.95):
    fword_query_comp = query_reported_company.lower().split(" ")[0].strip(".,-:;")

    # Check - query_reported_company has exact match in any reported company 
    flag_exact_match = False
    flag_exact_match_name = ""
    for reported_company in accession_to_reported_map[accession_number]:
        if query_reported_company.lower() == reported_company.lower():
            flag_exact_match = True
            flag_exact_match_name = reported_company
            break
    
    # No need of further check if exact match is found
    if flag_exact_match:
        return flag_exact_match_name, "exact_match"       

    # Check - by cosine similarity
    flag_sim_match = False
    flag_sim_match_name = ""
    
    top_k = 1
    similarities = entity_matcher.similarity(query_reported_company, accession_to_reported_map[accession_number])
    id_and_score = [(i, s) for i, s in enumerate(similarities) if s >= threshold]
    id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k]

    if id_and_score:
        flag_sim_match = True
        flag_sim_match_name = accession_to_reported_map[accession_number][id_and_score[0][0]]
    
     # No need of further check if sim match is found
    if flag_sim_match:
        return flag_sim_match_name, "sim_match"


    # Check -  first word of query company is present in any reported company
    flag_fword_present = False
    flag_fword_match_name = ""
    for reported_company in accession_to_reported_map[accession_number]:
        if re.search(r"\b{}\b".format(fword_query_comp), reported_company.lower()):
            flag_fword_present = True
            flag_fword_match_name = reported_company
            break
        
    # debug log 
    if flag_fword_match_name:
        print(f"{query_reported_company} -- {flag_fword_match_name} -- ")
    else:
        print(f"{query_reported_company} -- no match -- ")
    

    # prefer fword match over sim match
    if flag_fword_present:
        return flag_fword_match_name, "fword_match"
    else:
        return "", "no_match"

    

old_reported_match  = (
unique_company_relations.apply(lambda x: check_entity_in_old_relations(entity_matcher, x.accession_number, x.reported_company), axis=1))

old_reported_company = [x[0] for x in old_reported_match]
old_match_type = [x[1] for x in old_reported_match]

unique_company_relations["old_reported_company"] = old_reported_company
unique_company_relations["old_match_type"] = old_match_type



100%|██████████| 3077/3077 [00:00<00:00, 21787.14it/s]


Mallinckrodt -- Mallinckrodt PLC -- 
OLP Brooklyn Pavilion LLC -- no match -- 
Wal-Mart -- no match -- 
ThermoFisher Scientific -- no match -- 
Linde -- Linde PLC -- 
Zions -- Zions Bancorporation NA -- 
USPS -- no match -- 
Lumentum -- Lumentum Holdings Inc -- 
Check Point -- no match -- 
Auto Zone -- Advance Auto Parts Inc -- 
Alfa Wassermann S.p.A. -- Alfa Wassermann SpA -- 
Clinigen Group -- Clinigen Group PLC -- 
Vista -- no match -- 
Wella -- Wella AG -- 
FOX -- FOX SRL -- 
Antero -- no match -- 
Tesoro -- no match -- 
Norwegian Cruise Line -- no match -- 
LivaNova -- LivaNova PLC -- 
Brown Forman -- Brown-Forman Corp -- 
Realogy -- Realogy Holdings Corp -- 
LATEL L.L.C. -- LATEL LLC -- 
Kyowa Hakko Kirin Co., Ltd. -- Kyowa Kirin Co Ltd -- 
Laclede Gas Company -- no match -- 
Dish -- DISH Broadband LLC -- 
GENBAND US L.L.C. -- GENBAND US LLC -- 
AT and T Mobility -- AT and T Corp -- 
LR Advisors L.L.C. -- LR Advisors LLC -- 
Funding L.L.C. -- Funding LLC -- 
Delta Airlines -- Del

In [169]:
unique_company_relations.columns

Index(['accession_number', 'reported_company', 'reporter_name',
       'relationship_type', 'old_reported_company', 'old_match_type'],
      dtype='object')

In [180]:
unique_company_relations["old_match_type"].value_counts()

old_match_type
sim_match      3540
exact_match    1773
fword_match     781
no_match        681
Name: count, dtype: int64

### Compare the labelled relations with old pipeline relations

In [175]:
unique_query_data = query_data.drop_duplicates(subset=["accession_number", "reporter_name", "reported_company", 'relationship_type']).copy()

In [None]:
# unique_company_relations - relations from labelled data

# query_data - relations from old model


In [186]:
print(unique_query_data.shape)

# get the rows which has more than one relationship_type for a given accession_number, reporter_name, reported_company
unique_query_data = unique_query_data[~ unique_query_data.duplicated(subset=["accession_number", "reporter_name", "reported_company"], keep=False)]


unique_query_data = unique_query_data[["accession_number", "reporter_name", "reported_company", "relationship_type"]]
print(unique_query_data.shape)


(30391, 12)
(30391, 4)


In [177]:
print(unique_company_relations.columns)

print(unique_query_data.columns)


Index(['accession_number', 'reporter_name', 'reported_company',
       'relationship_type', 'old_reported_company', 'old_match_type'],
      dtype='object')
Index(['valid_from', 'last_seen', 'reporter_id', 'reporter_name',
       'relationship_type', 'reported_company_id', 'reported_company',
       'date_from', 'date_to', 'confidence', 'accession_number'],
      dtype='object')


In [188]:
unique_company_relations["join_id"] = unique_company_relations.apply(lambda x: 
"_".join([str(x[k]).lower() for k in ['accession_number', 'old_reported_company']]), axis=1)

unique_query_data["join_id"] = unique_query_data.apply(lambda x:
"_".join([str(x[k]).lower() for k in ['accession_number', 'reported_company']]), axis=1)


# do inner join unique_company_relations and unique_query_data on accession_number, reported_company

unique_company_relations_joined = unique_company_relations.merge(unique_query_data, on=["join_id"], how="inner",
suffixes=("_labeled", "_old"))




In [190]:
print(unique_company_relations_joined.shape)


print(unique_company_relations_joined.columns)


(6073, 11)
Index(['accession_number_labeled', 'reporter_name_labeled',
       'reported_company_labeled', 'relationship_type_labeled',
       'old_reported_company', 'old_match_type', 'join_id',
       'accession_number_old', 'reporter_name_old', 'reported_company_old',
       'relationship_type_old'],
      dtype='object')


In [191]:
# check if relationship_type_labeled and relationship_type_old are same

unique_company_relations_joined["relation_match"] = unique_company_relations_joined.apply(lambda x: x.relationship_type_labeled == x.relationship_type_old, axis=1)


In [192]:
unique_company_relations_joined["relation_match"].value_counts()

relation_match
True     6019
False      54
Name: count, dtype: int64

### Summary


- Positive labelled data has total 6775 unqiue relation 
- Pulled old models data from Athena and used as query_data
- Mapped reported_companies to old_models reported_companies using different matching methods 
- Matched data relations count - 6073
- Out of above, relations matching count - 6019 (~99%) , not matching count is 54