# **_Basic Dependencies_**

In [7]:
import os
import sys
import pandas as pd
import openai
import pandas as pd
from glob import glob
from tqdm import tqdm
from itertools import chain
import json
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KE")

current_path = Path.cwd()
src_dir = current_path.parent
sys.path.append(str(src_dir))
# import annotation methods
from src.labels_generator.instructor_util import (generate_relations,relation_search,resort_relation, get_completion,
                                  deserialize_json_dict2,
                                  generate_relations_with_explanation,
                                  relations_tupled_2,
                                 create_sorted_relation)

# Load matcher
from src.matcher.core import SimCSE_Matcher
matcher = SimCSE_Matcher(str(src_dir/ 'artifacts/matcher_model'))


replaces = {"sentence": "{sentence}"}
# Replace the keys with values for unified relation direction
relations_map = {"customer": "supplier"}

### **_`LLM Annotator`_**

In [8]:
# llm_annotator.py
import os
import openai
from glob import glob
from typing import List, Tuple, Union
from typing import Dict, Text
import json
import random
from collections import defaultdict
import pandas as pd
import time
import numpy as np
import yaml
import re
from colorama import Fore
from tqdm import tqdm
from pathlib import Path
import sys 
src_dir = Path.cwd().parent
sys.path.append(str(src_dir)) 
from src.matcher.core import SimCSE_Matcher
from src.utils import get_logger, dotdict



dataset_columns = ['sentence', 'org_groups']
valid_models =["gpt-3.5-turbo"]

main_relations = ['supplier', 'customer']
inverse = {"customer":"supplier", "supplier":"customer", "other":"other"}
explanation_tags = ["{sentence}", "{instructions}"]
labeling_tags = ["{explanation}"]
confirm_tags = ["{company1}", "{company2}", "{relation}" , "{explanation}"]

