# agg_utils.py

In [1]:
from tqdm import tqdm
from collections.abc import Iterable as iterable
from typing import List, Text,Tuple, Iterable, Dict
from itertools import chain
import random
from collections.abc import Iterable as iterable

# method to add Label based on the relations columns
def sc_label_from_relations(relation_tuples, main_relations):
    if not relation_tuples:
        return 0
    if len(relation_tuples) == 0:
        return 0
    for relation_tuple in relation_tuples:
        if len(relation_tuple) != 3:
            continue
        elif  relation_tuple[1] in main_relations:
            return 1
    return 0

def check_relation_tuples(relations: List[Iterable]) -> bool:
    """
    Check if the relations list is in the correct format.
    """
    if not all(isinstance(relation, iterable) and len(relation) == 3 for relation in relations):
        return False
    return True

def return_possible_pairs(ids_set:List):
    return list(zip(
            list(chain(*[[ids_set[x]]*(len(ids_set)-1-x) \
                 for x in range(len(ids_set))])),
            list(chain(*[[ids_set[i] for i in range(x+1, len(ids_set))]\
                 for x in range(len(ids_set))]))))


def get_other_relations(ids2org):
    """
    Returns a list of other relations between companies based on the dictionary of company groups passed as input.
    The maximum number of other relations is determined by the max_others parameter.

    @params
    ids2org
    org_groups (dict): A dictionary with company ids to map each group of ents.
    max_others (int): The maximum number of other relations to return.

    @returns
    --------
    list: A list of other relations between companies.
    """
    # Sort company keys (IDs) in ascending order
    comp_keys = sorted(ids2org.keys())

    # Generate all possible pairs of IDs
    other_ids = set(return_possible_pairs(comp_keys))

    # Create 'other_relations' tuples for each pair of IDs and return
    return [(ids2org[pair[0]][0] , 'other', ids2org[pair[1]][0]) for pair in other_ids]



def eval_relations(relation, default=[]):
    try:
        re = eval(relation)
    except:
        re = default
    return re

def eval_relation_data(output, relations_map):
    """
    Evaluate and process relation data in the output DataFrame.

    @params:
    - output: DataFrame containing relation data.
    - relations_map: Dictionary mapping relations.

    @returns:
    - None (modifies 'output' DataFrame in place).
    """
    # Resort sme_relations to unify the relations directions
    if not isinstance(output['sme_relations'].iloc[0], list):
        tqdm.pandas(desc="Eval sme_relations")
        output['sme_relations'] = output['sme_relations'].progress_apply(eval)

    # Evaluate and process 'relations' column if it exists
    if 'relations' in output.columns:
        if not isinstance(output['relations'].iloc[0], list):
            tqdm.pandas(desc="Eval relations")
            output['relations'] = output['relations'].progress_apply(eval_relations)
    else:
        # If 'relations' column doesn't exist, create it with None values
        output['relations'] = None
    
    # Evaluate and process 'org_groups' column
    if not isinstance(output['org_groups'].iloc[0], dict):
        tqdm.pandas(desc="Eval org_groups")
        output['org_groups'] = output['org_groups'].progress_apply(eval_relations, default={})
    
    # Resort sme_relations based on the provided 'relations_map'
    tqdm.pandas(desc="Resort sme relations")
    output['sme_relations'] = output['sme_relations'].progress_apply(lambda x: \
                              resort_relation((x[0], x[1], x[2]),
                                            relations_map))

def resort_relation(relation_tuple:Tuple, relations_map:Dict)->Tuple:
    """Resorts a tuple to match the order of the main relation."""
    c1, relation, c2 = relation_tuple
    return [c1, relation, c2]\
            if not relations_map.get(relation)\
            else [c2, relations_map.get(relation), c1]



# data_aggregator.py

