In [1]:
import os
import json
import time
from datetime import datetime

import pandas as pd
import numpy as np

import tiktoken
import gzip
from openai import OpenAI

In [2]:
system_prompt = """
You are a data scientist specializing in grouping plant biological interactions. Your task is to cluster similar edges while strictly adhering to the following guidelines:  

1. Exact Phrase Matching Matters:   
1.1 Consider the Entire Phrase: Treat each edge as a single, whole phrase. This includes all key biological terms and any bracketed text
1.2 Ignore Minor Surface Differences: Minor variations such as letter casing (uppercase vs. lowercase), spacing, punctuation, standard abbreviations, or singular vs. plural forms do not create new or separate edges.  

2. Strict (100%) Key Term Separation: If an edge has a different key biological term, it MUST GO into a separate cluster.  

3. Sub-identifier separation: If an edge differs by any numeric value, sub-identifier, or qualifier, they MUST BE placed in separate clusters.  

4. Avoid False Similarity: DO NOT cluster two edges together simply because they share a common word or term if their overall key term or concept is different. You should cluster them together if the semantic meaning is the same, example: "[protein] interacts with [metabolite]" and "[protein] is interacting with [metabolite]" should be clustered together.

5. Extra Descriptor Differentiation: If one edge has an extra descriptor that changes its meaning, do not group them together. However, if the extra descriptor is a synonym, then group them together.

6. Strict Synonym/Near-Synonym Grouping: Only group edges together if they refer to the exact same biological structure, process, or concept.  

7. Maintain 100% Precision: If there is any doubt about whether two edges are the same, MUST place them in separate clusters.  

8. Preserve Original Data: DO NOT introduce new items, create duplicates, or omit any edge from your final output.  

9. Output Format: Always return results in valid JSON format. You MUST USE GIVEN KEY. The output must be a list of lists, where each cluster is its own list. The total number of output entries must match the input entries.

10. Choose Cluster Representative:
  10.1 For every cluster containing more than one edge, YOU MUST choose exactly one representative edge.
  10.2 The representative edge must be the most appropriate and easy-to-understand version.
  10.3 Enclose the representative edge with '**' (double asterisks) in the output.
  10.4 For single-element clusters, no representative modification is required.

Read the input list, and return clustered edges, strictly following the given guidelines above.

Example input 1:"21329":  ['[protein] could interact with [metabolite]', '[protein] could potentially interact with [metabolite]', '[protein] might interact with [metabolite]', '[protein] most definitely interacts with [metabolite]']

Example output 1: {"21329": [["[protein] could interact with [metabolite]", "[protein] could potentially interact with [metabolite]", "**[protein] might interact with [metabolite]**"], "[protein] most difinitely interacts with [metabolite]"]}

Example input 2: "211": ['[organism] is beneficial for [process]', '[organism] is beneficial to [organism]', '[organism] beneficial for [process]', '[organism] beneficial for [organism]', '[organism] included plant beneficial microorganisms like [organism]', '[organism] beneficial to [organism]', '[organism] provide beneficial functions to [organism]', '[organism] may be beneficial to [organism]', '[organism] is beneficial for [organism]', '[organism] beneficial to [organ]', '[organism] are beneficial to [organism]', '[organism] confers beneficial traits to [organism]', '[organism] healthy to [organism]', '[organism] useful to [organism]', '[organism] beneficial for [environment]', '[organism] colonization is often of benefit to [organism]', '[interaction] beneficial for [organism]', '[organism] may be beneficial for [organism]', '[organism] provide beneficial services to [organism]']

Example output 2: {"211": [["**[organism] is beneficial for [process]**", "[organism] beneficial for [process]"], ["**[organism] is beneficial to [organism]**", "[organism] beneficial to [organism]", "[organism] are beneficial to [organism]", "[organism] is beneficial for [organism]", "[organism] beneficial for [organism]"], ["[organism] included plant beneficial microorganisms like [organism]"], ["[organism] provide beneficial functions to [organism]"], ["[organism] may be beneficial to [organism]"], ["[organism] beneficial to [organ]"], ["[organism] confers beneficial traits to [organism]"], ["[organism] healthy to [organism]"], ["[organism] useful to [organism]"], ["[organism] beneficial for [environment]"], ["[organism] colonization is often of benefit to [organism]"], ["[interaction] beneficial for [organism]"], ["[organism] may be beneficial for [organism]"], ["[organism] provide beneficial services to [organism]"]]}

Example input 3: "213":['[gene] were co-repressed or co-activated by [gene]', '[gene] selective induction/repression of [gene]', '[gene] may be repressed by [gene]', '[gene] could be repressed by [gene]', '[gene] can be repressed by interacting with [gene]', '[gene] are induced or repressed by [gene]', '[gene] repressible by [mutant]', '[gene] were induced or repressed by [other]', '[gene] were induced or repressed by [gene]', '[gene] can be repressed by inducing [gene]', '[gene] are crucial in expressing or repressing [gene]', '[gene] inducing or repressing [gene]', '[gene] induced or repressed by [gene]', '[gene] activated or repressed by [gene]']

Example output 3: {"213": [["[gene] were co-repressed or co-activated by [gene]"], ["[gene] selective induction/repression of [gene]"], ["**[gene] may be repressed by [gene]**", "[gene] could be repressed by [gene]"], ["[gene] can be repressed by interacting with [gene]"], ["**[gene] are induced or repressed by [gene]**", "[gene] were induced or repressed by [gene]", "[gene] induced or repressed by [gene]", "[gene] inducing or repressing [gene]", "[gene] activated or repressed by [gene]"], ["[gene] repressible by [mutant]"], ["[gene] were induced or repressed by [other]"], ["[gene] can be repressed by inducing [gene]"], ["[gene] are crucial in expressing or repressing [gene]"]]}


Note: The category enclosed by "[]" denotes the node type, and in edge clustering, edges should only be clustered together if both node types match.

Warning:
Do NOT output newlines or "\n" in the output
ALWAYS remember to select a group repressentative enclosed by **, this is critical, and must not be omitted, VERY IMPORTANT
BE VERY SURE that you do not add any new terms to the output, and that the output entries EXACTLY matches the input entries, that the output format is correct, and that the output is a list of lists. The total number of output entries must match the input entries.
"""