class LLMAnnotator(object):
    """This module contains code that generates prompt templates for Language Model
        APIs, such as OpenAI's GPT-3. The prompts are designed to help generate labeled
        datasets for training Relation Extraction models.

        The template consists of three prompts:

        1. Explanation prompt: This prompt is used to explain certain aspects of a given
        sentence. It is designed to help the user identify the entities and relations
        in the sentence.

        2. Label generation prompt: This prompt is responsible for generating a JSON
        object that contains the label for the given sentence. The label includes the
        type of relation between the entities and any additional information that may
        be relevant.

        3. Confirmation prompt: This prompt is used to curate the final label generated
        by the previous prompt. It allows the user to review and modify the label as
        needed.

    """
    def __init__(self, version, matcher_device='cpu'):
        self.logger = get_logger('\U0001F300 LLMAnnotator', log_level="INFO")
        versions_dirs = set(filter(None, [v if not '.dvc' in v else None for v in glob(str(src_dir / 'data/llms_datasets/templates/v*'))]))
        print("versions_dirs", versions_dirs, str(src_dir / 'data/llms_datasets/templates/v*'))
        
        self._versions = list(sorted([float(v.split('/')[-1][1:]) for v in versions_dirs]))
        # Validate the version
        if version not in self._versions:
            raise (f"Invalid template, The available _versions are {self._versions}")
        self.version = version
        self.logger.info("Loading template card...")
        # Reading template card
        with open(src_dir / f"data/llms_datasets/templates/v{str(version)}/card.yaml") as ob:
            self.card = dotdict(yaml.safe_load(ob))
        # identify instruction for annotation
        self._sme_definitions = src_dir / f"data/llms_datasets/templates/v{str(version)}/sme_definitions.yaml"
        self._llm_definitions = src_dir / f"data/llms_datasets/templates/v{str(version)}/llm_definitions.yaml"
        # 
        self.outdir = src_dir / f"data/llms_datasets/templates/v{str(version)}/output"
        self.outdir.mkdir(parents=True, exist_ok=True)
        # read sme instrctions
        with open(self._sme_definitions,"r") as ob:
            self.sme_definitions =  yaml.safe_load(ob)
        # read llm instrctions
        with open(self._llm_definitions,"r") as ob:
            self.llm_definitions =  yaml.safe_load(ob)
        # Use only approved examples
        self.valid_instructions = list(self.sme_definitions.keys())
        # Add only approved definitions from llm 
        self.valid_instructions += list(filter(None, [k if v['approved'] == True \
                                else None for k,v in self.llm_definitions.items()]))
            
        # Define conflict directory
        self._conflict_path = src_dir /f"data/llms_datasets/templates/v{self.version}/reports/conflict_sme_llm_trainset.json"
        # Load entity matcher
        self.matcher  = SimCSE_Matcher('princeton-nlp/sup-simcse-roberta-base', device=matcher_device)
        if self.valid_instructions is not None:
            self.matcher.build_index(self.valid_instructions)
        else:
            self.logger.info("no instructions provided for annoation")
        # initiate the labelled data
        self._ground_truth  = None

    @property
    def unlabeled(self):
        suffix = (src_dir / self.card['dataset']).suffix
        if  suffix == '.json':
            pd.read_json(src_dir / self.card['dataset'])
        elif suffix == '.xlsx':
            return pd.read_excel(src_dir / self.card['dataset'], index_col="index" )

    @property
    def conflicts(self):
        if os.path.isfile(self._conflict_path):
            return pd.read_json(self._conflict_path)
        else:
            self.logger.info("No conflicts had been detected the conflict file might be not exist or target dataset not labeled!!!")


    
    def insert_llm_definition(self, items:dict):
        for definition, values in items.items():
            if self.llm_definitions.get(definition, None) is not None:
                continue
            self.llm_definitions[definition] = values
        with open(self._llm_definitions, 'w') as ob:
            yaml.safe_dump(self.llm_definitions, ob)
    

    
    def get_completion(self, prompt):
        messages = [{"role": "user", "content": prompt}]
        response = None
        while not response:
            try:
                response = openai.ChatCompletion.create(
                    model=self.card['model'],
                    messages=messages,
                    temperature=self.card['temperature'] # this is the degree of randomness of the model's output
                )
            except:
                time.sleep(0.2)
            
        return response.choices[0].message["content"]
    
    def update_template(self):
        new_template = self.card.copy()
        
        for item in new_template.keys():
            text = f"Insert {Fore.BLUE}`{item}`{Fore.RESET}\nPrevious one:\n{Fore.LIGHTCYAN_EX}{new_template[item]}\n"
            print(text)
            input_item = input()
            print(f"{'*'*50}\n")
            if input_item:
                new_template[item] = input_item
        if self.card == new_template:
            print("No changes founded")
        else:
            # TODO: validate the new template
            # If Vaild 
            # pprint the new template and ask  for confirmation
            #print("The new template\n", Fore.CYAN, json.dumps(new_template, indent=4),'\n', "*"*50)
            print("#### New Template ####\n######################\n")
            for k,v in new_template.items():
                print(f'{k}\n--------------\n{Fore.GREEN}{v}{Fore.RESET}\n{"*"*50}\n')
            confirm  = input("Confirm the changes?(y|n)")
            if confirm == 'y':
                new_version = round(max(self._versions)+0.1, 2)
                # Validate the data
                if self.card['dataset'] != new_template['dataset']:
                    # Read file
                    new_data = pd.read_json(new_template['dataset'])
                    # Check required columns
                    if not all([x in new_data.columns for x in dataset_columns]):
                        raise(f"Invalid dataset must contains {dataset_columns}")

                # Validate templates 
                if new_template['model'] not in valid_models:
                    raise Exception(f"invalid open-ai model, Valid Models: {valid_models}")
                    
                if self.card['explanation_prompt'] != new_template["explanation_prompt"]:
                    if not all([x in new_template["explanation_prompt"] for x in explanation_tags]):
                        raise Exception(f"Invalid prompt for explanation, must include [{explanation_tags}]")
                        
                if self.card['labeling_prompt'] != new_template["labeling_prompt"]:
                    if not all([x in new_template["labeling_prompt"] for x in labeling_tags]):
                        raise Exception(f"Invalid prompt for labeling, must include [{labeling_tags}]")
                        
                if self.card['confirmation_prompt'] != new_template["confirmation_prompt"]:
                    if not all([x in new_template["confirmation_prompt"] for x in confirm_tags]):
                        raise Exception(f"Invalid prompt for confirmation, must include [{confirm_tags}]")
                
                # Create the directory and the card of the template
                version_dir = str(src_dir / f"data/llms_datasets/templates/v{new_version}") 
                os.mkdir(version_dir)
                os.mkdir(version_dir+'/data')
                with open (version_dir+'/card.yaml', 'w')as obj:
                    yaml.safe_dump(new_template, obj)
                
                print(f"Create new template with version {new_version}")
    
    def overwrite_card(self):
        """
        Write the current state of the `card` dictionary to the `card.yaml` file.
        """
        with open(src_dir / f"data/llms_datasets/templates/v{str(self.version)}/card.yaml", 'w') as ob:
            yaml.safe_dump(self.card, ob)

    def update_instructions(self, command: Text, instruction: Text, overwrite: bool = True) -> bool:
        """
        Update the `instructions` set in the `card` dictionary using the provided `command` and `instruction`.

        Args:
            command (Text): The name of the method to call on the `instructions` set.
            instruction (Text): The argument to the method specified in `command`.
            overwrite (bool, optional): Whether to overwrite the `card.yaml` file. Defaults to True.

        Returns:
            bool: Whether the update was successful.
        """
        try:
            self.instructions.__getattribute__(command)(instruction)
            if overwrite:
                self.overwrite_card()
            return True
        except KeyError:
            return False


    @staticmethod
    def mask_terms(sentence:Text, mask:Dict, mask_word:Text, demask=False):
        if demask:
            for v,k in dict(sorted(mask.items(),
                                    key=lambda item:item[1],
                                    reverse=True)).items():
                sentence = sentence.replace(f"{mask_word}{k}",v)
        else:
            for k,v in dict(sorted(mask.items(),
                                    key=lambda item:item[1],
                                    reverse=True)).items():
                sentence = sentence.replace(k,f"{mask_word}{v}")
        return sentence

    @staticmethod
    def dict_to_str(rules_dict):
        str_out = ''
        for key, values in rules_dict.items():
            str_out += "- {}:\n".format(key)
            str_out += '\t- ' + '\n\t- '.join(values) + '\n'
        return str_out
        
    def generate_explanation_prompt(self,
                                    sentence,
                                    org_groups=None,
                                    use_llm_definitions=True,
                                    candidate_definition=False,
                                    custom_definitions=None):
        """
        Generate an explanation prompt based on the given inputs.

        @params:
        -------
        - sentence (str): The sentence to generate the prompt for.
        - org_groups (list): Optional. List of organizational groups to mask in the sentence.
        - use_llm_definitions (bool): Optional. Flag to indicate whether to include LLM definitions in the prompt.
        - candidate_definition (bool): Optional. Flag to indicate whether to include candidate definitions in the prompt.

        @returns:
        ---------
        - str: The generated explanation prompt.

        """
        prompt = self.card['explanation_prompt']
        # If org_groups are given mask the `Orgs` on the query
        if org_groups:
            sentence = LLMAnnotator.mask_terms(sentence=sentence,
                                               mask=org_groups,
                                               mask_word="Company")
        # create_explanation_prompt
        explanation_prompt = str(self.card.explanation_prompt)
        # Create definitions
        definitions = ''
        output_rules = ''
        cand_definitions = None
        # Select slice of definitions based on semantic similarity between query and all the definitions
        
        if custom_definitions:
            for definition in custom_definitions:
                definitions += '- {}\n'.format(definition)  
        
        elif candidate_definition:
            cand_definitions = {x[0]:x[1] for x in self.matcher.search(sentence, threshold=0.1, top_k=10)} 
            
            for definition in cand_definitions.keys():
                definitions += '- {}\n'.format(definition)

        else:
            for definition in self.sme_definitions.keys():
                definitions += '- {}\n'.format(definition)

        for output_rule in self.card.explanation_output_rules:
            if isinstance(output_rule, dict):
                output_rules += dict_to_str(output_rule)
            else:
                output_rules += "- {}\n".format(output_rule)

        explanation_prompt = explanation_prompt.replace('{sentence}', sentence)
        explanation_prompt = explanation_prompt.replace('{definitions}', definitions)
        explanation_prompt = explanation_prompt.replace('{explanation_output_rules}', output_rules)

        return explanation_prompt


    def generate_relation_prompt(self, explanation):
        prompt = self.card['labeling_prompt']
        prompt = prompt.replace('{explanation}',  explanation)
        return prompt

    def generate_confirmation(self, company1, company2, relation, explanation):
        prompt = self.card['confirmation_prompt']
        prompt = prompt.replace('{company1}',  company1)
        prompt = prompt.replace('{company2}',  company2)
        prompt = prompt.replace('{explanation}',  explanation)
        prompt = prompt.replace('{relation}',  relation)
        return prompt
        
    def annotate(self, datapoint):
        """
        Annotates a datapoint with explanations, relations, and confirmations if they do not already exist.

        Args:
            datapoint (dict): A dictionary representing a datapoint.

        Returns:
            dict: The annotated datapoint.
        """
        if not datapoint.get('explanation'): 
            # If the datapoint does not have an explanation, generate an explanation prompt and get the completion.
            datapoint['explanation_prompt'] = self.generate_explanation_prompt(sentence=datapoint['sentence'],
                                                                               org_groups=datapoint.get('org_groups'))
            
            explanation = self.get_completion(datapoint['explanation_prompt'])
            datapoint['explanation'] = LLMAnnotator.mask_terms(sentence = explanation,
                            mask=datapoint.get('org_groups'),
                            mask_word="Company", 
                            demask=True)
            
            
        if not datapoint.get('ser_relations'):
            # If the datapoint does not have a relation, generate a relation prompt and get the completion.
            datapoint['relation_prompt'] = self.generate_relation_prompt(datapoint['explanation'])
            datapoint['ser_relations'] = self.get_completion(datapoint['relation_prompt'])
            try:
                datapoint['relations'] = deserialize_relations(datapoint['ser_relations'])
                llms_relations, other_relations = establish_company_relations(datapoint, self.matcher)
                datapoint['llms_relations'] = llms_relations
                datapoint['other_relations'] = other_relations

            except:
                datapoint['relations'] = 'undefined'
                datapoint['llms_relations'] = 'undefined'
                datapoint['other_relations'] = 'undefined'

        if not datapoint.get('confirmation') and self.card.get('confirm'):
            # If the datapoint does not have a confirmation, generate a confirmation prompt and get the completion.
            if datapoint['relations'] == 'undefined':
                datapoint['confirmation_prompt'] = 'undefined'
                datapoint['confirmation'] = 'undefined'
            else:
                datapoint['confirmation_prompt'] = self.generate_confirmation(datapoint['company1'],
                                                        datapoint['company2'],
                                                        datapoint['relations'],
                                                        datapoint['explanation'])
                datapoint['confirmation'] = self.get_completion(datapoint['confirmation_prompt'])
        return datapoint
    
    def generate_labels(self, batch:Tuple[int,int]):
        batch_name = f'batch_{batch[0]}_{batch[1]}.json'
        file_name = src_dir / f'data/llms_datasets/templates/v{str(self.version)}/data/{batch_name}'
        if os.path.exists(file_name):
            self.logger.info("This batch is already exist")
            data = pd.read_json(file_name)
            data = pd.concat([data ,
                              self.unlabeled_data[len(data)+batch[0]:batch[1]]], axis = 0)                
        else:   
            data = self.unlabeled_data[batch[0]:batch[1]]
        annotations = []
        count  = 0
        for i, datapoint in tqdm(data.iterrows(), total= data.shape[0]):
            datapoint = datapoint.to_dict()
            datapoint['index'] = i
            if self.card.get('tagged'):
                annotations.append(self.annotate(datapoint))                 
            else: 
                # Annotate all possible pairs
                datapoint_pairs= LLMAnnotator.get_random_company_pairs(datapoint['org_groups'],
                                                                      self.card.get("max_rel", 7))
                for pair in datapoint_pairs:
                    db_pair = datapoint.copy()
                    db_pair['company1'] = pair[0]
                    db_pair['company2'] = pair[1]
                    # self.logger.info(f"For {pair}, we have datapoint: \n{datapoint}")
                    annotations.append(self.annotate(db_pair))
            count += 1
                
            if (count%10) == 0:
                pd.DataFrame(annotations).to_json(file_name)
        pd.DataFrame(annotations).to_json(file_name)
        return pd.DataFrame(annotations)

    def is_conflict(self, row: Dict, threshold: float = 0.85) -> Tuple[bool, Tuple[str, str, str]]:
        """
        Check if there is a conflict between two entities based on their expected relation and their actual relations.

        Args:
        - row (Dict): A dictionary containing the information about the entities and their relations.
        - threshold (float): A float value between 0 and 1 that determines the minimum similarity score for the relations to be considered aligned.

        Returns:
        - A tuple containing a boolean value indicating whether the relations are aligned or not, and a tuple of the two entities and their expected relation.
        """
        
        # Initialize the value of align to False until we prove otherwise.
        align = False
        
        # Define the two entities to search for.
        c1 = row.get('entity_1')
        c2 = row.get('entity_2')
        
        # Determine the expected relation between c2 and c1.
        expected_relation = row.get('inf_relations')
        
        # Get the organization groups and the SME relation.
        org_groups  = row.get('org_groups')
        sme_relation = (c2, expected_relation, c1)
        
        # Initialize defaultdict to carry org_ids as keys with values carrying all companies and their aliases.
        id2c = defaultdict(lambda: [])
            
        # Group the companies by their organization ID.
        for k,v in org_groups.items():
            id2c[v].append(k)
        
        # Set the expected relation to "other" if the label is 0.
        if row.get('Label') == 0:
            expected_relation = "other"
        
        # Set the SME relation based on the main relations.
        elif main_relations[0] == expected_relation:
            sme_relation = (c2,main_relations[0], c1)
        elif main_relations[1] == expected_relation:
            sme_relation = (c1,main_relations[0], c2)
            
        # Initialize defaultdict to carry supplier names as keys with values carrying customer names.
        llm_relations= defaultdict(lambda : [])
        
        # Get all the supplier-customer relations.
        if isinstance(row['llms_relations'], list):
            for llm_relation in row['llms_relations']:
                if llm_relation[1] == 'supplier':
                    supplier = llm_relation[0]
                    supplier_id = org_groups.get(supplier)
                    supplier_names = id2c[supplier_id] if supplier_id else [supplier]
                    
                    customer = llm_relation[2]
                    customer_id = org_groups.get(customer)
                    customer_names = id2c[customer_id] if customer_id else [customer]
                    
                    for supplier_name in supplier_names:
                        llm_relations[supplier_name] += customer_names
        
        # Get the supplier names.
        llm_suppliers = list(llm_relations.keys())
        
        # Check if the relation had been detected
        expected_supplier = sme_relation[0]
        expected_customer = sme_relation[2]
        
        # If the expected relation is "other".
        if expected_relation == "other":
            # If there is no relation between the reporter and other companies.
            if len(llm_relations) == 0:
                align = True
            # If there are relations between the reporter and other companies.
            else:
                # Check if the supplier is found in the llm_relations.
                align = not self.matcher.similarity(expected_supplier,list(llm_relations.keys())).max() > threshold
        
        # If the expected relation is not "other".
        else:
            # If there are relations between the reporter and other companies.
            if len(llm_relations) > 0:
                # Get the similarity scores between the expected supplier and the llm_suppliers.
                sim_scores = self.matcher.similarity(expected_supplier, llm_suppliers)
                max_score = sim_scores.max()
                max_idx  = sim_scores.argmax()
                
                # If the maximum similarity score is greater than the threshold or the expected supplier is found in the llm_suppliers.
                if max_score > threshold  or any([expected_supplier in x for x in llm_suppliers  ]):
                    # Check if the expected customer is found in the llm_relations.
                    align = self.matcher.similarity(expected_customer,llm_relations[llm_suppliers[max_idx]] ).max() > threshold \
                            or any([[expected_customer.lower() in x.lower() for x in y] for y in llm_relations[llm_suppliers[max_idx]]  ])
        
        # Return the align and sme_relation values.
        return align, sme_relation
    
    def detect_conflicts(self, threshold:float=0.85,save:bool = True)->pd.DataFrame:
        '''Search conflicts within llms annotator compared with ground truth labels
        '''
        # Read generated labels
        llm_labels = pd.read_json(f"data/llms_datasets/templates/v{self.version}/labels/labels.json")
        llm_labels = llm_labels.query("inf_relations.notnull()")

        # Determine the basic two lists: sme_relations (expert annotations) and align_bool (True if llms align with ground truth).
        sme_relations = []
        align_bool = []
        for i, row in tqdm(llm_labels.iterrows(), total=llm_labels.shape[0]):
            align, sme_relation = self.is_conflict(row.to_dict(), threshold=threshold)
            align_bool.append(align)
            sme_relations.append(sme_relation)

        llm_labels['sme_relations'] = sme_relations
        llm_labels['align'] = align_bool

        true_ratio = llm_labels.query("align == True").shape[0] / llm_labels.shape[0]
        self.logger.info("Alignment percentage: {:.2f}%".format(true_ratio*100))
        conflicts = llm_labels.query("align == False")

        if save:
            if not os.path.exists(f"data/llms_datasets/templates/v{self.version}/reports"):
                os.mkdir(f"data/llms_datasets/templates/v{self.version}/reports")
            conflicts.to_json(self._conflict_path, index=True)
        return conflicts

    @staticmethod
    def get_companies_and_relation(relation: dict) -> tuple:
        """
        Extracts the companies and relation from a dictionary.

        Parameters:
        relation (dict): A dictionary containing the relation between two companies.

        Returns:
        tuple: A tuple containing the first company, the relation, and the second company in the relation.

        Example:
        >>> relation = {'company_1': 'Health Net Inc.', 'relation': 'supplier', 'company_2': 'LA Care'}
        >>> get_companies_and_relation(relation)
        ('Health Net Inc.', 'supplier', 'LA Care')
        """
        keys = np.array(list(relation.keys()))

        # Get the index of the 'relation' key in the dictionary
        relation_idx = np.where(keys == 'relation')[0][0]

        # Get the first and second company names based on the 'relation' key index
        company_1 = relation[keys[relation_idx-1] if relation_idx > 0 else keys[1]]
        company_2 = relation[keys[relation_idx+1] if relation_idx<len(keys) else keys[0]]

        return company_1, relation['relation'], company_2
    @staticmethod
    def get_random_company_pairs(org_groups, max_relation=5):
        """
        Returns a list of randomly-selected pairs of companies from a dictionary of company groups.

        Parameters:
            org_groups (dict): A dictionary mapping company keys to group values.
            max_relation (int): The maximum number of pairs to return.

        Returns:
            A list of randomly-selected pairs of companies.
        """
        # Create a dictionary called 'ids2org' that maps each value in 'org_groups' to a list of keys that have that value
        ids2org = defaultdict(lambda : [])
        for key ,val in org_groups.items():
            ids2org[val].append(key)

        # Create a list of all possible pair-wise combinations of the values in 'ids2org', and randomly choose 5 of those combinations
        availabel_relations = []
        comp_keys = list(ids2org.keys())
        for i in range(len(comp_keys)):
            for j in range(i+1, len(comp_keys)):
                relation_t = comp_keys[i], comp_keys[j]
                availabel_relations.append(relation_t)

        # For each of the 5 chosen combinations, randomly choose one key from each group of 'org_groups'
        #  that corresponds to the two values in the combination
        n_relations = max_relation if len(availabel_relations) > max_relation else len(availabel_relations)
        random_pairs = random.sample(availabel_relations, n_relations)
        company_pairs = []
        for pair in random_pairs:
            company1 = random.choice(ids2org[pair[0]])
            company2 = random.choice(ids2org[pair[1]])
            company_pairs.append((company1, company2))
        return company_pairs

    
    