In [15]:
from pathlib import Path
import sys
from typing import Tuple, List, Text, Dict
from collections import defaultdict
from itertools import chain
import yaml
import pandas as pd
import random
class DataAggregator():
    '''Aggregating the annoated files from the LLMs from multiple files        
    '''
    def __init__(self,
                 dataset_name,
                 output_dir,
                 entity_matcher,
                 relation_direction={"customer":"supplier"},
                 filer_names=['entity_1', 'firstEntity', 'filer' ],
                 relations_key = 'relations',
                 text_col= 'sentence',
                 lm_name="en_core_web_trf",
                 lm_type='spacy',
                 ):

        # Define the basic information about the files which will be used
        self._dataset_name = dataset_name
        self.filer_names = filer_names
        self.text_col= text_col
        self.relations_key=relations_key
        

        self.relation_direction = relation_direction
        self.main_relations = list(self.relation_direction.keys()) + list(self.relation_direction.values())
        self.output_dir = output_dir

        with open(dataset_name) as o:
            self.data_files = yaml.safe_load(o)
        
        # Construct language model for entity extraction and matching
        if lm_type=='spacy':
            from src.language_model.spacy_loader import SpacyLoader
            self.lm = SpacyLoader(lm_name,
                                  entity_matcher= entity_matcher,
                                  load_matcher=True)

    def read_and_prepare_datafiles(self):
        # read file and store each into it's coresponding key_value
        for key in list(self.data_files.keys()):
            print(key,'\n--------------')
            self.data_files[key]['data'] = pd.read_excel(self.data_files[key]['dir'])
            self.data_files[key]['data'] = self.process_labeled_data(self.data_files[key]['data']).reset_index(drop=True)
            self.data_files[key]['data']['idx'] = ["{}_{}".format(key, i)\
                                            for i in range(self.data_files[key]['data'].shape[0])]
            self.data_files[key]['data']['Label'] = self.data_files[key]['data']['relations']\
                                                .apply(lambda x : sc_label_from_relations(x, self.main_relations))        

    def process_labeled_data(self,
                             data:pd.DataFrame): 
        """
        Process labeled data by creating 'org_groups' if not existent and deserializing data.

        @params:
        - data: DataFrame containing labeled data.
        - text_col: Column containing text data.
        - relation_direction: Dictionary mapping relations.

        @returns:
        - None (modifies 'data' DataFrame in place).
        """
        # Create 'org_groups' if not existent
        if not all([x in data.columns for x in ['spans', 'org_groups']]):
            sents, spans, org_groups, aliases = self.lm.predictor(data[self.text_col])
            data[self.text_col] = sents
            data.loc[:, self.text_col] = sents
            data.loc[:, "spans"] = spans
            data.loc[:, "org_groups"] = org_groups
        # Deserialize data and evaluate relations
        eval_relation_data(data, self.relation_direction)
        return data


    def create_re_dataset(self,
                          data,
                          threshold:float=0.9,
                          max_others:int=3,
                          basic_columns:list=[],
                          only_filer=False,
                          ) -> pd.DataFrame:
        """
        Create a relation extraction dataset.

        @params
        -------
        - data: A pandas dataframe containing the following keys:
            * sentence: A string representing a sentence containing relevant text.
            * relations: A list of tuples representing the relationship between companies.

        @returns
        --------
        - dataset: A pandas dataframe containing the following columns:
        - idx: A unique identifier for the datapoint.
        - sentence: A string representing the sentence containing relevant text.
        - entity_2: A string representing the second entity in the relation.
        - relation: A string representing the relation between the two entities.
        - entity_1: A string representing the first entity in the relation.
        """

        # Apply the `extract_relations_from_llm` function to each datapoint in the data dataframe
        tqdm.pandas(desc="extract relations")
        results = data.progress_apply(lambda x: self.extract_relations_from_llm(datapoint=x,
                                threshold=threshold,
                                max_others=max_others,
                                only_filer = only_filer), axis=1)

        # Create new columns in the output dataframe to store the results of the `extract_relations_from_llm` function
        data['llm_relations']  = results.apply(lambda x : x[0]).tolist()
        data['other_relations']  = results.apply(lambda x : x[1]).tolist()
        relation_columns = ['llm_relations', 'other_relations']
        # Create a list of all possible pair-wise combinations of the values in 'ids2org', and randomly choose 5 of those combinations
        columns = relation_columns + [self.text_col, self.relations_key] + basic_columns
        re_dataset = []
        for _, row in data[columns].iterrows():
            row = row.to_dict()
            for r_column in relation_columns:
                # Iterate over relations and ingest row for each relation
                for relation_tuple in row[r_column]:
                    row['entity_2'] = relation_tuple[0]
                    row['relation'] = relation_tuple[1]
                    row['entity_1'] = relation_tuple[2]
                    re_dataset.append(dict(row))

        # Return a pandas dataframe containing the relation extraction dataset
        dataset = pd.DataFrame(re_dataset)[[ self.text_col,
                                            'entity_2',
                                            'relation',
                                            'entity_1'] + basic_columns]
        return dataset

        
    def extract_relations_from_llm(self,
                                   datapoint,
                                   threshold:float=0.9,
                                   only_filer = False,
                                   max_others=3):
        """
        Create a dataset for relation extraction training.
        @params
        -------
        - datapoint: A dictionary containing the following keys:
             * org_groups: A dictionary of company names associated with an integer identifier.
             * relations: A list of tuples representing the relationship between companies.
        - threshold: The similarity threshold for matching company names.

        @returns
        --------
        - llms_relations: A list of tuples representing the relationships between companies that were successfully matched.
        - other_relations: A list of tuples representing the relationships between companies that were not matched.

        @raises
        -------
        - ValueError: If the relations list in the datapoint is invalid.
        """
        r_others = True
        # establish org_groups
        group2id = datapoint['org_groups']
        id2group = defaultdict(list)
        for k,v in group2id.items():
            id2group[v].append(k)

        # define llms relations
        relations = datapoint[self.relations_key]

        # build index for org_groups
        if len(group2id) > 0:
            self.lm.entity_matcher.build_index(list(group2id.keys()))

        # Assert the relations on the right format
        if not check_relation_tuples(relations):
            raise ValueError("Invlid relations list on the datapoint, must be List[Tuple[Text, Text, Text]]")
        # Collect all companies mentioned in the relations and create a dictionary with each unique company as a key
        llms_companies = set()
        if isinstance(relations, list):
            llms_companies = list(set(chain(*[[x[0], x[2]] for x in relations])))

        # match the llm_companies to assign id according to group2id
        llms_co_matches = self.lm.entity_matcher.search(llms_companies, threshold=threshold, top_k=2)\
                          if len(llms_companies) > 0 else []
        # Create map the merge org_groups with llm_companies
        llms_ids = {}
        for co_match, llm_company in zip(llms_co_matches, llms_companies):
            # If match found
            if len(co_match) > 0:
                llms_ids[llm_company] = group2id[co_match[0][0]]

            # check if llm_company valid and add it to 
            elif llm_company in datapoint[self.text_col]:
                group2id[llm_company] = max(id2group.keys()) + 1 if len(id2group.keys()) > 0 else 1
                id2group[group2id[llm_company]] = [llm_company]
                llms_ids[llm_company] = group2id[llm_company]

        # Create a dictionary mapping IDs to company names
        llms_names = {k: id2group[v][0] for k, v in llms_ids.items()}

        # get all possible paris from the llms_ids
        availabel_relations = return_possible_pairs(sorted(set(group2id.values())))

        # Define all the exist relations from LLM with pairs tuples
        exist_relations = []
        llms_relations = []
        if isinstance(relations, list):
            for relation in relations:
                c1, c1_name = relation[0], llms_names.get(relation[0])
                c2, c2_name = relation[2], llms_names.get(relation[2])
                c1_id = llms_ids.get(c1)
                c2_id = llms_ids.get(c2)
                if None in [c1_id, c2_id]:
                    continue
                llms_relations.append((c1_name, relation[1], c2_name))
                exist_relations.append(tuple(sorted([c1_id, c2_id])))
        # Define all possible other relation pairs within the sentence
        other_ids = list(set(availabel_relations) ^ set(exist_relations))
        
        # Create relations tuples for other_relations
        other_relations = [(id2group[pair[0]][0],
                            'other',
                            id2group[pair[1]][0]) \
                           for pair in other_ids]
        # If llm return nothign, get all possible relations as `other`
        if len(llms_relations) == 0 and len(other_relations) == 0:
            other_relations = get_other_relations(id2group)
        # If only filer include only relations with filer
        if only_filer:
            # Find Filer name as mentioned on the sentence
            filer_column = list(set(datapoint.keys()).intersection(self.filer_names))
            filer_column = filer_column[0] if len(filer_column) > 0 else None
            given_filer = datapoint[filer_column] if filer_column else None
            if given_filer:
                filer_name = group2id.get(given_filer)
                if not filer_name and len(group2id) > 0:
                    filer_scope = list(group2id.keys())
                    filer_sim = self.lm.entity_matcher.similarity(given_filer,filer_scope)
                    if filer_sim.max() > threshold:
                        filer_name = filer_scope[filer_sim.argmax()]
            if filer_name:
                llms_relations = list(filter(None, [x if filer_name in [x[0], x[2]]\
                                                    else None  for x in llms_relations] ))
                other_relations = list(filter(None, [x if filer_name in [x[0], x[2]] \
                                                     else None  for x in other_relations] ))
        # Based on max_other return random sample
        other_relations = random.sample(other_relations, min(len(other_relations), max_others))
        return llms_relations, other_relations


