# **_Basic dependencies_**

In [2]:
import json
import pandas as pd
import os
import math
from typing import Dict, Any, Text, Tuple
import yaml
import sys
from pathlib import Path
src_dir= Path.cwd().parent
sys.path.append(str(src_dir))
import awswrangler as wr
from src.utils.s3_utils import put_json_obj
from src.utils.data import overwrite_json
from src.utils.logs import get_logger
from src.matcher.core import SimCSE_Matcher
from src.relation_extraction.infer import infer_from_trained
from src.relation_extraction.reporter import (agg_relations,
                                              process_relations,
                                              match_companies)
from src.glue.glue_etl import GlueETL

icon = "\U0001F4AB "
logger = get_logger(f"{icon} RE JOB", log_level="INFO")
############### Variables ################
CURRENT_STEP = "RE"
FOLLOWING = "Final"
distribute = False
##########################################
# Load GlueEtl Worker
etl = GlueETL()

# Reference it in the inference container at /opt/ml/model/code
def model_fn(model_dir: str) -> Tuple[infer_from_trained, SimCSE_Matcher]:
    """
    Loads the trained relation extractor and entity matcher models and returns them
    as a tuple.
    """
    relation_extractor = infer_from_trained(detect_entities=True,
                             language_model="en_core_web_trf",
                             require_gpu=True,
                            load_matcher=True,
                             entity_matcher=str(src_dir / "artifacts/matcher_model"))
    # "pipeline-artifacts/matcher/all-MiniLM-Nli-All-Random-v4"
    entity_matcher = SimCSE_Matcher(
        model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"
    )
    model_args = {"model_path": os.path.join(model_dir, "re_model"), "batch_size": 8}
    relation_extractor.load_model(model_args)
    return relation_extractor, entity_matcher


def float_format(x: Any) -> float:
    """
    Converts the given argument to a float and returns it.
    """
    return float(x)


def input_fn(request_body: str, content_type: str) -> Dict[str, Any]:
    """
    Parses the incoming request data and returns it as a dictionary.
    """
    if content_type == "application/json":
        request_content = json.loads(request_body)
    else:
        request_content = {}
    return request_content
    return input_data


if __name__ == "__main__":
    relation_extractor, entity_matcher = model_fn(src_dir/ "artifacts")

  from .autonotebook import tqdm as notebook_tqdm
11/20/2023 11:46:28 - INFO - matplotlib.font_manager -   generated new fontManager
11/20/2023 11:46:30 - INFO - botocore.credentials -   Found credentials in shared credentials file: ~/.aws/credentials
11/20/2023 11:46:30 - INFO - botocore.credentials -   Found credentials in shared credentials file: ~/.aws/credentials


Torch GPU Exists..
2023-11-20 11:46:35,869 — 🌌 spaCy — INFO — Language model used is en_core_web_trf
2023-11-20 11:46:35,874 — 🌌 spaCy — INFO — spaCy Work On GPU


Downloading tokenizer_config.json: 100%|██████████| 350/350 [00:00<00:00, 254kB/s]
Downloading vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 16.5MB/s]
Downloading tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 20.5MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 93.9kB/s]
Downloading config.json: 100%|██████████| 612/612 [00:00<00:00, 501kB/s]
Downloading pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:00<00:00, 407MB/s]


2023-11-20 11:46:43,746 — 💫 Relations Extractor — INFO — Loading tokenizer and model...
Load with CUDA


11/20/2023 11:46:46 - INFO - /notebooks/inferess-relation-extraction/src/relation_extraction/train_funcs.py -   Loaded model.


2023-11-20 11:46:46,498 — 💫 Relations Extractor — INFO — Done!


In [3]:
SENT = "At BDS, BOEING Co's continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally"

In [4]:
tagged = relation_extractor.tag_sentences([SENT])

In [7]:
tagged['sents'].tolist()

["At [E2] BDS [/E2], BOEING Co's continue to see a healthy market with solid demand for [E1] BOEING Co [/E1] major platforms and programs both domestically and internationally",
 "At [E2] BDS [/E2], [E1] BOEING Co [/E1]'s continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally"]

In [7]:
tagged = relation_extractor.predict_fn(tagged, mutate=True, reverse=True)

mutate text: 100%|██████████| 2/2 [00:00<00:00, 865.61it/s]
10/03/2023 07:37:13 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 2/2 [00:00<00:00, 966.54it/s]
tags positioning: 100%|██████████| 2/2 [00:00<00:00, 1682.43it/s]



Invalid rows/total: 0/2


100%|██████████| 1/1 [00:00<00:00, 56.13it/s]
mutate text: 100%|██████████| 2/2 [00:00<00:00, 1987.82it/s]
10/03/2023 07:37:13 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 2/2 [00:00<00:00, 1008.37it/s]
tags positioning: 100%|██████████| 2/2 [00:00<00:00, 2073.31it/s]



Invalid rows/total: 0/2


100%|██████████| 1/1 [00:00<00:00, 97.93it/s]


In [8]:
tagged

Unnamed: 0,sents,orig_sents,entity1,entity2,org_groups,idx,r_id,mutated_sents,scores
0,"At [E2] BDS [/E2], [E1] BOEING Co [/E1]'s cont...","At BDS, BOEING Co's continue to see a healthy ...",BOEING Co,BDS,"{'BOEING Co': 0, 'BDS': 1}",0,0_0,"At [E2] BDS [/E2], [E1] org-three [/E1]'s cont...","[0.05185773968696594, 0.9290893077850342, 0.01..."
1,"At [E2] BDS [/E2], BOEING Co's continue to see...","At BDS, BOEING Co's continue to see a healthy ...",BOEING Co,BDS,"{'BOEING Co': 0, 'BDS': 1}",0,0_0,"At [E2] BDS [/E2], org-seven's continue to see...","[0.001354456995613873, 0.9953155517578125, 0.0..."


In [9]:
pd.set_option('display.max_colwidth', None)
tagged

Unnamed: 0,sents,orig_sents,entity1,entity2,org_groups,idx,r_id,mutated_sents,scores
0,"At [E2] BDS [/E2], [E1] BOEING Co [/E1]'s continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally","At BDS, BOEING Co's continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally",BOEING Co,BDS,"{'BOEING Co': 0, 'BDS': 1}",0,0_0,"At [E2] BDS [/E2], [E1] org-three [/E1]'s continue to see a healthy market with solid demand for org-three major platforms and programs both domestically and internationally","[0.05185773968696594, 0.9290893077850342, 0.01905292645096779]"
1,"At [E2] BDS [/E2], BOEING Co's continue to see a healthy market with solid demand for [E1] BOEING Co [/E1] major platforms and programs both domestically and internationally","At BDS, BOEING Co's continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally",BOEING Co,BDS,"{'BOEING Co': 0, 'BDS': 1}",0,0_0,"At [E2] BDS [/E2], org-seven's continue to see a healthy market with solid demand for [E1] org-seven [/E1] major platforms and programs both domestically and internationally","[0.001354456995613873, 0.9953155517578125, 0.0033300507348030806]"


In [10]:
import numpy as np
id_scores = tagged.groupby(['r_id'])\
         .apply(lambda x : list(np.mean(x['scores'].tolist(), axis=0))).to_dict()

In [11]:
id_scores

{'0_0': [0.026606098341289908, 0.9622024297714233, 0.011191488592885435]}

In [13]:
pd.set_option("display.max_colwidth", None)
relation_extractor.predict_relations([SENT], mutate=True,reverse=True, num_positions=10)

mutate text: 100%|██████████| 2/2 [00:00<00:00, 2211.60it/s]
10/03/2023 07:37:48 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 2/2 [00:00<00:00, 950.23it/s]
tags positioning: 100%|██████████| 2/2 [00:00<00:00, 1414.13it/s]



Invalid rows/total: 0/2


100%|██████████| 1/1 [00:00<00:00, 77.28it/s]
mutate text: 100%|██████████| 2/2 [00:00<00:00, 1835.58it/s]
10/03/2023 07:37:48 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 2/2 [00:00<00:00, 1135.28it/s]
tags positioning: 100%|██████████| 2/2 [00:00<00:00, 1713.71it/s]



Invalid rows/total: 0/2


100%|██████████| 1/1 [00:00<00:00, 80.58it/s]
  tagged_frame.loc[:, 'scores'] =  labels, score


Unnamed: 0_level_0,relations,orig_sents,org_groups
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,"[{'BDS': 'customer', 'BOEING Co': 'supplier', 'score': 0.9617}]","At BDS, BOEING Co's continue to see a healthy market with solid demand for BOEING Co major platforms and programs both domestically and internationally","{'BOEING Co': 0, 'BDS': 1}"


# **_Simple Evaluation_**

In [14]:

