In [None]:
from openai import OpenAI

client = OpenAI(
    api_key="",
)

In [None]:
def get_prediction(news_prompt, client):
    content = """
    Note that the entities should not be generic, numerical or temporal (like dates or percentages). Entities must be classified into the following categories: ORG: Organizations other than government or regulatory bodies
    ORGGOV: Government bodies (e.g., "United States Government")
    ORGREG: Regulatory bodies (e.g., "Federal Reserve")
    PERSON: Individuals (e.g., "Elon Musk")
    GPE: Geopolitical entities such as countries, cities, etc. (e.g., "Germany")
    COMP: Companies (e.g., "Google")
    PRODUCT: Products or services (e.g., "iPhone")
    EVENT: Specific and Material Events (e.g., "Olympic Games", "Covid-19")
    SECTOR: Company sectors or industries (e.g., "Technology sector")
    ECON_INDICATOR: Economic indicators (e.g., "Inflation rate"), numerical value like "10%" is not a ECON_INDICATOR; FIN_INSTRUMENT: Financial and market instruments (e.g., "Stocks", "Global Markets")
    CONCEPT: Abstract ideas or notions or themes (e.g., "Inflation", "AI", "Climate Change")
    The relationships 'r' between these entities must be represented by one of the following relation verbs set: Has, Announce, Operate_In, Introduce, Produce, Control, Participates_In, Impact, Positive_Impact_On, Negative_Impact_On, Relate_To, Is_Member_Of, Invests_In, Raise, Decrease. Remember to conduct entity disambiguation, consolidating different phrases or acronyms that refer to the same entity (for instance, "UK Central Bank", "BOE" and "Bank of England" should be unified as "Bank of England"). Simplify each entity of the triplet to be less than four words.
    From this text, your output Must be in python dict made up of nested list made up of ['h', 'type', 'r', 'o', 'type'], where the relationship 'r' must be in the given relation verbs set above. Only output the list. 
    As an Example, consider the following news excerpt:
    Input: 'Apple Inc. is set to introduce the new iPhone 14 in the technology sector this month. The product's release is likely to positively impact Apple's stock value.' 
    OUTPUT:{"output":[['Apple Inc.', 'COMP', 'Introduce', 'iPhone 14', 'PRODUCT'], ['Apple Inc.', 'COMP', 'Operate_In', 'Technology Sector', 'SECTOR'], ['iPhone 14', 'PRODUCT', *Positive_Impact_On', 'Apple's Stock Value', 'FIN_INSTRUMENT']]} Dont use output like "Based on the given input text, here is the output in the required format:"
    The output MUST be json object made up of nested list and MUST not be anything apart from above OUTPUT. Return only the json object. 
    INPUT_TEXT:
    """ + news_prompt

    response = client.chat.completions.create(model="o1-preview", messages= [{ "role": "user", "content": content }])
    content = response.choices[0].message.content
    print(content)
    return content

In [None]:
import re
import json

def fix_json_string(json_string):
    # Replace single quotes with double quotes
    json_string = re.sub(r"(?<!\\)'", '"', json_string)
    
    # Handle the case where trailing commas might be problematic in arrays or objects
    json_string = re.sub(r",(\s*[}\]])", r"\1", json_string)
    
    try:
        # Attempt to parse the corrected JSON string to ensure it's valid
        json_obj = json.loads(json_string)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON format after fixing: {e}")

    return json.dumps(json_obj, indent=2)

In [None]:
import json

def get_triplets(prediction):
    fixed_json_string = fix_json_string(prediction)
    triplets = json.loads(fixed_json_string)["output"]
    return triplets
    

In [None]:
import csv

header = ["id", "subject", "subject_type", "relation", "object", "object_type", "input", "wiki_page"]
output_filename = "data/o1_knowledge_extraction.csv"
input_file_path = 'data/triplex_knowledge_extraction.csv'

with open(input_file_path, mode='r') as input_file:
    with open(output_filename, mode='w', newline='') as output_file:
        writer = csv.writer(output_file)
        writer.writerow(header)
        reader = csv.DictReader(input_file)
        
        previous_id = None
        for row in reader:
            current_id = row['idx']
            if current_id != previous_id:
                print(row['idx'])
                text = row['input']
                prediction = get_prediction(text, client)
                try:
                    triplets = get_triplets(prediction)
                except:
                    print("Error: " + str(current_id))
                    previous_id = current_id
                    continue
                for triplet in triplets:
                    triplet.insert(0, row['idx'])
                    triplet.append(row['input'])
                    triplet.append("")
                    writer.writerow(triplet)
            previous_id = current_id
      

print(f"CSV file '{output_file}' has been created successfully.")
