In [10]:
import os
import optuna
import openai
import pandas as pd
import json
import logging
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize the client object with your OpenAI API key
client = openai.OpenAI(
    api_key=os.getenv("OPENAI_API_KEY")  # Ensure your API key is correctly set
)

# Ensure the API key is available
if not client.api_key:
    raise ValueError("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")

# Data classes for structured data
@dataclass
class Label:
    text: str
    type: str

@dataclass
class Edge:
    label1: str
    label2: str
    relationship: str

@dataclass
class Record:
    input_text: str
    target_labels: List[Label]
    target_edges: List[Edge]

# Function to validate and clean the JSONL file
def validate_and_clean_jsonl(file_path: str, output_path: str) -> None:
    valid_data = []
    with open(file_path, 'r', encoding='utf-8') as infile:
        for line_number, line in enumerate(infile, 1):
            line = line.strip()
            if not line:
                continue  # Skip empty lines
            try:
                json_data = json.loads(line)
                valid_data.append(json_data)
            except json.JSONDecodeError as e:
                logger.warning(f"Skipping invalid line {line_number}: {e}")

    # Save cleaned data back to a new JSONL file
    with open(output_path, 'w', encoding='utf-8') as outfile:
        for entry in valid_data:
            json.dump(entry, outfile)
            outfile.write('\n')

# Function to load synthetic data from the JSONL file
def load_synthetic_data(file_path: str) -> List[Record]:
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f, 1):
            try:
                json_data = json.loads(line.strip())
                record = Record(
                    input_text=json_data['input_text'],
                    target_labels=[
                        Label(
                            text=label['Label'],
                            type=label['Type']
                        ) for label in json_data['target_labels']
                    ],
                    target_edges=[
                        Edge(
                            label1=edge['Label 1'],
                            label2=edge['Label 2'],
                            relationship=edge['Relationship 1 -> 2']
                        ) for edge in json_data['target_edges']
                    ]
                )
                data.append(record)
            except json.JSONDecodeError as e:
                logger.error(f"Error parsing line {line_number}: {e}")
            except KeyError as e:
                logger.error(f"Missing key {e} in line {line_number}")
    return data

# Caching embeddings to avoid redundant API calls
embedding_cache = {}

def get_embeddings(texts: List[str]) -> List[np.ndarray]:
    embeddings = []
    for text in texts:
        # Clean the input text
        text = text.replace("\n", " ")

        if text in embedding_cache:
            embeddings.append(embedding_cache[text])
        else:
            # Use the client object to create embeddings
            response = client.embeddings.create(
                input=[text],  # Input must be a list
                model="text-embedding-3-small"
            )

            # Extract the embedding from the response
            embedding = np.array(response.data[0].embedding)

            # Cache the embedding for future use
            embedding_cache[text] = embedding
            embeddings.append(embedding)

    return embeddings

# Function to compute semantic similarity using embeddings
def semantic_similarity_score(predicted: str, target: str) -> float:
    embeddings = get_embeddings([predicted, target])
    similarity = cosine_similarity(embeddings[0], embeddings[1])
    return similarity

# Function to compute cosine similarity
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# Function to score structured output
def score_structured_output(
    predicted_labels: List[Label],
    predicted_edges: List[Edge],
    target_labels: List[Label],
    target_edges: List[Edge]
) -> float:
    # Scoring for labels
    matched_labels = 0
    total_labels = len(target_labels)

    for target_label in target_labels:
        for predicted_label in predicted_labels:
            similarity = semantic_similarity_score(predicted_label.text, target_label.text)
            if similarity > 0.9 and predicted_label.type == target_label.type:
                matched_labels += 1
                break
            elif similarity > 0.8 and predicted_label.type == target_label.type:
                matched_labels += 0.5  # Partial credit
                break

    label_score = matched_labels / total_labels if total_labels > 0 else 0

    # Scoring for edges
    matched_edges = 0
    total_edges = len(target_edges)

    for target_edge in target_edges:
        for predicted_edge in predicted_edges:
            sim_label1 = semantic_similarity_score(predicted_edge.label1, target_edge.label1)
            sim_label2 = semantic_similarity_score(predicted_edge.label2, target_edge.label2)
            sim_relation = semantic_similarity_score(predicted_edge.relationship, target_edge.relationship)

            if sim_label1 > 0.9 and sim_label2 > 0.9 and sim_relation > 0.9:
                matched_edges += 1
                break
            elif sim_label1 > 0.8 and sim_label2 > 0.8 and sim_relation > 0.8:
                matched_edges += 0.5  # Partial credit
                break

    edge_score = matched_edges / total_edges if total_edges > 0 else 0

    # Aggregate and Normalize Scores
    final_score = 0.5 * label_score + 0.5 * edge_score  # Adjust weights as needed

    return final_score