In [18]:
data_aggregator.main_relations

['customer', 'supplier']

In [16]:
import sys
from pathlib import Path
src_dir = Path.cwd().parent
sys.path.append(str(src_dir))
data_aggregator = DataAggregator(dataset_name= src_dir / 'data/config/llm_aligned_0_1_huge_complex.yaml',
                                 output_dir= src_dir /'data/raw/aggregated_data.json',
                                 entity_matcher= str(src_dir/"artifacts/matcher_model/")
                                 )


2023-12-18 08:51:42,049 — 🌌 spaCy — INFO — Language model used is en_core_web_trf
2023-12-18 08:51:42,050 — 🌌 spaCy — INFO — spaCy Work On GPU


In [14]:
data_aggregator.read_and_prepare_datafiles()

all_0 
--------------


Eval sme_relations: 100%|██████████| 877/877 [00:00<00:00, 112221.75it/s]
Eval relations: 100%|██████████| 877/877 [00:00<00:00, 86503.88it/s]
Resort sme relations: 100%|██████████| 877/877 [00:00<00:00, 608503.66it/s]


all_1 
--------------


Eval sme_relations: 100%|██████████| 1767/1767 [00:00<00:00, 113383.85it/s]
Eval relations: 100%|██████████| 1767/1767 [00:00<00:00, 68599.97it/s]
Resort sme relations: 100%|██████████| 1767/1767 [00:00<00:00, 439112.17it/s]