In [3]:
def prepare_data(array_of_strings):
    """Prepare GO terms data for embedding."""
    encoding = tiktoken.get_encoding("cl100k_base")
    input_data = []
    total_tokens = 0

    for edge in array_of_strings:
        text = edge
        id = edge.split(":")[0]
        n_tokens = len(encoding.encode(text))
        total_tokens += n_tokens
        
        input_data.append({
            'id': id,
            'text': text,
            'n_tokens': n_tokens
        })

    estimated_cost = (total_tokens / 1000000) * 0.1
    print(f"Total number of submissions: {len(input_data)}")
    print(f"Total input tokens: {total_tokens}")
    
    return input_data


def create_batch_file(completion_data, query_dir, timestamp, system_prompt, model="o3-mini", model_prefix=None, options=None, chunk_number=None):
    """Create JSONL file for batch completion requests.
    
    Args:
        completion_data (list): List of dictionaries containing request data
        query_dir (str): Directory to save the batch file
        timestamp (str): Timestamp for the filename
        system_prompt (str): System prompt to use for all requests
        model (str): Model to use for completion requests (default: "o3-mini")
        model_prefix (str): Prefix for the model name
        options (dict): Additional model options like temperature or reasoning_effort
        chunk_number (optional): Chunk number if processing in batches
    """

    print(f"Model: {model}, Model prefix: {model_prefix}, Options: {options}")
    if options is None:
        #exit and ask user to provide options
        raise ValueError("Options are required")

    # Create filename based on model type
    if model_prefix is None:
        model_prefix = model
    batch_fname = f'{model_prefix}_completion_requests_{timestamp}'
    if chunk_number is not None:
        batch_fname += f'_chunk_{chunk_number}'
    batch_fname += '.jsonl'
    batch_file_path = os.path.join(query_dir, batch_fname)

    with open(batch_file_path, 'w') as f:
        for item in completion_data:
            request = {
                "custom_id": item['id'],
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "response_format": {"type": "json_object"},
                    "messages": [
                        {
                            "role": "system",
                            "content": system_prompt
                        },
                        {
                            "role": "user",
                            "content": item['text']
                        }
                    ],
                    **options
                }
            }
            f.write(json.dumps(request) + '\n')
    
    return batch_file_path


def submit_batch_job(client, batch_file_path, description):
    """Submit batch job to OpenAI API."""
    
    batch_input_file = client.files.create(
        file=open(batch_file_path, "rb"),
        purpose="batch"
    )
    #check if file is uploaded successfully
    while batch_input_file.status != "processed":
        time.sleep(1)
    
    batch = client.batches.create(
        input_file_id=batch_input_file.id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={"description": description}
    )
    print(f"batch created\n{batch.id}")
    return batch


