In [None]:
import json
from transformers import AutoModelForCausalLM, AutoTokenizer


def get_prediction(model, tokenizer, text, entity_types, predicates):

    input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates.
      
        **Entity Types:**
        {entity_types}
        
        **Predicates:**
        {predicates}
        
        **Text:**
        {text}
        """

    message = input_format.format(
                entity_types = json.dumps({"entity_types": entity_types}),
                predicates = json.dumps({"predicates": predicates}),
                text = text)

    messages = [{'role': 'user', 'content': message}]
    input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt = True, return_tensors="pt").to("cpu")
    output = tokenizer.decode(model.generate(input_ids=input_ids, max_length=2048)[0], skip_special_tokens=True)
    return output


In [None]:
import json
import re

def extract_json_from_string(input_string: str) -> dict:
    # extract the JSON part from the string
    json_match = re.search(r'```json\s*(\{.*\})\s*```', input_string, re.DOTALL)
    
    if json_match:
        json_str = json_match.group(1)
        try:
            json_data = json.loads(json_str)
            return json_data
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            return None
    else:
        print("No JSON found in the string.")
        return None


In [None]:
def get_entity_dict(string_list):
    result = {}
    for string in string_list:
        # Check if there's exactly one pair of brackets in the string
        if string.count('[') == 1 and string.count(']') == 1:
            # Extract the content inside the brackets
            key = string[string.find('[') + 1:string.find(']')]
            # Extract the substring after the bracket
            value = string[string.find(']') + 2:].strip()
            result[key] = value
    return result

In [None]:
def extract_relation(input_string):
    start_bracket = input_string.find('[')
    end_bracket = input_string.find(']', start_bracket + 1)

    substring = input_string[end_bracket + 1 : input_string.find('[', end_bracket + 1)]
    cleaned_substring = substring.strip()
    
    return cleaned_substring

In [None]:
import re

def extract_triplets(input_string, entity_dict):

    relation = extract_relation(input_string)
    keys = re.findall(r'\[(\d+)\]', input_string)
    nodes = [entity_dict[key] for key in keys if key in entity_dict]
    subject = nodes[0].split(':')[0]
    subject_type = nodes[0].split(':')[1]
    object = nodes[1].split(':')[0]
    object_type = nodes[1].split(':')[1]
    
    return [subject, subject_type, relation, object, object_type]

In [None]:
def get_triplets(prediction):
    triplets = []
    entities_and_triplets = extract_json_from_string(prediction)
    entity_dict = get_entity_dict(entities_and_triplets['entities_and_triples'])
    
    for entity_triplet_string in entities_and_triplets['entities_and_triples']:
        if entity_triplet_string.count('[') == 2 and entity_triplet_string.count(']') == 2:
            triplet = extract_triplets(entity_triplet_string, entity_dict)
            triplets.append(triplet)
    return triplets


In [None]:
model = AutoModelForCausalLM.from_pretrained("sciphi/triplex", trust_remote_code=True).to('cpu').eval()
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)
entity_types = [ "Government Body", "Regulatory Body", "PERSON", "Geopolitical Entity", "COMPANY", "PRODUCT", "EVENT", "SECTOR", "ECON_INDICATOR", "CONCEPT"]
predicates = [ "Has", "Announce", "Operate_In", "Introduce", "Produce", "Control", "Participates_In", "Impact", "Positive_Impact_On", "Negative_Impact_On", "Relate_To", "Is_Member_Of", "Invests_In", "Raise", "Decrease" ]

In [None]:
import csv

header = ["id", "subject", "subject_type", "relation", "object", "object_type", "input", "wiki_page"]
output_filename = "data/triplex_knowledge_extraction.csv"
input_file_path = 'data/mistral_knowledge_extraction.csv'

with open(input_file_path, mode='r') as input_file:
    with open(output_filename, mode='w', newline='') as output_file:
        writer = csv.writer(output_file)
        writer.writerow(header)
        reader = csv.DictReader(input_file)
        
        previous_id = None
        for row in reader:
            current_id = row['id']
            if current_id != previous_id:
                print(row['id'])
                text = row['input']
                prediction = get_prediction(model, tokenizer, text, entity_types, predicates)
                try:
                    triplets = get_triplets(prediction)
                except:
                    print("Error: " + str(current_id))
                    previous_id = current_id
                    continue
                for triplet in triplets:
                    triplet.insert(0, row['id'])
                    triplet.append(row['input'])
                    triplet.append("")
                    writer.writerow(triplet)
            previous_id = current_id
      

print(f"CSV file '{output_file}' has been created successfully.")
