In [1]:
import os
import optuna
import openai
import json
import logging
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
from dotenv import load_dotenv
from jinja2 import Template
import jsonschema
from jsonschema import validate, ValidationError

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.DEBUG)  # Set to DEBUG for detailed logs
logger = logging.getLogger(__name__)

# Initialize the OpenAI client with your API key

openai.api_key = os.getenv("OPENAI_API_KEY")
# Create an OpenAI client instance
client = openai.OpenAI(api_key=openai.api_key)

  # Ensure your API key is correctly set

# Ensure the API key is available
if not openai.api_key:
    raise ValueError("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")

# Data classes for structured data
@dataclass
class Label:
    text: str
    type: str

@dataclass
class Edge:
    label1: str
    label2: str
    relationship: str

@dataclass
class Record:
    input_text: str
    target_labels: List[Label]
    target_edges: List[Edge]

# Function to validate and clean the JSONL file
def validate_and_clean_jsonl(file_path: str, output_path: str) -> None:
    valid_data = []
    with open(file_path, 'r', encoding='utf-8') as infile:
        for line_number, line in enumerate(infile, 1):
            line = line.strip()
            if not line:
                continue  # Skip empty lines
            try:
                json_data = json.loads(line)
                valid_data.append(json_data)
            except json.JSONDecodeError as e:
                logger.warning(f"Skipping invalid line {line_number}: {e}")

    # Save cleaned data back to a new JSONL file
    with open(output_path, 'w', encoding='utf-8') as outfile:
        for entry in valid_data:
            json.dump(entry, outfile)
            outfile.write('\n')

# Function to load synthetic data from the JSONL file
def load_synthetic_data(file_path: str) -> List[Record]:
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f, 1):
            try:
                json_data = json.loads(line.strip())
                record = Record(
                    input_text=json_data['input_text'],
                    target_labels=[
                        Label(
                            text=label['Label'],
                            type=label['Type']
                        ) for label in json_data['target_labels']
                    ],
                    target_edges=[
                        Edge(
                            label1=edge['Label 1'],
                            label2=edge['Label 2'],
                            relationship=edge['Relationship 1 -> 2']
                        ) for edge in json_data['target_edges']
                    ]
                )
                data.append(record)
            except json.JSONDecodeError as e:
                logger.error(f"Error parsing line {line_number}: {e}")
            except KeyError as e:
                logger.error(f"Missing key {e} in line {line_number}")
    return data

# Caching embeddings to avoid redundant API calls
embedding_cache = {}

def get_embeddings(texts: List[str]) -> List[np.ndarray]:
    embeddings = []
    for text in texts:
        # Clean the input text
        text = text.replace("\n", " ")

        if text in embedding_cache:
            embeddings.append(embedding_cache[text])
        else:
            try:
                # Use the OpenAI API to create embeddings
                response = client.embeddings.create(input=[text])  # Input must be a list
                model="text-embedding-ada-002"  # Replace with the desired model)

                # Extract the embedding from the response
                embedding = np.array(response.data[0].embedding)

                # Cache the embedding for future use
                embedding_cache[text] = embedding
                embeddings.append(embedding)
            except Exception as e:
                logger.error(f"Error fetching embeddings for text '{text}': {e}")
                embeddings.append(np.zeros(768))  # Assuming the embedding size is 768
    return embeddings

# Function to compute cosine similarity
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    if np.linalg.norm(vec1) == 0 or np.linalg.norm(vec2) == 0:
        return 0.0
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# Function to compute semantic similarity using embeddings
def semantic_similarity_score(predicted: str, target: str) -> float:
    embeddings = get_embeddings([predicted, target])
    similarity = cosine_similarity(embeddings[0], embeddings[1])
    return similarity

# Define JSON schema for validation
schema = {
    "type": "object",
    "properties": {
        "entities": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "entity_name": {"type": "string"}
                },
                "required": ["entity_name"]
            }
        },
        "relationships": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "source_entity": {"type": "string"},
                    "target_entity": {"type": "string"},
                    "relationship_type": {"type": "string"},
                    "relationship_strength": {"type": "integer"}
                },
                "required": ["source_entity", "target_entity", "relationship_type", "relationship_strength"]
            }
        }
    },
    "required": ["entities", "relationships"]
}

# Function to parse and validate the model's JSON output
def parse_output_to_structured(output: str) -> Tuple[List[Label], List[Edge]]:
    predicted_labels = []
    predicted_edges = []

    try:
        json_output = json.loads(output)
        validate(instance=json_output, schema=schema)  # Validate against the schema

        entities = json_output.get('entities', [])
        relationships = json_output.get('relationships', [])

        for entity in entities:
            predicted_labels.append(Label(text=entity['entity_name'], type=''))  # 'type' can be inferred or left empty

        for relationship in relationships:
            predicted_edges.append(Edge(
                label1=relationship['source_entity'],
                label2=relationship['target_entity'],
                relationship=relationship['relationship_type']
            ))
    except (json.JSONDecodeError, ValidationError) as e:
        logger.warning(f"Invalid JSON output: {e}")
        # Implement fallback parsing or retry mechanisms if necessary
    except KeyError as e:
        logger.error(f"Missing key in JSON output: {e}")

    return predicted_labels, predicted_edges