all_other 
--------------


Eval sme_relations: 100%|██████████| 111/111 [00:00<00:00, 83931.45it/s]
Eval relations: 100%|██████████| 111/111 [00:00<00:00, 234897.95it/s]
Resort sme relations: 100%|██████████| 111/111 [00:00<00:00, 277288.71it/s]


huge_1 
--------------


Eval sme_relations: 100%|██████████| 688/688 [00:00<00:00, 113252.79it/s]
Eval relations: 100%|██████████| 688/688 [00:00<00:00, 70461.52it/s]
Resort sme relations: 100%|██████████| 688/688 [00:00<00:00, 524002.39it/s]


huge_1_complex 
--------------


Eval sme_relations: 100%|██████████| 1550/1550 [00:00<00:00, 112284.69it/s]
Eval relations: 100%|██████████| 1550/1550 [00:00<00:00, 76231.46it/s]
Resort sme relations: 100%|██████████| 1550/1550 [00:00<00:00, 562677.10it/s]


In [15]:
# read file and store each into it's coresponding key_value
for key in list(data_aggregator.data_files.keys()):
    print(key,'\n--------------' )
    dataset = data_aggregator.create_re_dataset(data=data_aggregator.data_files[key]['data'].reset_index(drop=True),
                           threshold=0.9,
                           max_others=2,
                           basic_columns=['Label', 'concept_class', 'idx', 'org_groups', 'spans', 'sme_relations'],
                           only_filer=True)
    data_aggregator.data_files[key]['dataset'] = dataset