def establish_company_relations(datapoint, matcher):
    """
    Assigns relationships between companies based on certain criteria.

    Args:
        datapoint (dict): A dictionary containing the sentence to be processed.
        matcher (object): An instance of the fuzzywuzzy string matching class.

    Returns:
        tuple: A tuple containing the LLMS relations and other relations.

    """
    global main_relations
    global inverse
    org_groups = datapoint['org_groups']
    relations = datapoint['relations']
    matcher_built = False
    # Collect all companies mentioned in the relations and create a dictionary with each unique company as a key
    llms_companies = []
    for relation in relations:
        llms_companies += [relation.get('company_1'), relation.get('company_2')]
    llms_companies = {k:None for k in set(llms_companies)}
    llms_ids = {k:i for i,k in enumerate(set(llms_companies))}
    ids_llms = {i:k for i,k in enumerate(set(llms_companies))}
    # Check if each company in the dictionary is mentioned in the sentence, and if not, try to match it with a known organization
    for company in list(llms_companies.keys()):
        if company in datapoint['sentence']:
            llms_companies[company] = company
        else:
            if matcher_built is False: 
                matcher.build_index(list(org_groups.keys()))
                matcher_built = True
            
            matches = matcher.search(company, threshold=0.95, top_k = 3)
            if len(matches) > 0:
                llms_companies[company] = matches[0][0]
            else:
                llms_companies.pop(company)
                ids_llms.pop(llms_ids.pop(company))

    # Create a dictionary called 'ids2org' that maps each value in 'org_groups' to a list of keys that have that value
    ids2org = defaultdict(lambda : [])
    for key ,val in llms_ids.items():
        ids2org[val].append(key)

    # Create a list of all possible pair-wise combinations of the values in 'ids2org', and randomly choose 5 of those combinations
    availabel_relations = []
    comp_keys = list(ids2org.keys())
    for i in range(len(comp_keys)):
        for j in range(i+1, len(comp_keys)):
            relation_t = tuple(sorted([comp_keys[i], comp_keys[j]]))
            availabel_relations.append(relation_t)
    exist_relations = []
    llms_relations = []
    for relation in relations:
        c1 = relation.get('company_1')
        c2 = relation.get('company_2')
        if not all([c1 in llms_companies.keys() , c2 in llms_companies.keys()]):
            continue
        relation = relation.get('relation')
        if main_relations[0] == relation:
            llms_relations.append((c1,main_relations[0], c2))
        elif main_relations[1] == relation:
            llms_relations.append((c2,main_relations[0], c1))
        else:
            llms_relations.append((c1, relation, c2))
        
        if not all([c1,c2,relation]):
            continue 
        c1_id = llms_ids.get(c1)
        c2_id = llms_ids.get(c2)    
        exist_relations.append(tuple(sorted([c1_id, c2_id])))
        
    other_ids = set(availabel_relations) ^ set(exist_relations)
    other_relations = []
    for pair in other_ids: 
        c1 = llms_companies[ids_llms[pair[0]]]
        c2 = llms_companies[ids_llms[pair[1]]]    
        other_relations.append((c1,'other', c2))
    return llms_relations, other_relations