# Function to score the model's output
def score_structured_output(
    predicted_labels: List[Label],
    predicted_edges: List[Edge],
    target_labels: List[Label],
    target_edges: List[Edge]
) -> float:
    # Scoring for labels
    matched_labels = 0
    total_labels = len(target_labels)

    for target_label in target_labels:
        for predicted_label in predicted_labels:
            similarity = semantic_similarity_score(predicted_label.text, target_label.text)
            if similarity > 0.9 and predicted_label.type == target_label.type:
                matched_labels += 1
                break
            elif similarity > 0.8 and predicted_label.type == target_label.type:
                matched_labels += 0.5  # Partial credit
                break

    label_score = matched_labels / total_labels if total_labels > 0 else 0

    # Scoring for edges
    matched_edges = 0
    total_edges = len(target_edges)

    for target_edge in target_edges:
        for predicted_edge in predicted_edges:
            sim_label1 = semantic_similarity_score(predicted_edge.label1, target_edge.label1)
            sim_label2 = semantic_similarity_score(predicted_edge.label2, target_edge.label2)
            sim_relation = semantic_similarity_score(predicted_edge.relationship, target_edge.relationship)

            if sim_label1 > 0.9 and sim_label2 > 0.9 and sim_relation > 0.9:
                matched_edges += 1
                break
            elif sim_label1 > 0.8 and sim_label2 > 0.8 and sim_relation > 0.8:
                matched_edges += 0.5  # Partial credit
                break

    edge_score = matched_edges / total_edges if total_edges > 0 else 0

    # Aggregate and Normalize Scores
    final_score = 0.5 * label_score + 0.5 * edge_score  # Adjust weights as needed

    return final_score

# Define the objective function for Optuna
def create_objective(records: List[Record], prompt_templates: List[str]) -> callable:
    """
    Creates an Optuna objective function with access to the records and prompt templates.
    """
    def objective(trial):
        # Randomly select a prompt template
        selected_prompt_template = trial.suggest_categorical('prompt_template', prompt_templates)

        # Create a Jinja2 Template object
        template = Template(selected_prompt_template)

        # Suggest parameters
        temperature = trial.suggest_float('temperature', 0.0, 0.3)  # Lower temperature for more deterministic outputs
        max_tokens = trial.suggest_int('max_tokens', 200, 600)     # Ensure enough tokens for complete JSON

        total_score = 0.0
        valid_record_count = 0

        for record in records:
            # Render the template with the current record's input_text
            rendered_prompt = template.render(input_text=record.input_text)

            # Function to get response from OpenAI
            def get_response(prompt, temperature=0.7, max_tokens=150):
                try:
                    response = client.chat.completions.create(model="gpt-4",  # Replace with "gpt-4o-2024-08-06" if that's your specific model
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=temperature,
                    max_tokens=max_tokens)
                    return response.choices[0].message.content
                except Exception as e:
                    logger.error(f"OpenAI API call failed: {e}")
                    return ""

            try:
                # Get response from the model
                response_content = get_response(rendered_prompt, temperature, max_tokens)

                if not response_content:
                    logger.warning(f"No response for record: {record.input_text}")
                    continue

                # Log the API response for debugging
                logger.debug(f"API Response for input '{record.input_text}': {response_content}")

                # Parse the output into structured format
                predicted_labels, predicted_edges = parse_output_to_structured(response_content)

                # Score the structured output
                score = score_structured_output(
                    predicted_labels, predicted_edges,
                    record.target_labels, record.target_edges
                )

                total_score += score
                valid_record_count += 1

            except Exception as e:
                logger.error(f"An error occurred while processing record: {e}")
                continue  # Skip to the next record

        # Calculate average score
        average_score = total_score / valid_record_count if valid_record_count > 0 else 0

        logger.info(f"Trial Score: {average_score} (Temperature: {temperature}, Max Tokens: {max_tokens})")

        return average_score

    return objective