from itertools import chain
def is_sc(relations):
    if not relations:
        return 0
    if isinstance(relations, list):
        return int(any([1 if x =='supplier' else 0 for x in\
             list(chain(*[list(r.values()) for r in relations])) ]))
    else:
        0
# Evaluate model performance on the simple data for a sanity check.
simple_data = pd.read_excel(src_dir /"data/raw/simple_sentences_cs_report.xlsx")
predictions = relation_extractor.predict_frame(simple_data, sentence_column = 'sentence', mutate=True, reverse=True)        
simple_data.loc[predictions.index, 'relations'] = predictions['relations']
simple_data.loc[:, 're_prediction'] = simple_data['relations'].apply(is_sc)
simple_data['re_prediction'].fillna(0, inplace=True)
simple_data.loc[:, 're_correct_prediction'] = simple_data['re_prediction'] == simple_data['true_label']
simple_data.query("true_label == re_prediction").shape[0] / len(simple_data)

mutate text: 100%|██████████| 500/500 [00:00<00:00, 15420.46it/s]
10/03/2023 07:38:29 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 500/500 [00:00<00:00, 9538.54it/s]
tags positioning: 100%|██████████| 500/500 [00:00<00:00, 85032.32it/s]



Invalid rows/total: 0/500


100%|██████████| 63/63 [00:00<00:00, 93.27it/s]
mutate text: 100%|██████████| 500/500 [00:00<00:00, 20245.32it/s]
10/03/2023 07:38:30 - INFO - __file__ -   Tokenizing data...
tokenization: 100%|██████████| 500/500 [00:00<00:00, 9632.95it/s]
tags positioning: 100%|██████████| 500/500 [00:00<00:00, 85542.18it/s]



Invalid rows/total: 0/500


100%|██████████| 63/63 [00:00<00:00, 94.32it/s]
  tagged_frame.loc[:, 'scores'] =  labels, score


0.8914893617021277

### **_Read entity tagged SEC sentences with Athena_**

In [51]:
# Read with CIK
query_data = etl.run_query("""SELECT * FROM "legacyevents"."filingtexttext_parquet" where reporter_cik IN ('0000012927',
    '0000037996',
    '0000104169',
    '0000320193',
    '0001047122',
    '0000078003',
    '0000789019',
    '0000320193',
    '0001467858')
""")

Running query:
 SELECT * FROM "legacyevents"."filingtexttext_parquet" where reporter_cik IN ('0000012927',
    '0000037996',
    '0000104169',
    '0000320193',
    '0001047122',
    '0000078003',
    '0000789019',
    '0000320193',
    '0001467858')

self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/e86b0b1c-6720-4d0f-8427-d44aa2ff4be1.csv
filename: e86b0b1c-6720-4d0f-8427-d44aa2ff4be1.csv
Query results shape: (18850, 7)


### **_Extract NER Tags_**

In [52]:
sents, spans, group_docs, aliases_docs = relation_extractor.spacy_loader.predictor(query_data['sentence'])

  0%|          | 0/22 [00:00<?, ?it/s]

In [53]:
# Convert JSON strings to Python objects
query_data.loc[:, "sentence"] = sents
query_data.loc[:, "spans"] = spans
query_data.loc[:, "org_groups"] = group_docs
query_data.loc[:, "aliases"] = aliases_docs
query_data.loc[:, 'num_orgs'] = query_data['org_groups']\
          .apply(lambda x : len(set(x.values()))).tolist()
query_data = query_data.query('num_orgs > 1')
query_data.reset_index(drop=True, inplace=True)

### **_Detect supply-chain_**
`Can be ignored`

In [11]:
from src.sc_classifier.trainer import Trainer
from src.sc_classifier.config.core import config
config.train_args.load_pretrained = True
sc_model = Trainer(config=config,load_data=False)

root==> /notebooks/inferess-relation-extraction


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Downloading (…)lve/main/config.json:   0%|          | 0.00/568 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/439M [00:00<?, ?B/s]

2023-09-21 13:13:46,120 — SCClassifier — INFO — loading checkpoint from `sc_model`


Downloading (…)okenizer_config.json:   0%|          | 0.00/393 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/152 [00:00<?, ?B/s]

2023-09-21 13:13:46,748 — SCClassifier — INFO — inference mode...


In [12]:
scores, preds = sc_model.predict_seq(query_data['sentence'] , max_length=128)
query_data.loc[:, 'sc_score'] = scores.max(1)
query_data.loc[:, 'sc_label'] = preds