### **_`LLM Instructor`_**

In [9]:
    
from typing import List, Tuple, Text 
from src.utils import get_logger
from pathlib import Path
import sys
import pandas as pd
from tqdm.auto import tqdm
import copy
from src.labels_generator.llm_annotator import establish_company_relations, main_relations
src_dir = Path.cwd().parent
sys.path.append(str(src_dir))
from src.labels_generator.instructor_util import deserialize_json_dict2, relations_tupled_2
# from src.labels_generator import LLMAnnotator
from src.labels_generator.utils import relation_search

class LLMInstructor(LLMAnnotator):
    def __init__(self, version,
                matcher_device='cpu',
                deserialize_func = deserialize_json_dict2,
                tuple_func= relations_tupled_2
                 ):
        """
        Initializes an LLMInstructor object.
        Tasks expected from that module:
        - Find conflicts and generate definition to resolve conflicted points
        - Provide list of rules attached with attr

        Args:
        - version (str): the version of the LLM model to use
        - matcher_device (str): the device to use for the matcher model (default: 'cpu')

        Returns:
        - None
        """
        super().__init__(version, matcher_device)
        self.logger = get_logger('\U0001F300 LLMInstructor', log_level="INFO")
        self.name = 'llm-instructor'
        self.deserialize_func = deserialize_func
        self.tuple_func = tuple_func
        self._ground_truth = None
        self.default_relation_columns = ['entity_2', 'inf_relations', 'entity_1','Label']

    @property
    def ground_truth(self):
        if isinstance(self._ground_truth, pd.DataFrame):
            return self._ground_truth
        else: 
            self.logger.warn("No facts loaded, please use `load_facts` to read labelled data")
            
        
    def load_facts(self,
                   path: Path,
                   feature_column:str,
                   relation_columns: List[str],
                   **kwargs):
        """
        Load labelled data from a specified path, considering different file formats.

        #params:
        --------
        - path (Path): The path to the data file.
        - relation_columns (List[str]): List of column names containing relation data.
        - **kwargs: Additional keyword arguments for data reading (e.g., for pandas read_json or read_excel).

        @raises:
        --------
        - ValueError: If the 'path' argument is not an instance of 'Path'.
        - FileExistsError: If the file format is not supported (not .json or .xlsx).
        """
        self.feature = feature_column
        
        if not isinstance(path, Path):
            raise ValueError("Invalid path type, must be Path from pathlib")

        suffix = path.suffix
        self.facts_source = path
        if suffix == '.json':
            self._ground_truth = pd.read_json(path, **kwargs)
        elif suffix == '.xlsx':
            self._ground_truth = pd.read_excel(path, **kwargs)
        else:
            raise FileExistsError("Invalid path, must be .json or .xlsx")

        if "sme_relations" in self._ground_truth.columns:
            # Process sme_relations to unify relation directions
            if not isinstance(self._ground_truth['sme_relations'].iloc[0], list):
                tqdm.pandas(desc="Evaluate sme_relations")
                self._ground_truth['sme_relations'] = self._ground_truth['sme_relations'].progress_apply(eval)
            tqdm.pandas(desc="Resort sme relations")
            self._ground_truth['sme_relations'] = self._ground_truth['sme_relations'].progress_apply(lambda x:
                resort_relation((x[0], x[1], x[2]), self.card.relations_map))

        else:
            tqdm.pandas(desc="Create sme_relations column")
            relation_columns = relation_columns or self.default_relation_columns
            self._ground_truth['sme_relations'] = self._ground_truth[relation_columns].progress_apply(lambda x:
                create_sorted_relation(x[0], x[1], x[2], x[3], relations_map=self.card.relations_map),
                axis=1)
        
    def generate_rule_prompt(self, sentence: str,
                             fact: Tuple[str],
                             source_idx: int,
                             compelition=True) -> str:
        """
        Generates a rule prompt by replacing placeholders in a card's rule_prompt template.

        @params:
        --------
        - sentence (str): The main sentence to be inserted into the rule_prompt.
        - facts (Tuple[str]): A tuple of facts to be incorporated into the rule_prompt.
        - source_idx (int): Index of the source for the rule_prompt.

        @returns:
        ---------
        - str: The generated rule prompt with placeholders replaced.

        @raises:
        --------
        - ValueError: If the card does not contain a rule_prompt.

        """
        # Retrieve the rule_prompt template from the card
        rule_prompt = str(self.card.get('rule_prompt'))

        # Check if rule_prompt exists in the card
        if not rule_prompt:
            raise ValueError("Card must contain rule_prompt")

        # Replace placeholders in the rule_prompt template
        rule_prompt = rule_prompt.replace("{sentence}", sentence)
        facts_str = ''
        facts_str += '{} {} {}'.format(fact[0], self.card[f"{fact[1]}_expression"], fact[2])
        if fact[1] in main_relations:
            facts_str += '{} {} {}'.format(fact[2], self.card[f"{inverse[fact[1]]}_expression"], fact[0])
        
        # Add facts and intro phrases of each relation
        rule_prompt = rule_prompt.replace("{facts}", facts_str)
        rule_prompt = rule_prompt.replace("{intro}", self.card[f"{fact[1]}_intro"])
        if compelition:
            llm_definitions = self.get_completion(prompt=rule_prompt)
            return llm_definitions
        return rule_prompt
    
    
    def annotate_point(self,
                       row,
                       candidate_definition=True,
                       custom_definitions=None):
        row["exp_prompt"] = self.generate_explanation_prompt(row[self.feature],
                                                             candidate_definition=candidate_definition,
                                                             custom_definitions=custom_definitions)
        # Generate relations and parse it
        explanation = self.get_completion(prompt=row["exp_prompt"])
        row['explanation'] = explanation
        #row['rel_prompt'] = self.generate_relation_prompt(explanation)
        #relation_completion = get_completion(prompt=row['rel_prompt'])
        #row['relation_completion'] = relation_completion

        # try:
        row['explanation'] = eval(row['explanation']) #or explanation
        row['relations'] = self.tuple_func(row['explanation'])
        row['defintions'] = [x['span with definition']['definition'] for x in row['explanation']]
        row['relation_phrase']  = [x['relation phrase'] for x in row['explanation']]
        
        matched_defintions = self.matcher.search(row['defintions'], threshold=0.9, top_k=1)
        semantic_score = 0
        if len(matched_defintions[0]) > 0 : 
            row['matched_definitions'] = [x[0][0] for x in matched_defintions ]\
                                         if len(matched_defintions)> 0  else None
            semantic_score = annotator.matcher.similarity(row["matched_definitions"], row[self.feature])
        # Test if annotation aligns
        align = False
        if isinstance(row['relations'], list):
            align = relation_search(row['sme_relations'], row['relations'], annotator.matcher)

        row["source_file"] = str(annotator.facts_source)
        row["semantic_score"] = round(float(semantic_score.max()), 3)
        row["align"] = align
            
        return row
        
    def extract_rules(self, row):
        """
        Extract rules based on the given row.

        @params:
        --------
        - row (dict): The row containing the necessary information.

        @returns:
        ---------
        - dict: The generated definitions and their corresponding information.

        """
        _generated_definitions = {}
        # Generate definitions to solve certain pattern
        out = self.generate_rule_prompt(sentence=row[self.feature],
                                        fact=row['sme_relations'],
                                        source_idx=row['idx'],
                                        compelition=True)
        # Deserialize the output
        llm_definitions = deserialize_json_dict2(out)
        llm_definitions = list(llm_definitions.values())
        # Calculate the semantic similarity between the definition and the report
        semantic_scores = self.matcher.similarity(llm_definitions, row[self.feature])
        # Return only sentences semantically similar with threshold > 0.5
        indices = np.where(semantic_scores > 0.4)[0]
        # Annotate the report with the generated definition for validation
        valid_definitions = {llm_definitions[i]: semantic_scores[i] for i in indices}
        if len(valid_definitions) == 0:
            return {}
        for definition, semantic_score in valid_definitions.items():
            # Ingest the definition into the definitions dataset
            annotation = self.annotate_point(dict(row), custom_definitions=[definition])
            # Identify if definition solved the pattern
            if isinstance(annotation['relations'], list):
                align = relation_search(annotation['sme_relations'], annotation['relations'], self.matcher)
            else:
                align = False
            _generated_definitions[definition] = {
                                                  "sentence": annotation[self.feature],
                                                  "source_file": str(self.facts_source),
                                                  "source_idx": annotation['idx'],
                                                  "source_align": align,
                                                  "source_semantic_score": round(float(semantic_score[0]), 3),
                                                  "approved": False,
                                                  "rejected": False,
                                                  "type": "LLM"}

            self.logger.info("sentence: {}\ndefinition: {}\nrelation_phrase: {}\nalign: {}"
                             .format(annotation[self.feature],
                                     definition,
                                     annotation['relation_phrase'],
                                     align))
        return _generated_definitions

    def discover_rules(self, b_start=None, b_end=None):
        """
        Extract rules based on the provided data batch.

        @params:
        --------
        - b_start (int): Optional. Start index of the data batch.
        - b_end (int): Optional. End index of the data batch.

        @returns:
        ---------
        None
        """
        b_start = b_start or 0
        b_end = b_end or len(annotator.ground_truth)
        data_batch = annotator._ground_truth[b_start:b_end]
        _definitions_records = []
        _generated_definitions = []

        for i, row in tqdm(data_batch.iterrows(),
                           total=len(data_batch),
                           desc="Searching rules"):
            row = row.to_dict()
            row['idx'] = i

            # Annotate with LLM
            row = annotator.annotate_point(row, candidate_definition=True)

            # Test if annotation aligns
            align = False
            if isinstance(row['relations'], list):
                align = relation_search(row['sme_relations'], row['relations'], annotator.matcher)

            semantic_score = 0
            if len(row.get("matched_definitions", [])) > 0:
                semantic_score = annotator.matcher.similarity(row["matched_definitions"], row[self.feature])

            row["source_file"] = str(annotator.facts_source)
            row["semantic_score"] = round(float(semantic_score.max()), 3)
            row["align"] = align
            _definitions_records.append(row)

            if align is False:
                print("resolve definition for following sentence\n: {}".format(row[self.feature]))
                _llm_definitions = annotator.extract_rules(row)
                _generated_definitions.append(_llm_definitions)
                # annotator.insert_instruction(_llm_definitions)
        pd.DataFrame(_definitions_records).to_excel(self.outdir/f'b{b_start}_b{b_end}_{self.facts_source.name}')
        for _definition in _generated_definitions:
            annotator.insert_llm_definition(_definition)