# Improved parsing function
def parse_output_to_structured(output: str) -> Tuple[List[Label], List[Edge]]:
    predicted_labels = []
    predicted_edges = []

    try:
        json_output = json.loads(output)
        entities = json_output.get('entities', [])
        relationships = json_output.get('relationships', [])

        for entity in entities:
            predicted_labels.append(Label(text=entity['text'], type=entity['type']))

        for relationship in relationships:
            predicted_edges.append(Edge(
                label1=relationship['entity1'],
                label2=relationship['entity2'],
                relationship=relationship['type']
            ))
    except json.JSONDecodeError:
        logger.warning("Output is not valid JSON. Attempting to parse as text.")
        # Fallback to text parsing if JSON parsing fails
        lines = output.strip().split('\n')
        for line in lines:
            if '->' in line:  # Edge case (relationship extraction)
                parts = line.split('->')
                if len(parts) == 3:
                    predicted_edges.append(Edge(
                        label1=parts[0].strip(),
                        label2=parts[1].strip(),
                        relationship=parts[2].strip()
                    ))
            elif ':' in line:  # Entity case
                parts = line.split(':')
                if len(parts) == 2:
                    predicted_labels.append(Label(
                        text=parts[0].strip(),
                        type=parts[1].strip()
                    ))
    return predicted_labels, predicted_edges

# Define the objective function for Optuna
def objective(trial):
    # Define prompt variations and parameters
    prompt_template = trial.suggest_categorical('prompt_template', [
        "Extract entities and their relationships from the following text:\n\n{}",
        "Identify all key entities and relationships in the text below:\n\n{}",
        "Analyze the document and list all entities and their connections:\n\n{}"
    ])

    temperature = trial.suggest_float('temperature', 0.0, 1.0)
    max_tokens = trial.suggest_int('max_tokens', 150, 500)

    # Create the prompt using the selected template and the input text from the record
    prompt = prompt_template.format(record.input_text)

    def get_response(prompt, temperature=0.7, max_tokens=150):
        # Use the already initialized client object for API calls
        response = client.chat.completions.create(
            model="gpt-4o-2024-08-06",
            messages=[
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            max_tokens=max_tokens
        )

        # Return the content of the response
        return response.choices[0].message.content

    try:
        # Call get_response with the correct arguments
        response_content = get_response(prompt, temperature, max_tokens)

        # Parse the output into structured format
        predicted_labels, predicted_edges = parse_output_to_structured(response_content)

        # Score the structured output
        score = score_structured_output(
            predicted_labels, predicted_edges,
            record.target_labels, record.target_edges
        )
    except Exception as e:
        logger.error(f"An error occurred during API call: {e}")
        return 0.0  # Return a score of 0 on failure

    return score

# Main execution
if __name__ == "__main__":
    # Clean the JSONL file and save to a new file
    validate_and_clean_jsonl('synthetic_data.jsonl', 'cleaned_synthetic_data.jsonl')

    # Load the cleaned synthetic data
    synthetic_data = load_synthetic_data('cleaned_synthetic_data.jsonl')

    if not synthetic_data:
        logger.error("No data available after cleaning.")
        exit(1)

    # Use the first record for this example
    record = synthetic_data[0]

    # Run optimization
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=50)

    # Output the best trial
    logger.info(f"Best trial: {study.best_trial}")


[I 2024-09-12 20:57:26,414] A new study created in memory with name: no-name-571b5d22-d8ae-4b45-ad16-6ac16a502c61
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST 

KeyboardInterrupt: 