# Define prompt templates
prompt_templates = [
    """
    -Goal-
    Given a text document that is potentially relevant to this activity, identify all digital evidence entities and their relationships, focusing on forensic artifacts.

    -Steps-
    1. **Identify all digital evidence entities.** Digital evidence can encompass a wide range of data types and sources, including but not limited to:

        1. **Personal Identifiers:** Name, Address, Phone number, Email address, Social Security number, Date of birth
        2. **Network Information:** IP address, MAC address, Login credentials
        3. **Communication Records:** Emails, Text messages, Social media messages and posts
        4. **Financial Data:** Bank account information, Credit card numbers, Transaction ID, Cryptocurrency wallet addresses
        5. **Location Data:** GPS latitude and longitude
        6. **Device Information:** Device type and model, Operating system and version, Installed applications, System logs
        7. **Internet Activity:** Browsing URL, Search queries

            For each identified entity, extract the following information:
            - **entity_name:** Name of the entity, capitalized

    2. **Identify all pairs of related entities.** From the entities identified in step 1, determine all pairs of (source_entity, target_entity) that are clearly related to each other. Common types of relationships between digital evidence include:

        1. **Communication Relationships:** e.g., [Phone number A, calls, Phone number B]
        2. **Ownership/Association:** e.g., [Person, owns, Device]
        3. **Temporal Relationships:** e.g., [File A, created before, File B]
        4. **Spatial Relationships:** e.g., [Device, located at, GPS coordinates]
        5. **Causal Relationships:** e.g., [Malware installation, causes, Data breach]
        6. **Data Flow:** e.g., [File, transferred from, Device A, to, Device B]
        7. **Access Relationships:** e.g., [User, accesses, File]
        8. **Modification Relationships:** e.g., [User, edits, Document]
        9. **Financial Transactions:** e.g., [Account A, transfers funds to, Account B]
        10. **Social Connections:** e.g., [User A, friends with, User B] on a social network
        11. **Software Interactions:** e.g., [Application, generates, Log file]
        12. **Content Relationships:** e.g., [Document A, contains similar text to, Document B]

            For each pair of related entities, extract the following information:
            - **source_entity:** Name of the source entity, as identified in step 1
            - **target_entity:** Name of the target entity, as identified in step 1
            - **relationship_type:** One of the relationship types listed above
            - **relationship_strength:** A numeric score indicating the strength of the relationship between the source entity and target entity (1-10, where 10 is the strongest)

    3. **Format the Output in JSON:**
        - Structure the output as a JSON object with two main keys: `"entities"` and `"relationships"`.
        - **Entities:** An array of objects, each representing an entity with its `entity_name`.
        - **Relationships:** An array of objects, each detailing the relationship between two entities, including the `source_entity`, `target_entity`, `relationship_type`, and `relationship_strength`.

            **Example Output:**
            ```json
            {% raw %}
            {
                "entities": [
                    {"entity_name": "Jessica Harper"},
                    {"entity_name": "John Doe"},
                    {"entity_name": "@jessica_harps"},
                    {"entity_name": "Instagram"},
                    {"entity_name": "jessicatravels.com"}
                ],
                "relationships": [
                    {
                        "source_entity": "Jessica Harper",
                        "target_entity": "@jessica_harps",
                        "relationship_type": "operates",
                        "relationship_strength": 8
                    },
                    {
                        "source_entity": "Jessica Harper",
                        "target_entity": "John Doe",
                        "relationship_type": "tagged",
                        "relationship_strength": 7
                    },
                    {
                        "source_entity": "Jessica Harper",
                        "target_entity": "jessicatravels.com",
                        "relationship_type": "linked",
                        "relationship_strength": 6
                    }
                ]
            }
            {% endraw %}
            ```

        4. **Return the JSON Output:**
            - Ensure the entire JSON object is returned as the output without additional text or explanations.
            - Use **{% raw %}{completion_delimiter}{% endraw %}** to signify the end of the output.

    -Real Data-
    Text: {{ input_text }}
    """
    # You can add more refined prompt variations here if needed
]

def main():
    # Path to your JSONL input file
    input_file_path = 'input.jsonl'  # Replace with your actual file path

    # Optional: Validate and clean the JSONL file
    # clean_file_path = 'clean_input.jsonl'
    # validate_and_clean_jsonl(input_file_path, clean_file_path)
    # records = load_synthetic_data(clean_file_path)

    # Load synthetic data from the JSONL file
    records = load_synthetic_data(input_file_path)
    logger.info(f"Loaded {len(records)} records from {input_file_path}")

    if not records:
        logger.error("No valid records found. Exiting.")
        return

    # Create the objective function with access to the records and prompt templates
    objective = create_objective(records, prompt_templates)

    # Create an Optuna study
    study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=42))

    # Optimize the study
    study.optimize(objective, n_trials=50)  # Adjust n_trials as needed

    # Output the best parameters
    logger.info(f"Best trial parameters: {study.best_params}")
    logger.info(f"Best trial score: {study.best_value}")

    # Optionally, save the study results
    # study.trials_dataframe().to_csv('optuna_study_results.csv', index=False)

if __name__ == '__main__':
    main()
                

  from .autonotebook import tqdm as notebook_tqdm
DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
DEBUG:httpx:load_verify_locations cafile='C:\\Users\\omar2\\.conda\\envs\\graphRAG\\Library\\ssl\\cacert.pem'
INFO:__main__:Loaded 19 records from input.jsonl
[I 2024-09-17 04:58:13,853] A new study created in memory with name: no-name-8dc36d95-9531-4aa4-a777-16c57afc3b0f
DEBUG:openai._base_client:Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'user', 'content': '\n    -Goal-\n    Given a text document that is potentially relevant to this activity, identify all digital evidence entities and their relationships, focusing on forensic artifacts.\n\n    -Steps-\n    1. **Identify all digital evidence entities.** Digital evidence can encompass a wide range of data types and sources, including but not limited to:\n\n        1. **Personal Identifiers:** Name, Address, Phone number, Email address, Soc

KeyboardInterrupt: 