all_0 
--------------


extract relations: 100%|██████████| 877/877 [02:19<00:00,  6.30it/s]


all_1 
--------------


extract relations: 100%|██████████| 1767/1767 [04:49<00:00,  6.11it/s]


all_other 
--------------


extract relations: 100%|██████████| 111/111 [00:17<00:00,  6.39it/s]


huge_1 
--------------


extract relations: 100%|██████████| 688/688 [01:52<00:00,  6.11it/s]


huge_1_complex 
--------------


extract relations: 100%|██████████| 1550/1550 [04:15<00:00,  6.07it/s]


In [17]:
pd.set_option("display.max_colwidth", None)

In [18]:
all_dataset= pd.concat([data_aggregator.data_files[k]['dataset'] for k in data_aggregator.data_files.keys()], axis=0)

In [30]:
all_dataset.shape

(6497, 11)

In [19]:
concept_class_remapping = {'agreement': ['supply_purchase_agreement',
               'services agreement',
               'agreement_and_partnership'],
'licensing_and_ip': ['royalties',
                     'licensing_and_ip',
                     'legal_and_regulatory'],
'supply_chain': ['supply_chain', 'product_related'],
'revenue': ['revenue', 'royalties'],
'real_estate': ['real_estate'],
'financial_statements': ['investment_related', 'financial_statements'],
'other': ['unknown'] 
}

# reverse the mapping
reverse_concept_class_remapping = {}
for k,v in concept_class_remapping.items():
    for x in v:
        reverse_concept_class_remapping[x] = k  
        

In [20]:
# remap the concept class
all_dataset["concept_class_remapped"] = all_dataset["concept_class"].apply(lambda x: reverse_concept_class_remapping.get(x, x))

In [21]:
# don't include other class
all_dataset = all_dataset[all_dataset["concept_class_remapped"] != "other"]


In [22]:
all_dataset["concept_class_remapped"].value_counts()

revenue                 3263
supply_chain            1102
agreement                875
licensing_and_ip         601
financial_statements     523
real_estate              133
Name: concept_class_remapped, dtype: int64

In [24]:
from src.utils.preprocess import word_search
founded1 = all_dataset.apply(lambda x : word_search(x['entity_1'], x['sentence']), axis =1 )
founded2 = all_dataset.apply(lambda x : word_search(x['entity_2'], x['sentence']), axis =1 )

In [25]:
(founded1.apply(len) == 0).sum()==0,(founded2.apply(len) == 0).sum()==0

(True, True)

# test_re_dataset_creation.py

In [17]:
from typing import List, Tuple, Text, Iterable
from itertools import chain
from functools import partial
from collections import defaultdict
from typing import List, Tuple, Text
import random
from collections.abc import Iterable as iterable


def test_check_relation_tuples():
    assert check_relation_tuples([]) == True
    assert check_relation_tuples([(1, 2, 3)]) == True
    assert check_relation_tuples([(1, 2)]) == False
    assert check_relation_tuples([(1, 2, 3), (4, 5, 6), (7, 8, 9)]) == True

def test_return_possible_pairs():
    assert return_possible_pairs([1, 2, 3]) == [(1, 2), (1, 3), (2, 3)]
    assert return_possible_pairs([]) == []


test_check_relation_tuples()
test_return_possible_pairs()


def test_point():
    return {
 'filer': 'ADVANCED MICRO DEVICES INC corp',
 'sentence': 'In addition, five customers, including Sony and Microsoft, accounted for approximately 95% of the net revenue attributable to ADVANCED MICRO DEVICES Inc Enterprise, Embedded and Semi Custom segment',
 'relations': [
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'],
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft'],
              ],
 'org_groups': {'ADVANCED MICRO DEVICES Inc': 0, 'Microsoft': 1, 'Sony': 2}}


############################### Test only_filer=true ###############################

llms_relations, other_relations = data_aggregator.extract_relations_from_llm(test_point(),
                            threshold= 0.9 ,
                            only_filer= True,
                            max_others= 1) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'),
                          ('ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft')]
assert other_relations == []