def create_multiple_batch_files(client, input_array, timestamp, query_dir, system_prompt, model, model_prefix, options, max_batch_size=50_000, description="Edge embeddings batch job"):
    """Create multiple batch files for embedding."""
    print(f"Model: {model}, Model prefix: {model_prefix}, Options: {options}")
    #split embeddings_array into chunks of max_batch_size
    chunks = [input_array[i:i+max_batch_size] for i in range(0, len(input_array), max_batch_size)]
    print(f"Number of chunks: {len(chunks)}")
    batch_ids = []
    
    batch_file_paths = []
    for i, chunk in enumerate(chunks):
        batch_file_path = create_batch_file(chunk, query_dir, timestamp, system_prompt, model = model, model_prefix=model_prefix, options=options, chunk_number=i)
        batch_file_paths.append(batch_file_path)
        batch= submit_batch_job(client, batch_file_path, description)
        batch_ids.append(batch.id)
        

    #save the batch_ids and batch_file_paths to a csv file
    batch_info = pd.DataFrame({"batch_id": batch_ids, "batch_file_path": batch_file_paths})
    batch_info.to_csv(os.path.join(query_dir, f"batch_info_{timestamp}.csv"), index=False)
        
    return batch_ids, batch_file_paths


def wait_for_completion(client, batch_id, initial_pause=4):
    """Wait for batch job completion with exponential backoff."""
    pause = initial_pause
    completed = False
    while not completed:
        batch = client.batches.retrieve(batch_id)
        print(batch.status)
        
        if batch.status == "completed":
            print("Batch completed")
            completed = True
            return batch
        
        print(f"Batch not completed, pausing for {pause} seconds")
        time.sleep(pause)
        pause *= 2


def wait_for_completion_for_multiple_batches(client, batch_ids, initial_pause=4):
    """Wait for completion of multiple batch jobs with exponential backoff."""
    pause = initial_pause
    all_completed = False
    #mark all batches as not completed
    batch_statuses_dict = {batch_id: "bla" for batch_id in batch_ids}
    batches = []
    while all_completed == False:
        for batch_id in batch_ids:
            if batch_statuses_dict[batch_id] != "completed":
                batch = client.batches.retrieve(batch_id)
                if batch.status == "completed":
                    batch_statuses_dict[batch_id] = batch.status
                    batches.append(batch)
                else:
                    batch_statuses_dict[batch_id] = batch.status
                
        #check if all batches are completed
        if all(status == "completed" for status in batch_statuses_dict.values()):
            print("All batches completed")
            print(batch_statuses_dict)
            all_completed = True
            return batches
        
        
        print(batch_statuses_dict)
        
        if pause > 300:
            pause = 300
            
        print(f"Batch not completed, pausing for {pause} seconds")
        time.sleep(pause)
        pause *= 2


def process_output(client, batch, output_dir, timestamp, batch_file_path, chunk_number=None, model_prefix="completion"):
    """Download jsonl files from openai and save them to the output_dir"""

    # Save output
    output_fname = f'embedding_output_{timestamp}'
    if chunk_number is not None:
        output_fname = f'{model_prefix}_output_{timestamp}_chunk_{chunk_number}.jsonl'
    output_path = os.path.join(output_dir, output_fname)
    if os.path.exists(output_path):
        print(f"output file already exists, skipping {output_path}")
        return output_path

    batch_output_file = client.files.content(batch.output_file_id)
    content = batch_output_file.text
    split_content = content.split("\n")[:-1]
    
    print(f"batch output file saved to {output_path}")
    with open(output_path, 'w') as f:
        for line in split_content:
            f.write(line + '\n')

    # Save batch info
    batch_info = {
        "batch_file_path": batch_file_path,
        "batch_id": batch.id,
        "output_file_id": batch.output_file_id
    }
    
    log_fname = f'{model_prefix}_log_{timestamp}.json'
    if chunk_number is not None:
        log_fname = f'{model_prefix}_log_{timestamp}_chunk_{chunk_number}.json'
    log_path = os.path.join(output_dir, log_fname)
    
    with open(log_path, 'w') as f:
        json.dump(batch_info, f, indent=4)

    
    return output_path