# _Using LLMInstructor_

In [10]:
# Real the data
version = 2.4
annotator = LLMInstructor(version=version)
methods = [method for method in dir(LLMInstructor) if callable(getattr(LLMInstructor, method)) and not method.startswith("__")]

#Load Facts
annotator.load_facts(src_dir / "data/tasks/finetune_llm_on_label_1/source_data.xlsx",
                     feature_column="sentence",
                     index_col='index',
                     relation_columns=["entity_2", "inf_relations", "entity_1", "Label"])
methods

versions_dirs {'/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v1.5', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v1.7', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v1.8', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v2.0', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v1.9', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v1.6', '/notebooks/inferess-relation-extraction/data/llms_datasets/templates/v2.4'} /notebooks/inferess-relation-extraction/data/llms_datasets/templates/v*
2023-11-21 13:06:35,922 — 🌀 LLMAnnotator — INFO — Loading template card...


Evaluate sme_relations: 100%|██████████| 2910/2910 [00:00<00:00, 102014.52it/s]
Resort sme relations: 100%|██████████| 2910/2910 [00:00<00:00, 451701.44it/s]


['annotate',
 'annotate_point',
 'detect_conflicts',
 'dict_to_str',
 'discover_rules',
 'extract_rules',
 'generate_confirmation',
 'generate_explanation_prompt',
 'generate_labels',
 'generate_relation_prompt',
 'generate_rule_prompt',
 'get_companies_and_relation',
 'get_completion',
 'get_random_company_pairs',
 'insert_llm_definition',
 'is_conflict',
 'load_facts',
 'mask_terms',
 'overwrite_card',
 'update_instructions',
 'update_template']

## Extract rules from annotated table

In [36]:
# discover rules by focusing on conflicts between human annotator and LLM annotation
annotator.discover_rules(b_start=0, b_end=5)

Searching rules:   0%|          | 0/5 [00:00<?, ?it/s]

resolve definition for following sentence
: Since NGL Energy Partners December 2013 acquisition of Gavilon Energy, NGL Energy Partners have purchased crude oil and natural gas from and sold crude oil and natural gas to WPX.
2023-10-07 23:14:19,948 — 🌀 LLMInstructor — INFO — sentence: Since NGL Energy Partners December 2013 acquisition of Gavilon Energy, NGL Energy Partners have purchased crude oil and natural gas from and sold crude oil and natural gas to WPX.
definition: Supplier companies may be referred to as entities from which another company purchases goods or services.
relation_phrase: ['NGL Energy Partners is a customer of WPX.']
align: True
resolve definition for following sentence
: On August 7, 2007, Sypris Solutions Inc entered into a comprehensive settlement agreement with Dana to resolve all outstanding disputes between the parties, terminate previously approved arbitration payments and replace three existing supply agreements with a single, revised contract running throu

### Generate annotation with candidate definitions

In [60]:
row = annotator.ground_truth.sample(1).iloc[0]
row['idx'] = row.name

In [61]:
row = annotator.annotate_point(dict(row), candidate_definition=True)

In [63]:
from pprint import pprint
print("expected: {}".format(row['sme_relations']), "\n------------\n")
print("predicted: {}".format(row['relations']), "\n------------\n")
pprint(row['explanation'])


expected: ['AEROJET ROCKETDYNE HOLDINGS, INC.', 'supplier', 'NASA'] 
------------

predicted: [] 
------------

[{'explanation': 'AEROJET ROCKETDYNE HOLDINGS INC reports that Principal '
                 'customers include the DoD, NASA, Boeing, Lockheed Martin, '
                 'Orbital Sciences Corporation, Raytheon Company, and ULA.',
  'financial_trade': [],
  'link between span and defintion': 'The span mentions the names of companies '
                                     'that are the principal customers of '
                                     'AEROJET ROCKETDYNE HOLDINGS INC. These '
                                     'companies are referred to as customers '
                                     'because they are a source of revenue for '
                                     'AEROJET ROCKETDYNE HOLDINGS INC.',
  'nothing': [],
  'relation phrase': 'AEROJET ROCKETDYNE HOLDINGS INC is a supplier of DoD, '
                     'NASA, Boeing, Lockheed Martin, Orbital Sciences

### Generate definitions out of row

In [64]:
generated_definitions = annotator.extract_rules(row)

2023-10-07 23:32:08,339 — 🌀 LLMInstructor — INFO — sentence: AEROJET ROCKETDYNE HOLDINGS INC reports that Principal customers include the DoD, NASA, Boeing, Lockheed Martin, Orbital Sciences Corporation, Raytheon Company, and ULA.
definition: Supplier companies may be referred to as entities that provide goods or services to their customers.
relation_phrase: ['AEROJET ROCKETDYNE HOLDINGS INC is a supplier of DoD, NASA, Boeing, Lockheed Martin, Orbital Sciences Corporation, Raytheon Company, and ULA.']
align: True
2023-10-07 23:32:13,352 — 🌀 LLMInstructor — INFO — sentence: AEROJET ROCKETDYNE HOLDINGS INC reports that Principal customers include the DoD, NASA, Boeing, Lockheed Martin, Orbital Sciences Corporation, Raytheon Company, and ULA.
definition: Customer companies may be referred to as entities that purchase goods or services from their suppliers.
relation_phrase: ['AEROJET ROCKETDYNE HOLDINGS INC is a supplier of DoD, NASA, Boeing, Lockheed Martin, Orbital Sciences Corporation, 

## Annotate custom sentence

In [11]:
sentence = "The loss of EchoStar or the heavy equipment OEM as a customer, a deterioration in either customer’s overall business, or a decrease in either customer’s volume of sales, could result in decreased sales for CalAmp Corp and could have a material adverse impact on CalAmp Corp ability to grow CalAmp Corp business."
sentence = "ALABAMA POWER CO reports that Southern Company's financial success is directly tied to the satisfaction of its customers"
sentence = "Southern Company, Alabama Power, Georgia Power, Mississippi Power (with the exception of its cost-based MRA electric tariffs described below), and Southern Company Gas each have a diversified base of customers and no single customer or industry comprises 10% or more of each company's revenues"
sentence = '''As part of our merger with Spansion, we acquired agreements with Fujitsu Semiconductor Limited ("FSL"), XMC and SK Hynix Inc. ("SK Hynix"). Agreements with FSL include agreements for the supply of product wafer foundry services, sort services and assembly and test services relating to the microcontroller and analog businesses. These agreements are at competitive market rates and enable us to leverage FSL's existing manufacturing capabilities and relationships with its partners spanning across various technologies, processes, geometries and wafer sizes in their wafer fabrication facilities and package solutions in their back end manufacturing facilities, until such time that we can either move these internally to our fabrication and back end facilities or find alternative solutions. For FSL, the fabrication facilities are all located in Japan, while the back end facilities are in Japan and other Asian countries. The supply agreements do not call for any minimum purchase commitments. The arrangement with XMC provides production support for advanced NOR technology products at 65, 45 and development of 32 nanometers. The arrangement with SK Hynix provides for the development and supply of SLC NAND products at the 4x and 3x nodes.	'''
explanation_prompt = annotator.generate_explanation_prompt(sentence)

In [12]:
def get_completion(self, prompt):
    response = None
    while not response:
        try:
            response = openai.ChatCompletion.create(
                model="gpt-4-1106-preview",
                messages=messages,
                temperature=self.card['temperature'] # this is the degree of randomness of the model's output
            )
        except:
            time.sleep(0.2)


In [13]:
from openai import OpenAI
client = OpenAI(api_key=os.environ['OPENAI_API_KE'])

In [14]:
ner_prompt = '''
Extratc the organization names as entities from the following sentence:

#Instructions:
- Explain why you extracted every entity with using the context of each entity

Sentence:TRIUMPH GROUP Inc reports that Systems and Support has experienced an increase in its military end market primarily from volume on military rotorcraft, and Aerospace Structures has experienced a decrease in its military end market due to reduced volume in the C-130 program and certain military rotorcraft


'''

In [14]:
messages = [{"role": "user", "content": explanation_prompt}]
response = client.chat.completions.create(
    model=annotator.card['model'],
    messages=messages
)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


11/21/2023 13:06:59 - INFO - httpx -   HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [15]:
print(response.choices[0].message.content)

[{
"explanation": "The report mentions that as part of their merger with Spansion, they acquired agreements with Fujitsu Semiconductor Limited (FSL), XMC, and SK Hynix Inc. These agreements include supply agreements for various services and products. Therefore, the companies mentioned in the report are in a supplier and customer relation.",
"span with definition": {"span": "As part of our merger with Spansion, we acquired agreements with Fujitsu Semiconductor Limited (\"FSL\"), XMC and SK Hynix Inc. (\"SK Hynix\").", "definition": "Supplier companies may be referred to as vendors or providers or sellers of services, products, or materials to customer companies."},
"link between span and defintion": "The report explicitly mentions that the agreements include supply agreements, which aligns with the definition of a supplier and customer relation. The companies mentioned in the report, namely FSL, XMC, and SK Hynix, are referred to as providers or sellers of services and products to the a

# **_Update Prompts Template_**

In [39]:
rule_prompt= """
Your task is to create a definition from the following report which is associated with supply chain relation fact

## Report
{sentence}

## Facts
- {facts}



## Output structure
{
"definition 1": <Defination can be used to indicate this relation>
"definition 2": <Another definition can be used to indicate this relation menationed on the facts>
...
}

## Rules to follow for creating output 
- Write one definition or more that can be used to describe why {facts} from the context of the Report
- Don't include company or products names.
- Definitions must align with the report.
- Definition must be aplicable to descripe how {facts} based on the report
- Definitions must be semantically close to the report
- Phrase the definition with this intro {intro}.
- Encapsulate your answer in 20 words or less.
"""
annotator.card['rule_prompt'] = rule_prompt


In [49]:
rule_prompt= """
Your task is to create a definition from the following report which is associated with supply chain relation fact

## Report
{sentence}

## Facts
- {facts}



## Output structure
{
"definition 1": <Defination can be used to indicate this relation>
"definition 2": <Another definition can be used to describe how {facts}>
...
}

## Rules to follow for creating output 
- Write one definition or more that can be used to describe why {facts} from the context of the Report
- Don't include company or products names
- Phrase the definition with this intro {intro}
- Encapsulate your answer in 20 words or less
"""
annotator.card['rule_prompt'] = rule_prompt


In [67]:
Explanation_prompt  ='''
Your task is to provide explanation about the relation between companies mentioned in the report given in ``` quote.
possible relation - {supplier, customer, financial_trade, nothing}

##Report
```
{sentence}

```

## Definitions
here is some defintions that might help to understand how to identify relation between companies

{definitions}

  
## definitions to follow for creating output -
{explanation_output_rules}


## Output structure
[{
"explanation": explain with clarity who is supplier or customer or in financial_trade and to whom, with respect to the definitions, choose to answer with `nothing` if no definition fits with the report,
"span with definition": {"span": span from the report that refer to certain definition, "definition": The exact text of definition that represent the explanation},
"link between span and defintion": explain how the span and the definition are aligned with eachother,
"relation phrase": Company_X is (supplier of or customer of or  in finantial_trade with or nothing) company_Y,
'supplier_and_customer' : [ {'customer': 'company_acting_as_customer', 'supplier': 'company_acting_as_supplier'} ],
'financial_trade': [ [company1_name, company2_name] , [company1_name, company2_name] ], 
'nothing': [ [company1_name, company2_name] , [company1_name, company2_name]]
}]

'''

annotator.card['explanation_prompt'] = Explanation_prompt

In [66]:
Explanation_prompt  ='''
Your task is to provide explanation about the relation between companies mentioned in the report given in ``` quote.
possible relation - {supplier, customer, financial_trade, nothing}

##Report
```
{sentence}

```

## Definitions
here is some defintions that might help to understand how to identify relation between companies

{definitions}

  
## definitions to follow for creating output -
{explanation_output_rules}


## Output Structure
[{
"explanation": Provide a clear explanation of whether the companies are suppliers, customers, involved in financial trade, or if no definition applies to the report. Choose to answer with `nothing` if no definition fits the report.
"span with definition": {"span": The span from the report that refers to a certain definition, "definition": The exact text of the definition that represents the explanation},
"link between span and definition": Explain how the span and the definition align with each other,
"relation phrase": Company_X is (a supplier of, a customer of, involved in financial trade with, or nothing) company_Y,
'supplier_and_customer': [{'customer': 'company_acting_as_customer', 'supplier': 'company_acting_as_supplier'}],
'financial_trade': [[company1_name, company2_name], [company1_name, company2_name]],
'nothing': [[company1_name, company2_name], [company1_name, company2_name]]
}]

'''

annotator.card['explanation_prompt'] = Explanation_prompt

In [None]:
## Definitions
Here are some definitions that might help to understand how to identify the relation between companies:

{definitions}


## Definitions to Follow for Creating Output
{explanation_output_rules}

## Output Structure
[{
"explanation": Provide a clear explanation of whether the companies are suppliers, customers, involved in financial trade, or if no definition applies to the report. Choose to answer with `nothing` if no definition fits the report.
"span with definition": {"span": The span from the report that refers to a certain definition, "definition": The exact text of the definition that represents the explanation},
"link between span and definition": Explain how the span and the definition align with each other,
"relation phrase": Company_X is (a supplier of, a customer of, involved in financial trade with, or nothing) company_Y,
'supplier_and_customer': [{'customer': 'company_acting_as_customer', 'supplier': 'company_acting_as_supplier'}],
 'financial_trade': [[company1_name, company2_name], [company1_name, company2_name]],
 'nothing': [[company1_name, company2_name], [company1_name, company2_name]]
}]

'''

# **_Write updates_**

In [None]:
print(annotator.card['explanation_prompt'])


Your task is to provide explanation about the relation between companies mentioned in the report given in ``` quote.
possible relation - {supplier, customer, financial_trade, nothing}

##Report
```
{sentence}

```

## Definitions
here is some defintions that might help to understand how to identify relation between companies
{definitions}


## Complex definitions for financial_trade, supplier or customer relation
- If two companies are involved in collaboration agreement or joint development, then think in following steps -
  1. find which company is paying money to another company in joint development
  2. if it is not clear who is paying money, then relation is financial_trade
  3. if it is clear who is payee, the payee is customer

  
## definitions to follow for creating output -
{explanation_output_rules}


## Output structure
[
{
"explanation": explain with clarity who is supplier or customer or financial_trade and to whom, with respect to the definition, choose to answer with `

In [112]:
annotator.card['explanation_prompt'] ="""
Your task is to provide explanation about the relation between companies mentioned in the report given in ``` quote.
possible relation - {supplier, customer, financial_trade, nothing}

##Report
```
{sentence}

```

## Definitions
here is some defintions that might help to understand how to identify relation between companies
{definitions}


## Complex definitions for financial_trade, supplier or customer relation
- If two companies are involved in collaboration agreement or joint development, then think in following steps -
  1. find which company is paying money to another company in joint development
  2. if it is not clear who is paying money, then relation is financial_trade
  3. if it is clear who is payee, the payee is customer

  
## definitions to follow for creating output -
{explanation_output_rules}


## Output structure
[{
"explanation": explain with clarity who is supplier or customer or financial_trade and to whom, with respect to the definition, choose to answer with `nothing` if no definition fits with the report 
"span with definition": {"span": the text span that indicate certain definition, "definition": The exact text of definition that represent the explanation}
"link between span and defintion": explain how the span and the definition are aligned with eachother,
"relation phrase": Company_X is (supplier of or customer of or in finantial_trade with or nothing) company_Y ,
'customer_supplier_relations' : [ {'is customer': 'company_acting_as_customer', 'is supplier': 'company_acting_as_supplier'} ],
'financial_trade': [ [company1_name, company2_name] , [company1_name, company2_name] ], 
'nothing': [ [company1_name, company2_name] , [company1_name, company2_name]]
}]
"""

In [200]:

rule = {"Supplier companies may be referred to as companies that are eligible for potential milestone payments and royalties based on the commercialization of products arising from research under a collaboration agreement.":
 {
  'source_file': str(annotator.facts_source),
  'source_idx': 189,
  "source_semantic_score": 0.7847796,
  "approved":True,
  "type":"LLM"
 }
}

In [None]:
##################Refactor Definition File##################

# f_dif = annotator.card.instructions.copy()
# d_fs = {}

# for definition in list(chain(*f_dif.values())):
#     d_fs[definition] = {
#         "approved":True,
#         "source_file":"unknown",
#         "source_idx":"unknown",
#         "source_align":"unknown",
#         "type":"SME" }


In [121]:
with open(annotator._definitions, 'w') as ob:
    yaml.safe_dump(d_fs, ob)

In [68]:
import yaml
with open(src_dir / f"data/llms_datasets/templates/v{str(annotator.version)}/card.yaml", 'w') as obj:
    yaml.safe_dump(dict(annotator.card), obj)

In [72]:
import yaml
with open(src_dir / f"data/llms_datasets/templates/v{str(annotator.version)}/definitions.yaml", 'w') as obj:
    yaml.safe_dump(dict(annotator.instructions), obj)

In [70]:
!pip install InstructorEmbedding

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting InstructorEmbedding
  Downloading InstructorEmbedding-1.0.1-py2.py3-none-any.whl (19 kB)
Installing collected packages: InstructorEmbedding
Successfully installed InstructorEmbedding-1.0.1
[0m

In [71]:
from InstructorEmbedding import INSTRUCTOR
model_ins = INSTRUCTOR('hkunlp/instructor-base')

sentence = "3D ActionSLAM: wearable person tracking in multi-floor environments"
instruction = "Represent the Science title:"
embeddings = model_ins.encode([[instruction,sentence]])
print(embeddings.shape)

# --------------------------------------------------------------------------------

import json
import numpy as np
import faiss
import torch

# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

# FAISS index setup
dimension = 768  # Instructor-XL output dimension
index_ins = faiss.IndexFlatL2(dimension)

# Extract and vectorize data
db_filename = 'arxiv-metadata-10000.json'
num_lines = 10000
batch_size = 4

# Load all papers from JSON
with open(db_filename, 'r') as f:
    papers = [json.loads(line) for line in f]

# Extract the papers' titles and abstracts
texts = [f"{paper['title']}: {paper['abstract']}" for paper in papers]

# Preparation for encoding
instructions = ["Represent the science titles and abstracts: "] * len(texts)

# Prepare the inputs
inputs = [[instr, txt] for instr, txt in zip(instructions, texts)]

# Create vectors using Instructor
vectors = model_ins.encode(
    sentences=inputs[:num_lines],
    batch_size=batch_size,
    show_progress_bar=True,
    convert_to_numpy=True,
    device=str(device)
)

# Add the vectors to the FAISS index
index_ins.add(np.array(vectors).astype('float32'))

print(f"Added {num_lines} papers to the FAISS index.")

# --------------------------------------------------------------------------------

def search_ins(query, k=5):
    vector = model_ins.encode(["Represent the query to a science database: ", query])
    _, indices = index_ins.search(np.array(vector[1]).reshape(1, -1).astype('float32'), k)
    return indices[0]

for query in queries:
    print(f"Question: {query}\n")
    line_numbers = search_ins(query, k=2)
    print_paper_details(line_numbers)
    print('-'*80)

10/08/2023 00:04:15 - INFO - sentence_transformers.SentenceTransformer -   Load pretrained SentenceTransformer: hkunlp/instructor-base


Downloading (…)62736/.gitattributes:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/270 [00:00<?, ?B/s]

Downloading (…)/2_Dense/config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

Downloading (…)15e6562736/README.md:   0%|          | 0.00/66.2k [00:00<?, ?B/s]

Downloading (…)e6562736/config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/439M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)"spiece.model";:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)62736/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.43k [00:00<?, ?B/s]

Downloading (…)6562736/modules.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

load INSTRUCTOR_Transformer


ModuleNotFoundError: No module named 'fused_layer_norm_cuda'