############################### Test only_filer=False & max_other=1 ###############################

llms_relations, other_relations = data_aggregator.extract_relations_from_llm(test_point(),
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 1) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'),
                          ('ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft')]
assert other_relations == [('Microsoft', 'other', 'Sony')]

############################### Test Changing Names slightly ###############################
datapoint = test_point()
datapoint['relations'] = [
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony Inc'],
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft inc'],
              ]
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 1) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'),
                          ('ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft')]
assert other_relations == [('Microsoft', 'other', 'Sony' )]


############################### Test When All Others ###############################
datapoint = test_point()
datapoint['relations'] = [
  ['ADVANCED MICRO DEVICES Inc', 'other', 'Sony Inc'],
  ['ADVANCED MICRO DEVICES Inc', 'other', 'Microsoft inc'],
              ]
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 1) 
assert llms_relations ==[('ADVANCED MICRO DEVICES Inc', 'other', 'Sony'),
 ('ADVANCED MICRO DEVICES Inc', 'other', 'Microsoft')]
assert other_relations == [('Microsoft', 'other', 'Sony')]

############################### Test When All Others & Only Filer ###############################

datapoint = test_point()
datapoint['relations'] = [
  ['ADVANCED MICRO DEVICES Inc', 'other', 'Sony Inc'],
  ['ADVANCED MICRO DEVICES Inc', 'other', 'Microsoft inc'],
              ]
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= True,
                            max_others= 0) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'other', 'Sony'),
 ('ADVANCED MICRO DEVICES Inc', 'other', 'Microsoft')]
assert other_relations == []

############################### Test Adding LLM Relation Not Exist On OrgGroups With Only Filer ###############################
datapoint = test_point()
datapoint['relations'] = [
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony Inc'],
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft inc'],
  ['MISTAKE', 'supplier', 'WRONG NAME'],
    
              ]
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 1) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'),
 ('ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft')]

assert other_relations == [('Microsoft', 'other', 'Sony')]


#########Test Adding LLM Relation Not Exist On OrgGroups Without Only Filer & max_other=2  ###############################

datapoint = test_point()
datapoint['sentence'] = 'MISTAKE is supplier WRONG NAME of In addition, five customers, including Sony and Microsoft, accounted for approximately 95% of the net revenue attributable to ADVANCED MICRO DEVICES Inc Enterprise, Embedded and Semi Custom segment'
datapoint['relations'] = [
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony Inc'],
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft inc'],
  ['MISTAKE', 'supplier', 'WRONG NAME'],
              ]
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 2) 
assert llms_relations == [('ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'),
 ('ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft'),
 ('MISTAKE', 'supplier', 'WRONG NAME')]
assert len(other_relations) == 2


#########Test having no relation  ###############################

datapoint = test_point()
datapoint['sentence'] = 'MISTAKE is supplier WRONG NAME of In addition, five customers, including Sony and Microsoft, accounted for approximately 95% of the net revenue attributable to ADVANCED MICRO DEVICES Inc Enterprise, Embedded and Semi Custom segment'
datapoint['relations'] = []
llms_relations, other_relations = data_aggregator.extract_relations_from_llm(datapoint,
                            threshold= 0.9 ,
                            only_filer= False,
                            max_others= 2) 
assert llms_relations == []
assert len(other_relations) == 2


In [11]:
def test_point():
    return {
 'filer': 'ADVANCED MICRO DEVICES INC corp',
 'sentence': 'In addition, five customers, including Sony and Microsoft, accounted for approximately 95% of the net revenue attributable to ADVANCED MICRO DEVICES Inc Enterprise, Embedded and Semi Custom segment',
 'relations': [
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Sony'],
  ['ADVANCED MICRO DEVICES Inc', 'supplier', 'Microsoft'],
              ],
 'org_groups': {'ADVANCED MICRO DEVICES Inc': 0, 'Microsoft': 1, 'Sony': 2}}


############################### Test only_filer=true ###############################

llms_relations, other_relations = data_aggregator.extract_relations_from_llm(test_point(),
                            threshold= 0.9 ,
                            only_filer= True,
                            max_others= 1) 