100%|[32m██████████[0m| 73/73 [00:10<00:00,  7.05batch/s]


In [13]:
query_data = query_data.query('sc_label == 1 and sc_score > 0.95').reset_index(drop=True)

### **_Extract relations_**

In [54]:
# Predict relations
predictions= relation_extractor.predict_relations(
    sentences=query_data["sentence"].tolist(),
    ent="ORG",
    spans=query_data["spans"].tolist(),
    org_groups=query_data["org_groups"].tolist(),
    aliases=query_data["aliases"].tolist(),
    mutate=True, # re_model trained to predict
    reverse=True # aggregate average score between both directions
)


mutate text: 100%|██████████| 10925/10925 [00:00<00:00, 20283.68it/s]
09/22/2023 09:10:30 PM [INFO]: Tokenizing data...
tokenization: 100%|██████████| 10925/10925 [00:02<00:00, 5302.96it/s]
tags positioning: 100%|██████████| 10925/10925 [00:00<00:00, 68493.10it/s]



Invalid rows/total: 0/10925


100%|██████████| 1366/1366 [00:26<00:00, 51.87it/s]
mutate text: 100%|██████████| 10925/10925 [00:00<00:00, 19929.19it/s]
09/22/2023 09:11:00 PM [INFO]: Tokenizing data...
tokenization: 100%|██████████| 10925/10925 [00:02<00:00, 5350.93it/s]
tags positioning: 100%|██████████| 10925/10925 [00:00<00:00, 60677.87it/s]



Invalid rows/total: 0/10925


100%|██████████| 1366/1366 [00:26<00:00, 51.44it/s]


In [58]:
query_data['relations'] = None
query_data.loc[predictions.index.values, "relations"] = predictions["relations"]
query_data.dropna(subset=['relations'],inplace=True)

In [59]:
pd.set_option("display.max_colwidth", None)
query_data[40: 60].drop(['spans'], axis=1)

Unnamed: 0,accessionnumber,reporter_name,reporter_normalizedname,reporter_cik,sentence_id,sentence,filedasofdate,org_groups,aliases,num_orgs,relations
40,0001467858-21-000037,General Motors Co,General Motors Co,1467858,159,Factors that affect future funding requirements for General Motors Co US defined benefit plans generally affect the required funding for non US plans.,2021-02-10,"{'General Motors Co': 0, 'non US': 1}",[],2,"[{'non US': 'other', 'General Motors Co': 'other', 'score': 0.798}]"
41,0001467858-21-000037,General Motors Co,General Motors Co,1467858,162,"For EBIT adjusted and General Motors Co other non GAAP measures, once General Motors Co have made an adjustment in the current period for an item, General Motors Co will also adjust the related non GAAP measure in any future periods in which there is an impact from the item",2021-02-10,"{'General Motors Co': 0, 'non GAAP measure': 1, 'non GAAP': 2}",[],3,"[{'non GAAP measure': 'customer', 'General Motors Co': 'supplier', 'score': 0.6078}]"
42,0001467858-21-000037,General Motors Co,General Motors Co,1467858,171,For these reasons General Motors Co believe these non GAAP measures are useful for General Motors Co investors,2021-02-10,"{'General Motors Co': 0, 'non GAAP measures': 1}",[],2,"[{'non GAAP measures': 'supplier', 'General Motors Co': 'customer', 'score': 0.4937}]"
43,0001467858-21-000037,General Motors Co,General Motors Co,1467858,186,"Forward Looking Statements This report and the other reports filed by General Motors Co with the SEC from time to time, as well as statements incorporated by reference herein and related comments by General Motors Co management, may include ""forward looking statements"" within the meaning of the US federal securities laws",2021-02-10,"{'General Motors Co': 0, 'SEC': 1}",[],2,"[{'SEC': 'other', 'General Motors Co': 'other', 'score': 0.9924}]"
44,0001467858-21-000037,General Motors Co,General Motors Co,1467858,192,"Further, as an entity operating in the financial services sector, GM Financial is required to comply with a wide variety of laws and regulations that may be costly to adhere to and may affect General Motors Co consolidated operating results",2021-02-10,"{'General Motors Co': 0, 'GM Financial': 1}",[],2,"[{'GM Financial': 'supplier', 'General Motors Co': 'customer', 'score': 0.8124}]"
45,0001467858-21-000037,General Motors Co,General Motors Co,1467858,196,"Furthermore, these non GAAP measures allow investors the opportunity to measure and monitor General Motors Co performance against General Motors Co externally communicated targets and evaluate the investment decisions being made by management to improve ROIC adjusted.",2021-02-10,"{'General Motors Co': 0, 'non GAAP measures': 1}",[],2,"[{'non GAAP measures': 'other', 'General Motors Co': 'other', 'score': 0.9978}]"
46,0001467858-21-000037,General Motors Co,General Motors Co,1467858,198,"GM Financial The amounts presented for GM Financial have been adjusted to include the effect of General Motors Co tax attributes on GM Financial's deferred tax positions and provision for income taxes, which are not applicable to GM Financial on a stand alone basis, and to eliminate the effect of transactions between GM Financial and the other members of the consolidated group",2021-02-10,"{'General Motors Co': 0, 'GM Financial's': 1, 'GM Financial': 1}",[],2,"[{'GM Financial's': 'supplier', 'General Motors Co': 'customer', 'score': 0.482}]"
47,0001467858-21-000037,General Motors Co,General Motors Co,1467858,199,"GM Financial did not have borrowings outstanding against General Motors Co revolving credit facilities at December 31, 2020 and 2019.",2021-02-10,"{'General Motors Co': 0, 'GM Financial': 1}",[],2,"[{'GM Financial': 'other', 'General Motors Co': 'other', 'score': 0.5163}]"
48,0001467858-21-000037,General Motors Co,General Motors Co,1467858,200,"GM Financial faces a number of business, economic and financial risks that could impair its access to capital and negatively affect its business and operations, which in turn could impede its ability to provide leasing and financing to customers and commercial lending to General Motors Co dealers",2021-02-10,"{'General Motors Co': 0, 'GM Financial': 1}",[],2,"[{'GM Financial': 'supplier', 'General Motors Co': 'customer', 'score': 0.8507}]"
49,0001467858-21-000037,General Motors Co,General Motors Co,1467858,201,"GM Financial has access to $16.5 billion of General Motors Co revolving credit facilities with exclusive access to the 364 day, $2.0 billion facility.",2021-02-10,"{'General Motors Co': 0, 'GM Financial': 1}",[],2,"[{'GM Financial': 'other', 'General Motors Co': 'other', 'score': 0.9824}]"


In [60]:
import math
from typing import Tuple, List, Text, Dict
def top_n_size(x, y, z=None):
    """
    Input - 
    x - count of supplier relation sentences
    y - count of customer relation sentences
    z - count of other relation sentences

    Balaced approach towards all relations
    Return the minimum of "20% of count each relation" as top_n_relations to consier in final scoring

    """
    if z:
        assert (x > 0) and (y > 0) and (z > 0)
        n1 = math.ceil(x * 0.2)
        n2 = math.ceil(y * 0.2)
        n3 = math.ceil(z * 0.2)
        return min(n1, min(n2, n3))
    else:
        assert (x > 0) and  (y > 0)
        n1 = math.ceil(x * 0.2)
        n2 = math.ceil(y * 0.2)
        return min(n1, n2)

def top_n_size_new(x, y):
    """
    Input - 
    x - count of supplier relation sentences
    y - count of customer relation sentences

    - Ignore the relation count of other relations, find the top_n_size based on 
    only supplier and customer relations count. 
    - This approach slightely favors relation occuring more times
   
    """
    assert (x > 0) and (y > 0) 
    
    # If difference of just 1 relation count, two relations will fight for winning relation
    if abs(x - y) == 1:
        
        return min(x, y)
    else:
        # if more difference in relation counts, favor to relation with more count
        n1 = math.ceil(x * 0.5)
        n2 = math.ceil(y * 0.5)        
        return max(n1, n2)


def log_sum_top_n(scores, top_n_size):
    """
    Logarithmic sum of top_n scores
    function name - log_sum_top_n
    """
    total_score = sum(scores)
    avg_score = total_score / len(scores)
    sorted_classifications = sorted(scores, reverse=True)
    top_n = sorted_classifications[:top_n_size]
    top_n_conf = sum(top_n)
    return avg_score * (1 + math.log(top_n_conf))
    
    

def agg_relation_score(company_relation_score: Dict, top_n_approach: str):
    """
    Input:
    top_n_approach: "old" or "new"
        "old" - top_n_size function
        "new" - top_n_size_new function
    
    Returns:
    Aggregates the scores for each relation type (supplier, customer, other) and 
    returns a dictionary with the aggregated scores.
    """

    supplier_scores = company_relation_score.get("supplier_scores", [])
    customer_scores = company_relation_score.get("customer_scores", [])
    other_scores = company_relation_score.get("other_scores", [])
 
    label_scores = {"supplier": 0, "customer": 0, "other": 0}

    # no scores for any relation 
    if not supplier_scores and not customer_scores and not other_scores:
        pass
    
    # only one relation has scores
    elif supplier_scores and not customer_scores and not other_scores:
        label_scores["supplier"] = log_sum_top_n(supplier_scores, len(supplier_scores))
    
    elif customer_scores and not supplier_scores and not other_scores:        
        label_scores["customer"] = log_sum_top_n(customer_scores, len(customer_scores))

    elif other_scores and not customer_scores and not supplier_scores:
        label_scores["other"] = log_sum_top_n(other_scores, len(other_scores)) 
    
    # two or more relations have scores
    else:
        if customer_scores and supplier_scores and not other_scores:
            if top_n_approach == "old":
                n = top_n_size(len(customer_scores), len(supplier_scores))
            elif top_n_approach == "new":
                n = top_n_size_new(len(customer_scores), len(supplier_scores))
        elif customer_scores and other_scores and not supplier_scores:
            if top_n_approach == "old":
                n = top_n_size(len(customer_scores), len(other_scores))
            elif top_n_approach == "new":
                n = top_n_size_new(len(customer_scores), len(other_scores))

        elif supplier_scores and other_scores and not customer_scores:
            if top_n_approach == "old":
                n = top_n_size(len(supplier_scores), len(other_scores))
            elif top_n_approach == "new":
                n = top_n_size_new(len(supplier_scores), len(other_scores))
        elif customer_scores and supplier_scores and other_scores:
            if top_n_approach == "old":
                n = top_n_size(len(customer_scores), len(supplier_scores), len(other_scores))
            elif top_n_approach == "new":
                n = top_n_size_new(len(customer_scores), len(supplier_scores))

        if customer_scores:
            label_scores["customer"] = log_sum_top_n(customer_scores, n)
        if supplier_scores:
            label_scores["supplier"] = log_sum_top_n(supplier_scores, n)
        if other_scores:
            label_scores["other"] = log_sum_top_n(other_scores, n)

    return label_scores
    
def get_winning_relation(company_relation_score: Dict, top_n_approach: str  ):
    """
    Input:
    ------
    Dict with list of scores for each relation type. 
    {"customer_scores": [..],  "supplier_scores": [..], "other_scores": [..]}
    
    top_n_approach: "old" or "new"
    "old" - top_n_size function
    "new" - top_n_size_new function
    
    Return:
    -------
    Dict with relation type and scores for each relation type.
    If there are more than one relation type with the same score, then winning relation is "Supplier".

    """

    # get the scores for each relation type: Dict[relation_key, aggregated_score_for_relation]
    relation_scores = agg_relation_score(company_relation_score, top_n_approach)

    # find the max score and get the relation type of max score
    max_score = max(relation_scores.values())
    max_score_relations = [k for k, v in relation_scores.items() if v == max_score]

    # if there are more than one relation type with max score, then return "Suuplier"
    if len(max_score_relations) > 1:
        winning_relation = "supplier"
    else:
        winning_relation = max_score_relations[0]
    
    relation_scores["winning_relation"] = winning_relation

    return relation_scores


    

In [61]:
from typing import Dict, List, Tuple
import string
import re
from itertools import chain
from collections import defaultdict
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple

def process_relations(
    all_relations: pd.DataFrame,
    matcher,
    _ingest=False,
    match_thresh=0.973,
    top_n_approach="new"
    
) -> List[Dict[str, str]]:
    """
    Process relations from all_relations dataframe and perform clustering based on organization names.

    @params
    -------
    - all_relations (pd.DataFrame): DataFrame containing all relations data.
    - matcher: Matcher object for name matching.
    - _ingest (bool): Flag indicating whether the data should be ingested.
    - match_thresh (float): Threshold for matching similarity.
    - top_n_approach (str): Approach for selecting top relations.

    @returns:
    --------
    List[Dict[str, str]]: Processed relations data.
    """
    # Initialize lists to store all items and aliases
    all_items = []
    all_aliases = []
    Items = defaultdict(dict)

    # Group relations by accessionnumber and process each file
    for _, group in all_relations.groupby("accessionnumber"):
        # Extract aliases and create mappings from aliases to names and names to aliases
        aliases = list(filter(None, group["aliases"].tolist()))
        aliases = list(chain(*aliases))
        aliases = set([tuple(l) for l in aliases])
        alias2name = defaultdict(list)
        name2alias = defaultdict(list)
        for k, v in aliases:
            name2alias[k].append(v)
            alias2name[v].append(k)
        # Initialize defaultdict to store relations by company and reporter mentions
        companies_relations = defaultdict(dict)
        company_represents = defaultdict(list)
        # Extract reporter mentions and aliases
        all_orgs = np.array(list(set(chain(*group["org_groups"]))))
        # Build index for names and return embeddings for each name
        embs = matcher.build_index(all_orgs.tolist(), return_emb=True)
        # Cluster org names with certain threshold
        results = matcher.search(tuple((all_orgs.tolist(), embs)) ,  threshold=match_thresh)
        # Initialize basic variables for clustering:
        #   f_list-> flatten_list containes all companies
        #   ids_c-> ids counter to set id for each group
        #   org2id-> organization name mapped to unique id that represent it's group
        #   id2group-> each id mapped to sorted set of groups names
        f_list = []
        ids_c = 0
        org2id = {}
        id2org = {}
        # Loop over the maches to cluster names with cosince similarity
        for c, matches in zip(all_orgs.tolist(), results):
            # Continue if the name existed before
            if c in f_list:
                continue
            # Get all names with high sim score
            n_matches  = [x[0] for x in matches]
            # Add alaises if founded
            n_matches = n_matches +  list(chain(*[alias2name.get(x, []) for x in n_matches]))
            n_matches = n_matches + list(chain(*[name2alias.get(x, []) for x in n_matches]))
            for name in n_matches:
                org2id[name] = ids_c
            # Filter and sort names to set first name is the longest to be representative of the company
            id2org[ids_c] = sorted(set(filter(None, n_matches)), key=lambda x: len(x), reverse=True)
            ids_c += 1
            f_list += n_matches
        # Define all names used for the reporter
        reporter_names = [x[0] for x in matcher.search(group["reporter_name"].iloc[0], threshold=match_thresh)]
        # Identify reporter mentions and add them to reporter_mentions set
        reporter_mentions = set(
            alias2name.get(group["reporter_name"].iloc[0], [])
            + list(chain(*[name2alias.get(x, []) for x in reporter_names]))
            + [group["reporter_name"].iloc[0]]
            + reporter_names
        )
                # Process each relation in the group
        for _, raw in group.iterrows():
            for rel in raw["relations"]:
                relation = rel.copy()
                scores= defaultdict(list)
                # Check if relation mentions reporter and remove reporter mention from relation
                for reporter in reporter_mentions:
                    if relation.get(reporter, None) and relation:
                        relation.pop(reporter)
                        if not relation:
                            continue
                        score = relation.pop("score")
                        company = list(relation.keys())[0]
                        # Get the representative name of the company
                        representative = id2org[org2id[company]]
                        company_represents[representative[0]] = representative
                        if not companies_relations[representative[0]].get('sentences'):
                            companies_relations[representative[0]]['sentences'] = []
                            companies_relations[representative[0]]['scores'] = defaultdict(list)
                        
                        # add the score to scores dict
                        companies_relations[representative[0]]['scores']\
                        ["{}_scores".format(relation[company])].append(score)
                        # add sentence info
                        companies_relations[representative[0]]['sentences'].append(
                            {
                                "sentence": raw["sentence"],
                                "sentence_id": raw["sentence_id"],
                                "relation": relation[company],
                                "score": score,
                            }
                        )
                
        # Aggregate relations by company and append to file_report
        intial_vals = group.iloc[0]
        
        Items[intial_vals["accessionnumber"]] = dict()
        sec_item = Items[intial_vals["accessionnumber"]]
        sec_item["PK"] = {"S": f"an#{intial_vals['accessionnumber']}"} if _ingest else intial_vals['accessionnumber']
        sec_item["reporterName"] = {"S": intial_vals["reporter_name"]} if _ingest else intial_vals['reporter_name']
        sec_item["cik"] = {"S": str(intial_vals["reporter_cik"])} if _ingest else intial_vals['reporter_cik']
        sec_item["accessionNumber"] = {"S": intial_vals["accessionnumber"]} if _ingest else  intial_vals["accessionnumber"]
        sec_item["type"] = {"S": "relationship"}if _ingest else "relationship"
        sec_item["filingDate"] = {"S": intial_vals["filedasofdate"]}if _ingest else intial_vals['filedasofdate']
    
        for co in list(companies_relations.keys()):
            companies_relations[co]['aggregation_results'] = get_winning_relation(companies_relations[co]['scores'],
                                                                                  top_n_approach=top_n_approach)
        sec_item['relations'] = dict(companies_relations)

    return dict(Items)


In [62]:
relations_report = process_relations(query_data, entity_matcher, top_n_approach="new")

09/22/2023 09:17:08 PM [INFO]: Loading faiss with AVX2 support.
09/22/2023 09:17:08 PM [INFO]: Successfully loaded faiss with AVX2 support.
09/22/2023 09:17:08 PM [INFO]: Encoding embeddings for sentences...
09/22/2023 09:17:08 PM [INFO]: Building index...
09/22/2023 09:17:08 PM [INFO]: StandardGpuResources not found in faiss, Use CPU-version faiss
09/22/2023 09:17:08 PM [INFO]: Finished
09/22/2023 09:17:08 PM [INFO]: Encoding embeddings for sentences...
09/22/2023 09:17:08 PM [INFO]: Building index...
09/22/2023 09:17:08 PM [INFO]: StandardGpuResources not found in faiss, Use CPU-version faiss
09/22/2023 09:17:08 PM [INFO]: Finished
09/22/2023 09:17:08 PM [INFO]: Encoding embeddings for sentences...
09/22/2023 09:17:08 PM [INFO]: Building index...
09/22/2023 09:17:08 PM [INFO]: StandardGpuResources not found in faiss, Use CPU-version faiss
09/22/2023 09:17:08 PM [INFO]: Finished
09/22/2023 09:17:08 PM [INFO]: Encoding embeddings for sentences...
09/22/2023 09:17:08 PM [INFO]: Building

In [63]:
relations_report.keys()

dict_keys(['0000012927-21-000011', '0000012927-23-000007', '0000037996-21-000012', '0000078003-21-000038', '0000104169-21-000033', '0000320193-21-000105', '0001467858-21-000037', '0001467858-23-000029'])

### **_Print reports_**

In [64]:
from pprint import pprint
relations_report['0000012927-21-000011']

{'PK': '0000012927-21-000011',
 'reporterName': 'BOEING CO',
 'cik': 12927,
 'accessionNumber': '0000012927-21-000011',
 'type': 'relationship',
 'filingDate': '2021-02-01',
 'relations': {'747 Program': {'sentences': [{'sentence': '747 Program BOEING Co are currently producing at a rate of 0.5 aircraft per month',
     'sentence_id': 6,
     'relation': 'supplier',
     'score': 0.5047}],
   'scores': defaultdict(list, {'supplier_scores': [0.5047]}),
   'aggregation_results': {'supplier': 0.15959063907955104,
    'customer': 0,
    'other': 0,
    'winning_relation': 'supplier'}},
  '787 Program': {'sentences': [{'sentence': '787 Program During 2020, BOEING Co experienced significant reductions in deliveries due to the impacts of COVID 19 on BOEING Co customers as well as production issues and associated rework',
     'sentence_id': 7,
     'relation': 'customer',
     'score': 0.5242}],
   'scores': defaultdict(list, {'customer_scores': [0.5242]}),
   'aggregation_results': {'supplie

09/17/2023 11:54:31 AM [INFO]: Loading faiss with AVX2 support.
09/17/2023 11:54:31 AM [INFO]: Successfully loaded faiss with AVX2 support.
09/17/2023 11:54:31 AM [INFO]: Encoding embeddings for sentences...
09/17/2023 11:54:31 AM [INFO]: Building index...
09/17/2023 11:54:31 AM [INFO]: Use CPU-version faiss
09/17/2023 11:54:31 AM [INFO]: Finished


TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

In [None]:
# Create logs
log_frame = etl.create_logs(valid_data)

if query_data.shape[0] > 0:
    # Ingest new data
    wr.s3.to_parquet(
        df=valid_data,
        database=etl.database,
        table=etl.relations_table,
        dataset=True,
        path=etl.database_path + "/" + etl.relations_table,
        partition_cols=list(etl.relations_partitions.keys()),
        mode="overwrite_partitions",
        boto3_session=etl.session,
    )
    
    
    # Update logs
    wr.s3.to_parquet(
        df=log_frame,
        database=etl.database,
        table=etl.logs_table,
        dataset=True,
        path=etl.database_path + "/" + etl.logs_table,
        partition_cols=list(etl.logs_partitions.keys()),
        mode="overwrite_partitions",
        boto3_session=etl.session,
    )

    # Log success and failure
    logger.info(f"Ingested files with ids: {file_ids} successfully\u2705")
    success |= set(log_frame["accessionnumber"])
    failed = set(file_ids) - success
    # update_response = etl.add_results(CURRENT_STEP, FOLLOWING, success, failed)

else:
    logger.info("Didn't find any valid sentence for supply classification")


2023-07-26 15:55:45,453 — 💫  RE JOB — INFO — Ingested files with ids: ['0001193125-21-114283'] successfully✅


In [None]:
log_frame

Unnamed: 0,accessionnumber,valid_for_supply,supply_estimations,valid_for_re,relation_estimations,filedasofdate
0,0001193125-21-114283,24,24,24,24,2021-04-13


In [None]:
org_links = match_companies(
    predictions=query_data,
    entity_matcher=entity_matcher,
    etl_worker=etl,
    lookup_table="company",
    index_column="companyprefix-normalizedname-index",
    attribute_name="companyprefix",
    prefix_len=2,
    sort_len=5,
    normalized_column="normalizedname",
    id_column="rgid",
    database_type="dynamodb",
    match_thresh=0.973,
    cand_thresh=0.90,
    top_k=5,
    index_memory="cuda",
)


2023-09-17 11:25:00,800 — 🌍 ETL — INFO — Query DyanmoDB to find the `companyprefix` match companies to search for links


  0%|          | 0/1029 [00:00<?, ?it/s]


ClientError: An error occurred (ValidationException) when calling the Query operation: The table does not have the specified index: companyprefix-normalizedname-index

In [None]:
org_links

NameError: name 'org_links' is not defined

In [None]:
relations_items = process_relations(valid_data, entity_matcher, org_links)

In [None]:
failed_insertions = etl.batch_write_dynamodb_items(relations_items, "Predictions")

Batch insertion into `Predictions` Table:   0%|          | 0/215 [00:00<?, ?it/s]

Success rate 201/215 


In [None]:
with open("sample_data/relations_items_1.json", "w") as j:
    json.dump(relations_items, j)

In [None]:
for fail in failed_insertions:
    print(fail['Exception'])

An error occurred (ProvisionedThroughputExceededException) when calling the BatchWriteItem operation (reached max retries: 9): The level of configured provisioned throughput for one or more global secondary indexes of the table was exceeded. Consider increasing your provisioning level for the under-provisioned global secondary indexes with the UpdateTable API
An error occurred (ProvisionedThroughputExceededException) when calling the BatchWriteItem operation (reached max retries: 9): The level of configured provisioned throughput for one or more global secondary indexes of the table was exceeded. Consider increasing your provisioning level for the under-provisioned global secondary indexes with the UpdateTable API
An error occurred (ProvisionedThroughputExceededException) when calling the BatchWriteItem operation (reached max retries: 9): The level of configured provisioned throughput for one or more global secondary indexes of the table was exceeded. Consider increasing your provision

In [None]:
def predict_fn(input_data: dict, model: tuple) -> dict:
    """
    Predicts relations between entities in the input data using a trained
    relation extraction model and an entity matching model.

    Args:
        input_data: A dictionary containing input data to be processed.
        model: A tuple containing a trained relation extraction model and
            an entity matching model.

    Returns:
        A dictionary containing the processed input data.
    """
    # Print input data type and content
    print("Type of input:", type(input_data))
    print("Input data:", input_data)

    # Unpack the models
    relation_extractor, entity_matcher = model
    # Set the job input data
    etl.job = input_data

    # Update ETL starting status
    etl.update_starting(CURRENT_STEP, 1, add=False)

    # Block job files
    file_ids = etl.block_job_files(
        task=CURRENT_STEP,
        number_files=etl.config["job"][f"{CURRENT_STEP}_max_files"],
        distribute=distribute,
    )
    while file_ids:
        # Load data with filters
        query_data = etl.load_with_filter(
            etl.relations_table, col="accessionnumber", condition="isin", value=file_ids
        )

        # Filter for valid data
        valid_idx = query_data.query("supply_label == 1").index
        valid_data = query_data.iloc[valid_idx].copy()

        # Convert JSON strings to Python objects
        valid_data["spans"] = valid_data.spans.apply(json.loads)
        valid_data["org_groups"] = valid_data.org_groups.apply(json.loads)
        valid_data["aliases"] = valid_data.aliases.apply(json.loads)
        valid_data.reset_index(drop=True, inplace=True)
        # Set invalid files as successed
        success = set(query_data.accessionnumber.unique()) - set(
            valid_data.accessionnumber.unique()
        )

        # Predict relations
        predictions = relation_extractor.predict_relations(
            sentences=valid_data["sentence"].tolist(),
            ent="ORG",
            spans=valid_data["spans"].tolist(),
            org_groups=valid_data["org_groups"].tolist(),
            aliases=valid_data["aliases"].tolist(),
        )

        # Fill NaN values in predictions
        predictions.relations.fillna({}, inplace=True)
        valid_data.loc[predictions.index.values, "relations"] = None
        valid_data.loc[predictions.index.values, "relations"] = predictions["relations"]

        # Create logs
        log_frame = etl.create_logs(query_data)

        if query_data.shape[0] > 0:
            # Ingest new data
            wr.s3.to_parquet(
                df=query_data,
                database=etl.database,
                table=etl.relations_table,
                dataset=True,
                path=etl.database_path + "/" + etl.relations_table,
                partition_cols=list(etl.relations_partitions.keys()),
                mode="overwrite_partitions",
                boto3_session=etl.session,
            )

            # Update logs
            wr.s3.to_parquet(
                df=log_frame,
                database=etl.database,
                table=etl.logs_table,
                dataset=True,
                path=etl.database_path + "/" + etl.logs_table,
                partition_cols=list(etl.logs_partitions.keys()),
                mode="overwrite_partitions",
                boto3_session=etl.session,
            )

            # Log success and failure
            logger.info(f"Ingested files with ids: {file_ids} successfully\u2705")
            success |= set(log_frame["accessionnumber"])
            failed = set(file_ids) - success
            update_response = etl.add_results(CURRENT_STEP, FOLLOWING, success, failed)

        else:
            logger.info("Didn't find any valid sentence for supply classification")

        # Block job files
        file_ids = etl.block_job_files(
            task=CURRENT_STEP,
            number_files=etl.config["job"][f"{CURRENT_STEP}_max_files"],
            distribute=distribute,
        )
    return input_data

In [None]:
etl.load_table("logs")

Unnamed: 0,valid_for_supply,supply_estimations,valid_for_re,relation_estimations,filedasofdate,accessionnumber
0,10,10,0,0,2023-01-03,0001477932-23-000002
1,244,244,14,0,2023-01-03,0001477932-23-000012
2,200,200,18,0,2023-01-04,0001493152-23-000346
3,56,56,1,0,2023-01-05,0001096906-23-000016
4,104,104,8,0,2023-01-06,0001091818-23-000002
...,...,...,...,...,...,...
88,132,132,16,0,2023-01-31,0001373715-23-000035
89,195,195,21,0,2023-01-31,0001437749-23-002137
90,629,629,70,0,2023-01-31,0001467858-23-000029
91,55,55,3,0,2023-01-31,0001674796-23-000007


In [None]:
df = etl.run_query(f"SELECT * FROM {etl.sentences_table} limit 10 ")

Running query:
 SELECT * FROM sec_sentences limit 10 
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/bbc4dfb9-8b9e-4efb-9e87-cdcaca9415a5.csv
filename: bbc4dfb9-8b9e-4efb-9e87-cdcaca9415a5.csv
Query results shape: (10, 8)


### _Read relations table_

In [None]:
query_string = f"SELECT * from {etl.relations_table} WHERE relations IS NOT NULL"
all_relations = etl.run_query(query_string)
all_relations.loc[:, "relations"] = all_relations["relations"].apply(json.loads)
all_relations.loc[:, "org_groups"] = all_relations["org_groups"].apply(json.loads)
all_relations.loc[:, "spans"] = all_relations["spans"].apply(json.loads)
all_relations.loc[:, "aliases"] = all_relations["aliases"].apply(json.loads)

Running query:
 SELECT * from sec_relations WHERE relations IS NOT NULL
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/4f2200f0-f128-439d-a845-7b09e747d14c.csv
filename: 4f2200f0-f128-439d-a845-7b09e747d14c.csv
Query results shape: (686, 15)


### _description of the Predictions table_

In [None]:
etl.dynamodb.describe_table(TableName="Predictions")

{'Table': {'AttributeDefinitions': [{'AttributeName': 'PK',
    'AttributeType': 'S'},
   {'AttributeName': 'SK', 'AttributeType': 'S'},
   {'AttributeName': 'accessionNumber', 'AttributeType': 'S'},
   {'AttributeName': 'extractedName', 'AttributeType': 'S'}],
  'TableName': 'Predictions',
  'KeySchema': [{'AttributeName': 'PK', 'KeyType': 'HASH'},
   {'AttributeName': 'SK', 'KeyType': 'RANGE'}],
  'TableStatus': 'ACTIVE',
  'CreationDateTime': datetime.datetime(2023, 6, 22, 17, 30, 38, 709000, tzinfo=tzlocal()),
  'ProvisionedThroughput': {'LastIncreaseDateTime': datetime.datetime(2023, 6, 24, 20, 13, 35, 718000, tzinfo=tzlocal()),
   'LastDecreaseDateTime': datetime.datetime(2023, 6, 24, 20, 28, 51, 137000, tzinfo=tzlocal()),
   'NumberOfDecreasesToday': 0,
   'ReadCapacityUnits': 1,
   'WriteCapacityUnits': 1},
  'TableSizeBytes': 147812,
  'ItemCount': 355,
  'TableArn': 'arn:aws:dynamodb:us-east-1:276066050088:table/Predictions',
  'TableId': '23c1ad11-6d95-48b6-97ea-b8edbb6e65a3

### _Read Predictions table_

In [None]:
etl.dynamodb.query

<bound method ClientCreator._create_api_method.<locals>._api_call of <botocore.client.DynamoDB object at 0x7f4f63bfbaf0>>

In [None]:
entity_matcher.similarity("amazon web services inc", "amazon web services")

0.9882376194000244

In [None]:
entity_matcher.build_index(["Ahmed", "Mohamed"])

06/25/2023 05:28:31 PM [INFO]: Loading faiss with AVX2 support.
06/25/2023 05:28:31 PM [INFO]: Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
06/25/2023 05:28:31 PM [INFO]: Loading faiss.
06/25/2023 05:28:31 PM [INFO]: Successfully loaded faiss.
06/25/2023 05:28:31 PM [INFO]: Encoding embeddings for sentences...
06/25/2023 05:28:31 PM [INFO]: Building index...
06/25/2023 05:28:31 PM [INFO]: Use GPU-version faiss
06/25/2023 05:28:32 PM [INFO]: Finished


In [None]:
import re, string

# Now we want to create list of tuples
all_companies = list(filter(None, [" the Es ", "the Yah  oo", "Whatever", "whatever"]))


queries = []
for company in all_companies:
    company_prefix = (
        re.sub(f"[{re.escape(string.punctuation)}]", "", company.strip())
        .lower()
        .replace("the", "")
        .replace(" ", "")[:2]
    )
    if len(company_prefix) == 0:
        continue
    sort = re.sub(f"[{re.escape(string.punctuation)}]", "", company.strip()).lower()
    if sort.split(" ")[0] == "the":
        queries.append((company_prefix, sort[:8].strip()))
        sort = sort[4:].strip()
    queries.append((company_prefix, sort[:4].strip()))

In [None]:
org_links = match_companies(
    predictions=all_relations,
    entity_matcher=entity_matcher,
    etl_worker=etl,
    lookup_table="company",
    index_column="companyprefix-normalizedname-index",
    attribute_name="companyprefix",
    prefix_len=2,
    sort_len=5,
    normalized_column="normalizedname",
    id_column="rgid",
    database_type="dynamodb",
    match_thresh=0.95,
    cand_thresh=0.80,
    top_k=5,
)
all_relations["date"] = pd.to_datetime(all_relations[["year", "month", "day"]])
relations_items = process_relations(all_relations, entity_matcher, org_links)
failed = etl.batch_write_dynamodb_items(relations_items, "Predictions")

2023-06-25 17:52:24,494 — 🌍 ETL — INFO — Query DyanmoDB to find the `companyprefix` match companies to search for links


100%|██████████| 398/398 [01:17<00:00,  5.16it/s]


2023-06-25 17:53:42,831 — 🌍 ETL — INFO — Found 685425 items with prefix lookup


06/25/2023 05:53:44 PM [INFO]: Encoding embeddings for sentences...


  0%|          | 0/5266 [00:00<?, ?it/s]

06/25/2023 05:54:55 PM [INFO]: Building index...
06/25/2023 05:54:55 PM [INFO]: Use GPU-version faiss
06/25/2023 05:54:55 PM [INFO]: Finished


Batch insertion into `Predictions` Table:   0%|          | 0/15 [00:00<?, ?it/s]

Success rate 15/15 


In [None]:
failed

In [None]:
org_links

{'AbbVie': {'matches': ['70103090817',
   '70102857055',
   '70103090817',
   '70102857055',
   '70107358422',
   '70100071981',
   '70102555268',
   '70107860876',
   '70107115942',
   '70106585558',
   '70109214783',
   '70107358422',
   '70100071981',
   '70102555268',
   '70107860876',
   '70107115942',
   '70106585558',
   '70109214783',
   '70107358422',
   '70100071981',
   '70102555268',
   '70107860876',
   '70107115942',
   '70106585558',
   '70109214783'],
  'matches_names': ['abbvie as',
   'abbvie as',
   'abbvie ltd',
   'abbvie ltd',
   'abbvie ltd'],
  'candidates': [],
  'candidates_names': []},
 'Arma Services': {'matches': ['70102839734',
   '70109378754',
   '70110162991',
   '70100787243',
   '70108646615'],
  'matches_names': ['arma services inc',
   'arma ltd',
   'arma group ltd',
   'arma international ltd',
   'arma holdings ltd'],
  'candidates': [],
  'candidates_names': []},
 'eBay': {'matches': [],
  'matches_names': [],
  'candidates': ['70104828456',
   

In [None]:
entity_matcher.model.device

device(type='cuda', index=0)

2023-06-24 17:06:21,105 — 🌍 ETL — INFO — Query DyanmoDB to find the `companyprefix` match companies to search for links


100%|██████████| 398/398 [01:09<00:00,  5.73it/s]


2023-06-24 17:07:31,612 — 🌍 ETL — INFO — Found 685425 items with prefix lookup


06/24/2023 05:07:33 PM [INFO]: Encoding embeddings for sentences...


  0%|          | 0/5266 [00:00<?, ?it/s]

06/24/2023 05:08:23 PM [INFO]: Building index...
06/24/2023 05:08:23 PM [INFO]: Use CPU-version faiss
06/24/2023 05:08:24 PM [INFO]: Finished


Batch insertion into `Predictions` Table:   0%|          | 0/15 [00:00<?, ?it/s]

Success rate 15/15 


In [None]:
found_mask = [1 if len(v["matches"]) > 0 else 0 for v in org_links.values()]

In [None]:
sum(found_mask) / len(found_mask)

0.7016129032258065

In [None]:
links = []
for k, v in org_links.items():
    links.append(
        {
            "extracted_name": k,
            "matches_names": v["matches_names"],
            "candidates_names": v["candidates_names"],
        }
    )

In [None]:
links = pd.DataFrame(links)

In [None]:
links

Unnamed: 0,extracted_name,matches_names,candidates_names
0,Yandex,"[yandex, yandex inc, yandex llc, yandex oy, yandex.delivery]","[yandex laboratories, yandex nv, yandex toloka, yandex.ukraine llc, yandex.music]"
1,Microsoft Corporation's,"[microsoft corp, microsoft ltd, microsofts inc]","[microsoft research ltd, microsoft corporation (i) pvt ltd, microsoft studios, microsoft product development ltd, microsoft certified professional magazine, microsoft consulting services, microsoft licensing inc]"
2,Reality Labs,"[reality labs inc, reality lab as, reality lab ltd]","[realism labs ltd, realism labs inc, automatic labs inc, conversa labs, intelligent labs, inspired labs ltd, publish lab as]"
3,Knowledge,[knowledge ltd],"[knowledgereserve ltd, knowledgement ltd, knowledge and practice, knowledge group, knowledge solutions ltd, knowledgems ltd, knowledge circle, knowledgearc ltd, knowledgeforall ltd]"
4,Convergence Pharmaceuticals Ltd,"[convergence pharmaceuticals ltd, convergence pharmaceuticals inc, convergence pharmaceuticals holdings ltd]","[national pharmaceuticals corp, national pharmaceuticals inc, national pharmaceutical co ltd, convergent therapeutics inc, national pharmaceutical council, general pharmaceutical council, general pharmaceuticals ltd]"
...,...,...,...
367,A and E: Biography,[],"[a and e inc, a and e anodizing inc, a and e support ltd, a and e direct ltd, a and a, a and e partnership, a and e developments ltd, a and e international ltd, a and c e developments ltd, world e and c co ltd]"
368,Woodcrest,"[woodcrest ltd, woodcrest llc, woodcrest worldwide ltd, woodcrest holdings ltd, woodcrest partners llc, woodcrest services ltd, woodcrest development inc, woodcrest management, woodcrest services inc, woodcrest consulting ltd]",[]
369,Intel Corporation,"[intel corp, intel gmbh, intel, intel systems ltd, intel corporation (uk) ltd, intel international inc, intel investment ltd, intel co, intel tech ltd, intel service llc]",[]
370,Citigroup Global Markets Inc,"[citigroup global markets inc, citigroup global markets ltd, citigroup global markets llc, citigroup global markets holdings inc, citigroup global markets holdings gmbh, citigroup global markets services gmbh, citigroup global markets securities ltd, citigroup global markets international llc, citigroup markets inc, citigroup global markets (proprietary) ltd]",[]


In [None]:
pd.set_option("display.max_colwidth", None)
links[links["matches_names"].apply(lambda x: False if len(x) > 0 else True)].sample(19)

Unnamed: 0,extracted_name,matches_names,candidates_names
187,Pinners,[],"[pinners pvt ltd, pinnerscots ltd, pinner solutions ltd, pinner ltd, pinnerica ltd, pinner partners ltd, pinneraccountants ltd, pinner heights ltd, pinnergy ltd, pinnerod holding as]"
172,WORLD WRESTLING ENTERTAINMENTINC Network,[],"[world wrestling entertainment inc, world wrestling entertainment (international) ltd, world wrestling legends llc, wwe network llc, world championship wrestling inc, world of sport wrestling ltd, world war wrestling ltd, world wrestling entertainment canada inc, world entertainment network, wwes ltd]"
314,WORLD WRESTLING ENTERTAINMENTINC Performance Centers,[],"[world wrestling entertainment inc, world wrestling entertainment (international) ltd, world championship wrestling inc, world of sport wrestling ltd, world wrestling legends llc, world wrestling entertainment canada inc, wwe studios production inc, wwes ltd, world war wrestling ltd, euro-american wrestling group inc]"
339,Major Payment Institution,[],"[payment finance, payment systems ltd, major payment systems llc, payment holdings ltd, payment experts ltd, payment management systems ltd, payment resources ltd, national payment systems corp, payment industry insights ltd, payment pathways inc]"
293,Higon Information Technology Co Ltd,[],"[advanced hi-tech corp, ns hi tech co ltd, global hi-tech co ltd]"
194,the ICE Clearing Houses,[],"[ice homes ltd, ice house ltd, ice house rentals ltd, ice house partners llc, ice house consulting ltd, ice house associates ltd, ice projects ltd, icequest ltd, ice investment house ltd, ice clear uk ltd]"
82,Intuitive System Leasing,[],"[system leasing ltd, system leasing und finanz ag kuesnacht, systems leasing corp, intelligent leasing ltd, simply leasing ltd, automatic leasing inc, systems leasing trust no vii, intuitive systems corp, finance and leasing solutions ltd, interact leasing and finance ltd]"
191,Workplace Service Delivery,[],"[workplace service solutions ltd, workplace services inc, workplace services ltd, workplace systems ltd, workplaces that work, workplace solutions inc, workplace management services ltd, total workplace solutions ltd, workplace wellbeing services ltd, workplace project services ltd]"
31,ATMP JV,[],"[atmp consulting group llc, atmp properties ltd, atmp manufacturing community ltd, global atm solutions ltd]"
338,non US,[],"[inspired by us ltd, simple solution 4 u ltd, a and u capital partner ltd, a and u capital ltd, consult international ltd, consult worldwide ltd, consult group worldwide ltd, total specialties usa inc, a and u holdings ltd, occasion usa corp]"


In [None]:
"RITUXAN HYCELA".lower()

'rituxan hycela'

In [None]:
%%time
attribute_name = "companyprefix"
key_condition_expression = (
    f"{attribute_name} = :val1 and begins_with(normalizedname, :val2)"
)

# Clinic Sub Inc
response = etl.dynamodb.query(
    TableName="company",
    IndexName="companyprefix-normalizedname-index",
    KeyConditionExpression=key_condition_expression,
    ExpressionAttributeValues={
        ":val1": {"S": "cl"},
        ":val2": {"S": "clinic"},
    },
)
outs = pd.DataFrame(
    [item["normalizedname"]["S"] for item in response["Items"]],
    columns=["normalized_name"],
)

CPU times: user 53 ms, sys: 107 µs, total: 53.1 ms
Wall time: 118 ms


In [None]:
outs.query("normalized_name.str.contains('c s')")

Unnamed: 0,normalized_name
44,clinic for hospitals and therapeutic services
108,clinic scandinavia
109,clinic service center srl
110,clinic service corp
111,clinic service vorarlberg gmbh
112,clinic software ltd
113,clinic solutions ltd
114,clinic source
115,clinic spots
116,clinic success systems ltd


In [None]:
import json

with open("sample_data/org_links_1.json", "w") as obj:
    json.dump(org_links, obj)

In [None]:
all_relations.to_json("sample_data/relations_1.json")

In [None]:
sum([True if len(v["matches"]) > 0 else False for k, v in org_links.items()])

47

etl.run_que

In [None]:
data = etl.run_query(
    "SELECT * FROM legacyevents.filingtext_ecomap_sec_filing_text_predictions_parquet"
)

Running query:
 SELECT * FROM legacyevents.filingtext_ecomap_sec_filing_text_predictions_parquet
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/e5b08301-f74e-4121-8717-6c4e0ccad8bc.csv
filename: e5b08301-f74e-4121-8717-6c4e0ccad8bc.csv
Query results shape: (149014, 7)


In [None]:
data["filedasofdate"].unique()

array(['2023-01-12', '2023-01-05', '2023-01-24', '2023-01-31',
       '2023-01-13', '2023-01-30', '2023-01-26', '2023-01-27',
       '2023-01-10', '2023-01-17', '2023-01-06', '2023-01-25',
       '2023-01-19', '2023-01-20', '2023-01-04', '2023-01-03',
       '2023-01-18', '2023-01-09', '2023-01-11', '2023-05-01',
       '2023-04-18'], dtype=object)

In [None]:
reset_all = True
if reset_all:
    log_frame = etl.load_table("logs")
    log_frame.loc[:, "relation_estimations"] = 0
    wr.s3.to_parquet(
        df=log_frame,
        database=etl.database,
        table=etl.logs_table,
        dataset=True,
        path=etl.database_path + "/" + etl.logs_table,
        partition_cols=["file_id"],
        mode="overwrite_partitions",
        boto3_session=etl.session,
    )

In [None]:
import json
import logging
import boto3
import pandas as pd
import awswrangler as wr

logger = logging.getLogger(__name__)


def predict_fn(input_data: dict, model: tuple) -> dict:
    """
    Predicts relations between entities in the input data using a trained
    relation extraction model and an entity matching model.

    Args:
        input_data: A dictionary containing input data to be processed.
        model: A tuple containing a trained relation extraction model and
            an entity matching model.

    Returns:
        A dictionary containing the processed input data.
    """
    # Print input data type and content
    print("Type of input:", type(input_data))
    print("Input data:", input_data)

    # Unpack the models
    relation_extractor, entity_matcher = model
    # Set the job input data
    etl.job = input_data

    # Update ETL starting status
    etl.update_starting(CURRENT_STEP, 1, add=False)

    # Block job files
    file_ids = etl.block_job_files(
        task=CURRENT_STEP,
        number_files=etl.config["job"][f"{CURRENT_STEP}_max_files"],
        distribute=False,
    )
    all_reports = {}
    # Process the input data
    while file_ids:
        # Load data with filters
        query_data = etl.load_with_filter(
            etl.relations_table, col="file_id", condition="isin", value=file_ids
        )

        # Filter for valid data
        valid_idx = query_data.query("supply_label == 1").index
        valid_data = query_data.iloc[valid_idx].copy()

        # Convert JSON strings to Python objects
        valid_data["spans"] = valid_data.spans.apply(json.loads)
        valid_data["org_groups"] = valid_data.org_groups.apply(json.loads)
        valid_data["aliases"] = valid_data.aliases.apply(json.loads)

        # Predict relations
        predictions = relation_extractor.predict_relations(
            sentences=valid_data["sentence"].tolist(),
            ent="ORG",
            spans=valid_data["spans"].tolist(),
            org_groups=valid_data["org_groups"].tolist(),
            aliases=valid_data["aliases"].tolist(),
        )

        # Fill NaN values in predictions
        predictions.relations.fillna({}, inplace=True)

        # Add predicted relations to valid_data
        valid_data.loc[:, "relations"] = predictions["relations"].tolist()

        org_links = match_companies(
            predictions=all_relations,
            entity_matcher=entity_matcher,
            etl_worker=etl,
            lookup_table="company",
            index_column="companyprefix-normalizedname-index",
            attribute_name="companyprefix",
            prefix_len=2,
            sort_len=5,
            normalized_column="normalizedname",
            id_column="rgid",
            database_type="dynamodb",
            match_thresh=0.95,
            cand_thresh=0.80,
            top_k=5,
        )
        relations_items = process_relations(all_relations, entity_matcher, org_links)
        failed = etl.batch_write_dynamodb_items(relations_items, "Predictions")
        # Convert relations to JSON strings
        query_data.loc[valid_idx, "relations"] = list(
            map(
                lambda x: json.dumps(x, default=float_format),
                valid_data["relations"].tolist(),
            )
        )

        # Create logs
        log_frame = etl.create_logs(query_data, "relations")

        if query_data.shape[0] > 0:
            # Ingest new data
            wr.s3.to_parquet(
                df=query_data,
                database=etl.database,
                table=etl.relations_table,
                dataset=True,
                path=etl.database_path + "/" + etl.relations_table,
                partition_cols=list(etl.relations_partitions.keys()),
                mode="overwrite_partitions",
                boto3_session=etl.session,
            )

            # Update logs
            wr.s3.to_parquet(
                df=log_frame,
                database=etl.database,
                table=etl.logs_table,
                dataset=True,
                path=etl.database_path + "/" + etl.logs_table,
                partition_cols=list(etl.logs_partitions.keys()),
                mode="overwrite_partitions",
                boto3_session=etl.session,
            )

            # Log success and failure
            logger.info(f"Ingested files with ids: {file_ids} successfully\u2705")
            success = set(log_frame["file_id"])
            failed = set(file_ids) - success
            update_response = etl.add_results(CURRENT_STEP, FOLLOWING, success, failed)

        else:
            logger.info("Didn't find any valid sentence for supply classification")

        # Block job files
        file_ids = etl.block_job_files(
            task=CURRENT_STEP,
            number_files=etl.config["job"][f"{CURRENT_STEP}_max_files"],
            distribute=True,
        )

    # Return input data
    return input_data

Type of input <class 'dict'>
Input data: {'TaskToken': 'AQBwAAAAKgAAAAMAAAAAAAAAAS+nYaB28dpTtEEDpf8s2G0fEfkYP3GRdk5rrBYpZ4dqLiTkkZ4c9zpQhUvhhZTwkYae2F6FjuqtHIJuvf8pkXV8XhAR8cRQF5DFgYQP9DitQisxmasIjm+oqI7AlmNZV9XwCHbw951p7VbJes4RQIljPUe8cPaFlBdFsUt3znOMIvzW8DoBHAwE6fA4gOFxg/H5u5wyyhZenc+gLgndrrrf/ZSb/5FGGcNACCYwtrSoqiEDEiOHpOiMyZWhxbAGJPwKyU4Cjbnv0OPKEZP8lC3Ee2at7ovfcPcM8zezSLUYs2MfA0dL6q4vwY2okBxWqSZgFqQ2F5PgxQOUUSgy/9RbDw7CzGO5/1FqB5AU5DsIvlB8T9JjvPJpxxvC+g956Fufg6pHQgzFLS0wtSYneS9HnJ9+Ly5RF0Xhpbb1WeD4UgL0ZWbQ4U/t7G+XUJklPyfTgTmpKX9+c+dvbtnPQZnERGwqZaBFyrOf8PnZbJCefMlUou5sW8z/SaGYQA5o02fDPJZ/XsfNRC/EvOPV/BIiJKf47Nwp3xWpDN3lh6NG5vjvMEMWLe/YKLzlJvk1g+2sGFrkMGLBpQbGrWAY7rO/bwYOYxqnCyh8kL+DEzxfGC/c', 'TaskType': 'local-inference', 'FileIds': ['0001493152-22-009683']}
2023-05-06 12:17:21,599 — 🌍 ETL — INFO — Read adaptor values
Running query:
 SELECT file_id
                        FROM etl_logs
                        WHERE relation_estimations <> valid_for_re 
self.bucket: ecomap-dl-pipe

mutate text: 100%|██████████| 144/144 [00:00<00:00, 6828.33it/s]
05/06/2023 12:17:24 PM [INFO]: Tokenizing data...
tokenization: 100%|██████████| 144/144 [00:00<00:00, 847.51it/s]
tags positioning: 100%|██████████| 144/144 [00:00<00:00, 18473.15it/s]



Invalid rows/total: 0/144


100%|██████████| 3/3 [00:01<00:00,  2.45it/s]

2023-05-06 12:17:26,089 — 🌍 ETL — INFO — Reseting Adaptor Values to default
2023-05-06 12:17:26,130 — 🌍 ETL — INFO — Query DyanmoDB to find the `tri` match companies to search for links



  return asarray(a).ndim


2023-05-06 12:17:27,001 — 🌍 ETL — INFO — Found 10035 item with the `tri` indecies


05/06/2023 12:17:27 PM [INFO]: Encoding embeddings for sentences...
05/06/2023 12:17:28 PM [INFO]: Building index...
05/06/2023 12:17:28 PM [INFO]: Finished


Running query:
 SELECT * FROM etl_logs where file_id IN ('0000002488-22-000016','0000025232-22-000005')
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/dd45f0b6-e515-437f-b2e6-0475b44f6d16.csv
filename: dd45f0b6-e515-437f-b2e6-0475b44f6d16.csv
Query results shape: (2, 7)
2023-05-06 12:17:32,583 — 🌍 ETL — INFO — Read adaptor values
2023-05-06 12:17:32,605 — 🌍 ETL — INFO — Adaptor Updated
Running query:
 SELECT * FROM etl_logs where file_id IN ('0000002488-22-000016','0000025232-22-000005')
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/ba9cdbf1-f0f9-4775-a766-b6d46943339c.csv
filename: ba9cdbf1-f0f9-4775-a766-b6d46943339c.csv
Query results shape: (2, 7)
2023-05-06 12:17:36,241 — 🌍 ETL — INFO — Reseting Adaptor Values to default
2023-05-06 12:17:36,280 — 💫 RE JOB — INFO — Ingested files with ids: ['0000025232-22-000005', '0000002488-22-000016'] successfully✅


In [None]:
%%time
predict_fn({}, (relation_extractor, entity_matcher))

Type of input <class 'dict'>
Input data: {}
2023-05-06 11:24:55,535 — 🌍 ETL — INFO — Read adaptor values
Running query:
 SELECT file_id
                            FROM etl_logs
                            WHERE relation_estimations <> valid_for_re 
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/0b716d03-5196-4b28-a0a0-e4b5288c3728.csv
filename: 0b716d03-5196-4b28-a0a0-e4b5288c3728.csv
Query results shape: (0, 1)
2023-05-06 11:24:57,470 — 🌍 ETL — INFO — Adaptor Updated
2023-05-06 11:24:57,507 — 💫 RE JOB — INFO — Estimating files with ids []
CPU times: user 329 ms, sys: 0 ns, total: 329 ms
Wall time: 1.97 s


In [None]:
# sa = hashlib.sha256(s['A'].to_json().encode()).hexdigest()

In [None]:
etl.load_table("logs")

Unnamed: 0,datapoints,ner_estimations,valid_for_supply,supply_estimations,valid_for_re,relation_estimations,file_id
0,3872,3872,413,413,25,25,0000005513-22-000030
1,2051,2051,515,515,116,116,0000908937-22-000007


In [None]:
query_data = etl.load_athena_with_filter(
    etl.relations_table, col="file_id", condition="isin", value=file_id, columns="*"
)

Running query:
 SELECT * FROM sec_relations where file_id IN ('0000005513-22-000030','0000908937-22-000007')
self.bucket: ecomap-dl-pipeline
results_file_prefix: queries/88e461ba-a6a2-4de6-a95d-56aa8a90dbc5.csv
filename: 88e461ba-a6a2-4de6-a95d-56aa8a90dbc5.csv
Query results shape: (928, 14)


In [None]:
query_data

In [None]:
from io import StringIO
import re
import string

# Read file from S3 as a string
obj = etl.s3.Object("ecomap-dl-pipeline", "glue-db/2022-12-01-company.tsv")
file_content = obj.get()["Body"].read().decode("utf-8")
# Convert string to pandas dataframe
df = pd.read_csv(StringIO(file_content), sep="\t")

df["tri"] = df["normalized_name"].apply(
    lambda x: re.sub(f"[{re.escape(string.punctuation)}]", "", x)
    .lower()
    .replace(" ", "")[:3]
)


data = df[["inferess_entity_id", "tri", "normalized_name", "reference_name"]]

if False:
    wr.catalog.create_parquet_table(
        database=etl.database,
        table="inferess_companies",
        path=etl.database_path + f"/inferess_companies",
        columns_types={
            "inferess_entity_id": "int",
            "tri": "string",
            "normalized_name": "string",
            "reference_name": "string",
        },
        compression="snappy",
        parameters={"source": "s3"},
        boto3_session=etl.session,
    )

if False:
    wr.s3.to_parquet(
        df=data,
        database=etl.database,
        table="inferess_companies",
        dataset=True,
        path=etl.database_path + f"/inferess_companies",
        mode="overwrite",
        boto3_session=etl.session,
    )
inferess_companies = etl.load_athena_with_filter(
    table="inferess_companies",
    col="tri",
    condition="isin",
    value=data["tri"].unique().tolist(),
)