def process_output_for_multiple_batches(client, batch_ids, batch_file_paths, output_dir, timestamp, model_prefix):
    """Calls process_output for each batch and returns a list of output paths"""
    i = 0
    output_paths = []
    for batch_id, batch_file_path in zip(batch_ids, batch_file_paths):
        batch = client.batches.retrieve(batch_id)
        output_path = process_output(client, batch, output_dir, timestamp, batch_file_path, chunk_number=i, model_prefix=model_prefix)

        i += 1
        output_paths.append(output_path)
    
    return output_paths


def load_output_jsonl(file_path, client=None, api_key=None):
    """Load output from JSONL file."""
    with open(file_path, 'rt') as f:  # Changed to gzip.open with text mode
        data = []
        for line in f:

                json_obj = json.loads(line)
                embedding = json_obj['response']['body']['choices'][0]['message']['content']
                data.append({
                    'id': json_obj['custom_id'],
                    'embedding': embedding
                })

    # Convert to DataFrame
    df = pd.DataFrame(data)
    return df


def load_output_for_multiple_batches(output_paths, client=None, api_key=None):
    """Calls load_output_jsonl for each output path and concatenates the results."""
    output_dfs = pd.concat([load_output_jsonl(output_path, client, api_key) for output_path in output_paths])
    return output_dfs



In [4]:

cluster_df = pd.read_parquet('/home/mads/connectome/data/embeddings/edge_embeddings/clustering/edge_clusters_6_max_clust_30.parquet')
print(f"Total number of clusters: {len(cluster_df)}")
#search the lists in ids_in_cluster for strings containing "interacts with"
#cluster_df = cluster_df[cluster_df['ids_in_cluster'].apply(lambda x: any("interact" in item.lower() for item in x))]
cluster_df = cluster_df[cluster_df.cluster_size > 2]

#pick 1000 random rows from cluster_df, set seed to 42
cluster_df = cluster_df.sample(n=10, random_state=42)


#create the input string. the string should be "index: the list of edges as a string"
cluster_df["input"] = cluster_df.apply(lambda row: f"{row.name}: {row['ids_in_cluster']}", axis=1)
inputs = cluster_df["input"].tolist()


Total number of clusters: 116976


In [8]:
# Load API key for OpenAI
with open("data/api_key.txt", "r") as f:
    api_key = f.read()
# Initialize OpenAI client
client = OpenAI(api_key=api_key)

query_dir = "/home/mads/connectome/data/predictions/queries"
# Check if the query directory exists; if not, create it
if not os.path.exists(query_dir):
    os.makedirs(query_dir)

output_dir = "/home/mads/connectome/data/predictions/output"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Prepare embedding data from unique connection types
run_from_scratch = input("Do you want to run from scratch? (y/n): ")
#assert that run_from_scratch is either "y" or "n", and give a message if it is not
assert run_from_scratch in ["y", "n"], "run_from_scratch must be either 'y' or 'n'"
batch_size = 5

model_prefix = "v9"
model = "ft:gpt-4o-mini-2024-07-18:mutwil-lab:4omini-v9-train-1306-test-78:B9u5DUsf"
options = {"temperature": 0}
description = "Edge disambiguation"


if run_from_scratch == "y":

    # Generate a timestamp for file identification
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print(f"Timestamp file identifier: {timestamp}")

    input_array = prepare_data(inputs)

    batch_ids, batch_file_paths= create_multiple_batch_files(client = client,
                                                                                input_array = input_array, 
                                                                                timestamp = timestamp, 
                                                                                query_dir = query_dir, 
                                                                                system_prompt = system_prompt, 
                                                                                model = model, 
                                                                                model_prefix = model_prefix,
                                                                                options = options,
                                                                                max_batch_size = batch_size,
                                                                                description = description)
    
if run_from_scratch == "n":
    #load the batch_info from the csv file
    #get user input for the timestamp
    timestamp = input("Enter the timestamp for the run you want to load: ")
    batch_info = pd.read_csv(os.path.join(query_dir, f"batch_info_{timestamp}.csv"))
    batch_ids = batch_info["batch_id"].tolist()
    batch_file_paths = batch_info["batch_file_path"].tolist()



wait_for_completion_for_multiple_batches(client, batch_ids, initial_pause=4)

output_paths = process_output_for_multiple_batches(client, batch_ids, batch_file_paths, output_dir, timestamp, model_prefix)

output_df = load_output_for_multiple_batches(output_paths, client, api_key)

output_df.to_excel(f"/home/mads/connectome/data/predictions/{model_prefix}_output_{timestamp}.xlsx", index=False)