In [46]:
import json
import csv
import pandas as pd
import random
import numpy as np
import google.generativeai as genai
import json
import time
import re

# JSON to CSV Conversion

In [47]:
def parse_edges_to_csv(input_file, output_file):
    """
    Parse edges.jsonl and extract subject, predicate, and object columns to CSV
    """
    with open(input_file, 'r') as jsonl_file, open(output_file, 'w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        
        # Write header
        writer.writerow(['subject', 'predicate', 'object'])
        
        # Process each line in the JSONL file
        for line in jsonl_file:
            try:
                data = json.loads(line.strip())
                subject = data.get('subject', '')
                predicate = data.get('predicate', '')
                obj = data.get('object', '')
                
                writer.writerow([subject, predicate, obj])
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue

In [48]:
input_file = "./data/example_edges.jsonl"
output_file = "edges_output.csv"
    
parse_edges_to_csv(input_file, output_file)
print(f"CSV file created: {output_file}")

CSV file created: edges_output.csv


# Sub graph Preparation (random predicate removal)

In [49]:
def select_chunk_and_remove_predicates(input_csv, chunk_size=100, predicate_removal_percent=50, output_file='modified_chunk.csv'):
    """
    Select a random chunk from the CSV and remove a percentage of edges for each unique predicate.
    
    Args:
        input_csv: Path to the input CSV file
        chunk_size: Number of rows to select (default: 100)
        predicate_removal_percent: Percentage of edges to remove for each unique predicate (default: 50)
        output_file: Path to save the modified chunk
    
    Returns:
        tuple: (original_chunk_df, modified_chunk_df)
    """
    # Read the CSV
    df = pd.read_csv(input_csv)
    
    # Select a random chunk
    if chunk_size >= len(df):
        chunk_df = df.copy()
    else:
        start_idx = random.randint(0, len(df) - chunk_size)
        chunk_df = df.iloc[start_idx:start_idx + chunk_size].copy()
    
    # Store original chunk
    original_chunk = chunk_df.copy()
    modified_chunk = chunk_df.copy()
    
    # Get unique predicates in the chunk
    unique_predicates = modified_chunk['predicate'].unique()
    
    total_removed = 0
    
    # Remove specified percentage of edges for each unique predicate
    for predicate in unique_predicates:
        predicate_indices = modified_chunk[modified_chunk['predicate'] == predicate].index.tolist()
        num_to_remove_pred = int(len(predicate_indices) * (predicate_removal_percent / 100))
        
        if num_to_remove_pred > 0:
            indices_to_remove_pred = random.sample(predicate_indices, num_to_remove_pred)
            modified_chunk.loc[indices_to_remove_pred, 'predicate'] = ''
            total_removed += num_to_remove_pred
    
    # Save modified chunk to CSV
    modified_chunk.to_csv(output_file, index=False)
    
    print(f"Original chunk size: {len(original_chunk)}")
    print(f"Removed {predicate_removal_percent}% of edges for each unique predicate")
    print(f"Total edges with predicates removed: {total_removed}")
    print(f"Modified chunk size: {len(modified_chunk)}")
    print(f"Modified chunk saved to: {output_file}")
    
    return original_chunk, modified_chunk

In [50]:
original, modified = select_chunk_and_remove_predicates(
    'edges_output.csv',
    chunk_size=100,
    predicate_removal_percent=50,
    output_file='modified_chunk_50%_removed.csv'
)

Original chunk size: 100
Removed 50% of edges for each unique predicate
Total edges with predicates removed: 40
Modified chunk size: 100
Modified chunk saved to: modified_chunk_50%_removed.csv


In [51]:
print("Original Chunk:")
print(original.head())
print("\nModified Chunk:")
print(modified.head())
print("\nRemoved Rows:")


Original Chunk:
          subject                 predicate         object
1041   CL:0000576          biolink:produces  NCBIGene:3553
1042   CL:0000576          biolink:produces  NCBIGene:3553
1043  CHEBI:30411  biolink:applied_to_treat  MONDO:0002012
1044  CHEBI:17303            biolink:causes  UMLS:C0040210
1045  CHEBI:17303            biolink:causes  UMLS:C0040210

Modified Chunk:
          subject       predicate         object
1041   CL:0000576                  NCBIGene:3553
1042   CL:0000576                  NCBIGene:3553
1043  CHEBI:30411                  MONDO:0002012
1044  CHEBI:17303  biolink:causes  UMLS:C0040210
1045  CHEBI:17303                  UMLS:C0040210

Removed Rows:


# Random Edge Assignment

In [52]:
# function to randomly assign edges to nodes from the list of unique predicates 
def randomly_assign_edges(input_csv, unique_predicates, output_file='randomly_assigned_edges.csv'):
    """
    Randomly assign edges to nodes from the list of unique predicates and save to new CSV.
    
    Args:
        input_csv: Path to the input CSV file
        unique_predicates: List of unique predicates
        output_file: Path to save the new CSV with randomly assigned edges
    """
    df = pd.read_csv(input_csv)
    
    # Fill empty predicates with random choices from unique_predicates
    for idx, row in df.iterrows():
        if pd.isna(row['predicate']) or row['predicate'] == '' or str(row['predicate']).strip() == '':
            df.at[idx, 'predicate'] = random.choice(unique_predicates)
    
    # Save to new CSV
    df.to_csv(output_file, index=False)
    print(f"Randomly assigned edges saved to: {output_file}")
    
    
    return df

In [53]:
# list of all unique predicates in the dataset which are not empty
unique_predicates = modified['predicate'].unique()
unique_predicates = unique_predicates[unique_predicates != '']
print("Unique predicates found:")
print(unique_predicates)
print(len(unique_predicates))

# Run the improved function
result_df = randomly_assign_edges('modified_chunk_50%_removed.csv', unique_predicates, output_file='randomly_assigned_edges.csv')



Unique predicates found:
['biolink:causes' 'biolink:affects' 'biolink:negatively_correlated_with'
 'biolink:has_part' 'biolink:preventative_for_condition'
 'biolink:positively_correlated_with' 'biolink:regulates'
 'biolink:target_for' 'biolink:decreases_response_to' 'biolink:treats'
 'biolink:occurs_in' 'biolink:manifestation_of' 'biolink:correlated_with'
 'biolink:has_participant' 'biolink:composed_primarily_of'
 'biolink:contributes_to' 'biolink:applied_to_treat' 'biolink:disrupts'
 'biolink:develops_from' 'biolink:related_to'
 'biolink:acts_upstream_of_negative_effect' 'biolink:has_phenotype'
 'biolink:produces' 'biolink:subclass_of'
 'biolink:acts_upstream_of_positive_effect' 'biolink:located_in'
 'biolink:directly_physically_interacts_with' 'biolink:precedes'
 'biolink:associated_with_increased_likelihood_of'
 'biolink:physically_interacts_with' 'biolink:interacts_with'
 'biolink:in_taxon' 'biolink:has_input']
33
Randomly assigned edges saved to: randomly_assigned_edges.csv


# Gemini LLM Edge Assignment

In [54]:
api = "AIzaSyCQVqiw_JyVbMrko4TpplqS0bf2GJCtgr8"

def fill_missing_predicates_llm_base(input_df, unique_predicates, output_file='llm_filled_predicates.csv', 
                                    metrics_file='llm_metrics.json', responses_file='llm_responses.json'):
    """
    Use Gemini API to fill in missing predicates in the DataFrame using a single batch prompt.
    
    Args:
        input_df: DataFrame with potential missing predicates
        unique_predicates: List of unique predicates to choose from
        output_file: Path to save the new CSV with LLM filled predicates
        metrics_file: Path to save metrics about the LLM usage
        responses_file: Path to save all LLM responses for analysis
    
    Returns:
        tuple: (filled_df, metrics, responses)
    """
    # Configure Gemini API
    genai.configure(api_key=api)
    model = genai.GenerativeModel('gemini-2.5-flash')
    
    df = input_df.copy()
    start_time = time.time()
    
    # Find all empty predicate rows
    empty_mask = df['predicate'].isna() | (df['predicate'] == '') | (df['predicate'].str.strip() == '')
    empty_indices = df[empty_mask].index.tolist()
    empty_count = len(empty_indices)
    
    print(f"Found {empty_count} empty predicates to fill")
    
    if empty_count == 0:
        print("No empty predicates found!")
        return df, {}, []
    
    # Build single large prompt with all missing predicates
    predicate_list = ', '.join(unique_predicates)
    
    batch_prompt = f"""You are a biomedical knowledge graph expert. Complete the missing predicates for these triples.

Available predicates: {predicate_list}

Instructions: For each numbered triple, respond with ONLY the most appropriate predicate from the list above.

Triples to complete:
"""
    
    # Add all empty predicate cases to the prompt
    case_mapping = {}  # Maps case number to dataframe index
    for case_num, idx in enumerate(empty_indices, 1):
        row = df.iloc[idx]
        batch_prompt += f"{case_num}. Subject: {row['subject']} | Object: {row['object']}\n"
        case_mapping[case_num] = idx
    
    batch_prompt += f"""
Expected response format:
1. predicate_name
2. predicate_name
3. predicate_name
...

Respond with ONLY the numbered list of predicates, nothing else."""

    print(f"Sending batch request for {empty_count} predicates...")
    
    llm_filled_count = 0
    fallback_count = 0
    successful_requests = 0
    failed_requests = 0
    
    try:
        # Single API request for all missing predicates
        response = model.generate_content(
            batch_prompt,
            generation_config=genai.types.GenerationConfig(
                max_output_tokens=empty_count * 10,  # Adjust based on number of predicates
                temperature=0.3,
                candidate_count=1
            )
        )
        
        response_text = response.text.strip()
        print("✓ Batch API request successful")
        successful_requests = 1
        
        # Parse the response to extract individual predicates
        response_lines = response_text.split('\n')
        
        # Use regex to extract numbered responses
        predicate_suggestions = {}
        for line in response_lines:
            line = line.strip()
            if line:
                # Match patterns like "1. biolink:treats" or "1) biolink:treats" or "1 biolink:treats"
                match = re.match(r'^(\d+)[\.\)\s]+(.+)', line)
                if match:
                    case_num = int(match.group(1))
                    suggested_predicate = match.group(2).strip()
                    
                    # Clean up the suggestion (remove quotes, extra text)
                    suggested_predicate = suggested_predicate.replace('"', '').replace("'", "")
                    
                    # Try exact match first
                    if suggested_predicate in unique_predicates:
                        predicate_suggestions[case_num] = suggested_predicate
                    else:
                        # Try partial matching
                        for predicate in unique_predicates:
                            if predicate in suggested_predicate or suggested_predicate in predicate:
                                predicate_suggestions[case_num] = predicate
                                break
        
        print(f"✓ Successfully parsed {len(predicate_suggestions)} predicates from response")
        
        # Apply the suggestions to the dataframe
        for case_num, idx in case_mapping.items():
            if case_num in predicate_suggestions:
                suggested_predicate = predicate_suggestions[case_num]
                df.at[idx, 'predicate'] = suggested_predicate
                llm_filled_count += 1
                print(f"✓ Row {idx}: Filled with '{suggested_predicate}'")
            else:
                # Fallback to random selection
                fallback_predicate = random.choice(unique_predicates)
                df.at[idx, 'predicate'] = fallback_predicate
                fallback_count += 1
                print(f"⚠ Row {idx}: No suggestion found, used random '{fallback_predicate}'")
        
        # Estimate token usage
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = len(response_text.split()) * 1.3
        
        # Store response details
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'response_text': response_text,
            'parsed_suggestions': predicate_suggestions,
            'case_mapping': case_mapping,
            'success': True,
            'estimated_input_tokens': input_tokens,
            'estimated_output_tokens': output_tokens
        }]
        
    except Exception as e:
        print(f"✗ Batch API request failed: {e}")
        failed_requests = 1
        
        # Fallback: fill all with random predicates
        for idx in empty_indices:
            fallback_predicate = random.choice(unique_predicates)
            df.at[idx, 'predicate'] = fallback_predicate
            fallback_count += 1
        
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'error': str(e),
            'success': False,
            'fallback_used': True
        }]
        
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = 0
    
    end_time = time.time()
    total_time = end_time - start_time
    
    # Create metrics summary
    metrics = {
        'total_empty_predicates': empty_count,
        'llm_filled_count': llm_filled_count,
        'fallback_count': fallback_count,
        'successful_requests': successful_requests,
        'failed_requests': failed_requests,
        'total_requests': successful_requests + failed_requests,
        'success_rate': successful_requests / (successful_requests + failed_requests) if (successful_requests + failed_requests) > 0 else 0,
        'llm_success_rate': llm_filled_count / empty_count if empty_count > 0 else 0,
        'total_processing_time_seconds': total_time,
        'estimated_total_input_tokens': input_tokens,
        'estimated_total_output_tokens': output_tokens,
        'estimated_total_tokens': input_tokens + output_tokens,
        'batch_processing': True,
        'speed_improvement': f"~{empty_count}x faster than individual requests"
    }
    
    # Save files
    df.to_csv(output_file, index=False)
    
    with open(metrics_file, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    with open(responses_file, 'w') as f:
        json.dump(responses, f, indent=2)
    
    # Print summary
    print(f"\n=== Batch LLM Processing Complete ===")
    print(f"LLM filled predicates saved to: {output_file}")
    print(f"Total predicates filled by LLM: {llm_filled_count}/{empty_count}")
    print(f"Fallback (random) assignments: {fallback_count}")
    print(f"LLM success rate: {metrics['llm_success_rate']:.2%}")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Estimated tokens used: {int(input_tokens + output_tokens)}")
    print(f"Speed improvement: ~{empty_count}x faster than individual requests!")
    print(f"Metrics saved to: {metrics_file}")
    print(f"Responses saved to: {responses_file}")
    
    return df, metrics, responses

In [55]:
# Test the LLM function with a subset of data first - Fix the indexing issue
test_df = modified.copy().reset_index(drop=True)  # Reset index to 0, 1, 2, 3...

print(f"DataFrame shape: {test_df.shape}")
print(f"Index range: {test_df.index.min()} to {test_df.index.max()}")

filled_df, metrics, responses = fill_missing_predicates_llm_base(
    test_df,
    unique_predicates,
    output_file='gemini_filled_test.csv',
    metrics_file='gemini_metrics_test.json',
    responses_file='gemini_responses_test.json'
)

DataFrame shape: (100, 3)
Index range: 0 to 99
Found 40 empty predicates to fill
Sending batch request for 40 predicates...


E0000 00:00:1759381559.366299  710520 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


✓ Batch API request successful
✓ Successfully parsed 40 predicates from response
✓ Row 0: Filled with 'biolink:has_part'
✓ Row 1: Filled with 'biolink:has_part'
✓ Row 2: Filled with 'biolink:treats'
✓ Row 4: Filled with 'biolink:causes'
✓ Row 5: Filled with 'biolink:contributes_to'
✓ Row 7: Filled with 'biolink:has_phenotype'
✓ Row 9: Filled with 'biolink:contributes_to'
✓ Row 10: Filled with 'biolink:has_participant'
✓ Row 12: Filled with 'biolink:subclass_of'
✓ Row 13: Filled with 'biolink:occurs_in'
✓ Row 18: Filled with 'biolink:directly_physically_interacts_with'
✓ Row 21: Filled with 'biolink:regulates'
✓ Row 22: Filled with 'biolink:associated_with_increased_likelihood_of'
✓ Row 24: Filled with 'biolink:produces'
✓ Row 27: Filled with 'biolink:has_input'
✓ Row 35: Filled with 'biolink:has_phenotype'
✓ Row 38: Filled with 'biolink:manifestation_of'
✓ Row 43: Filled with 'biolink:directly_physically_interacts_with'
✓ Row 47: Filled with 'biolink:contributes_to'
✓ Row 51: Filled wi

# in progress

In [56]:
import requests
import gzip
import chromadb
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import uuid

# URLs for the .gz files (tool descriptions)
urls = [
    # "https://ftp.ncbi.nlm.nih.gov/pub/lu/PubTator3/bioconcepts2pubtator3.gz",
    "https://ftp.ncbi.nlm.nih.gov/pub/lu/PubTator3/gene2pubtator3.gz"
]

local_files = ["gene2pubtator3.gz"]

# Function to download the .gz files from URLs
def download_file(url, local_path):
    response = requests.get(url)
    with open(local_path, 'wb') as f:
        f.write(response.content)
    print(f"Downloaded file: {local_path}")

# Download the files
# for url, local_file in zip(urls, local_files):
#     download_file(url, local_file)

# Initialize ChromaDB client (new method)
client = chromadb.Client()

# Create or get the collection - handle existing collection
try:
    collection = client.create_collection("pubtator_data")
    print("Created new collection: pubtator_data")
except Exception as e:
    if "already exists" in str(e):
        print("Collection already exists, getting existing collection...")
        collection = client.get_collection("pubtator_data")
        print("Retrieved existing collection: pubtator_data")
    else:
        raise e

# Load the transformer model for embedding
model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize text splitter (chunk text into smaller pieces)
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)

def embed_text(text):
    """Embed text into vector using pre-trained model."""
    embedding = model.encode(text)
    return embedding

# Function to read and process the .gz files
def process_gz_file(file_path):
    with gzip.open(file_path, 'rt') as f:
        return f.readlines()

# Check if collection already has data
existing_count = collection.count()
print(f"Collection currently has {existing_count} documents")

# Only process files if collection is empty or you want to add more data
if existing_count == 0:
    print("Collection is empty, processing files...")
    # Process and insert the data into ChromaDB
    for file in local_files:
        print(f"Processing file: {file}")
        lines = process_gz_file(file)
        
        # Process first 100 lines for testing (remove this limit for full processing)
        lines = lines[:100]  # Limit for testing
        
        # Split lines into smaller chunks
        for line_idx, line in enumerate(lines):
            chunks = splitter.split_text(line.strip())  # Split long descriptions into chunks
            
            # For each chunk, generate embedding and store in ChromaDB
            for chunk_idx, chunk in enumerate(chunks):
                if chunk.strip():  # Only process non-empty chunks
                    embedding = embed_text(chunk)
                    unique_id = f"{file}_{line_idx}_{chunk_idx}_{str(uuid.uuid4())[:8]}"
                    
                    # Add document and metadata (such as source file and chunk position) to ChromaDB
                    collection.add(
                        ids=[unique_id],
                        documents=[chunk],
                        metadatas=[{"source": file, "line_idx": line_idx, "chunk_idx": chunk_idx}],
                        embeddings=[embedding.tolist()]
                    )
    
    print("Data inserted into ChromaDB.")
else:
    print("Collection already contains data. Skipping file processing.")

# Query ChromaDB for relevant documents (example query)
query = "What is gene expression in biological research?"
query_embedding = embed_text(query)

# Perform the retrieval to get top-k similar documents
results = collection.query(
    query_embeddings=[query_embedding.tolist()],
    n_results=5  # Retrieve top 5 similar documents
)

print("Retrieved documents:")
for doc, metadata in zip(results["documents"][0], results["metadatas"][0]):
    print(f"Document: {doc[:200]}... (Source: {metadata['source']})")



Collection already exists, getting existing collection...
Retrieved existing collection: pubtator_data
Collection currently has 100 documents
Collection already contains data. Skipping file processing.
Retrieved documents:
Document: 40757000	Gene	1791	TdT|Terminal deoxynucleotidyl transferase	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	374	AREG|amphiregulin	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	108155	OGT	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	9332	CD163	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	283871	Pgp	PubTator3... (Source: gene2pubtator3.gz)
Collection currently has 100 documents
Collection already contains data. Skipping file processing.
Retrieved documents:
Document: 40757000	Gene	1791	TdT|Terminal deoxynucleotidyl transferase	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	374	AREG|amphiregulin	PubTator3... (Source: gene2pubtator3.gz)
Document: 40757000	Gene	108155	

In [57]:
def query_domain_knowledge(query_text, collection, model, top_k=3):
    """
    Query the ChromaDB collection for relevant domain knowledge.
    
    Args:
        query_text: Text to search for in the knowledge base
        collection: ChromaDB collection containing domain knowledge
        model: SentenceTransformer model for embeddings
        top_k: Number of top relevant documents to retrieve (default: 3)
    """
    query_embedding = embed_text(query_text)
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results

In [58]:
# Optimized RAG-based function using single batch prompt with enhanced explanations and domain knowledge references
def fill_missing_predicates_llm_with_domain_knowledge(input_df, unique_predicates, collection, model,
                                                 output_file='llm_rag_filled_predicates.csv', 
                                                 metrics_file='llm_rag_metrics.json', 
                                                 responses_file='llm_rag_responses.json'):
    """
    Use Gemini API with RAG to fill in missing predicates in DataFrame using optimized single batch prompt.
    
    Args:
        input_df: DataFrame with potential missing predicates
        unique_predicates: List of unique predicates to choose from
        collection: ChromaDB collection for domain knowledge retrieval
        model: SentenceTransformer model for embeddings
        output_file: Path to save the new CSV with LLM filled predicates
        metrics_file: Path to save metrics about the LLM usage
        responses_file: Path to save all LLM responses for analysis
    
    Returns:
        tuple: (filled_df, metrics, responses)
    """
    # Configure Gemini API
    genai.configure(api_key=api)
    llm_model = genai.GenerativeModel('gemini-2.5-flash')
    
    df = input_df.copy()
    start_time = time.time()
    
    # Find all empty predicate rows
    empty_mask = df['predicate'].isna() | (df['predicate'] == '') | (df['predicate'].str.strip() == '')
    empty_indices = df[empty_mask].index.tolist()
    empty_count = len(empty_indices)
    
    print(f"Found {empty_count} empty predicates to fill")
    
    if empty_count == 0:
        print("No empty predicates found!")
        return df, {}, []
    
    # Retrieve domain knowledge for all cases and build single batch prompt
    predicate_list = ', '.join(unique_predicates)
    
    # Build comprehensive batch prompt with RAG context for all cases
    batch_prompt = f"""You are a biomedical knowledge graph expert. Complete the missing predicates for these triples using the provided domain knowledge context.

Available predicates: {predicate_list}

Instructions: 
1. For each numbered triple, provide ONLY the most appropriate predicate from the available list
2. Use the provided domain knowledge context to make informed decisions
3. Each response must include a brief explanation referencing the context that supports your choice
4. Format: "X. predicate_name | Explanation: [brief justification citing relevant context]"

Cases to complete:
"""
    
    # Collect all contexts and build case mapping
    case_mapping = {}
    case_contexts = {}
    
    for case_num, idx in enumerate(empty_indices, 1):
        row = df.iloc[idx]
        case_mapping[case_num] = idx
        
        # Retrieve relevant domain knowledge context for this case
        query_text = f"relationship between {row['subject']} and {row['object']}"
        retrieval_results = query_domain_knowledge(query_text, collection, model, top_k=3)
        
        context = "\n".join(retrieval_results["documents"][0]) if retrieval_results["documents"][0] else "No specific context found."
        case_contexts[case_num] = {
            'context': context,
            'subject': row['subject'],
            'object': row['object'],
            'query': query_text
        }
        
        # Add case to batch prompt with context
        batch_prompt += f"""
{case_num}. Subject: {row['subject']} | Object: {row['object']}
   Domain Context: {context[:500]}...
   
"""
    
    batch_prompt += f"""
Expected response format:
1. predicate_name | Explanation: Based on the context mentioning [specific reference], this predicate best describes...
2. predicate_name | Explanation: The domain knowledge indicates [specific reference], supporting this relationship...
...

Respond with ONLY the numbered list of predicates and explanations, nothing else."""

    print(f"Sending optimized batch request for {empty_count} predicates with RAG context...")
    
    llm_filled_count = 0
    fallback_count = 0
    successful_requests = 0
    failed_requests = 0
    detailed_responses = []
    
    try:
        # Single optimized API request for all missing predicates with RAG context
        response = llm_model.generate_content(
            batch_prompt,
            generation_config=genai.types.GenerationConfig(
                max_output_tokens=empty_count * 50,  # More tokens for explanations
                temperature=0.1,  # Lower temperature for more consistent responses
                candidate_count=1
            )
        )
        
        response_text = response.text.strip()
        print("✓ Batch RAG API request successful")
        successful_requests = 1
        
        # Parse the response to extract predicates and explanations
        response_lines = response_text.split('\n')
        predicate_suggestions = {}
        explanations = {}
        
        for line in response_lines:
            line = line.strip()
            if line and '|' in line:
                # Match patterns like "1. predicate_name | Explanation: ..."
                parts = line.split('|', 1)
                if len(parts) == 2:
                    predicate_part = parts[0].strip()
                    explanation_part = parts[1].strip()
                    
                    # Extract case number and predicate
                    match = re.match(r'^(\d+)[\.\)\s]+(.+)', predicate_part)
                    if match:
                        case_num = int(match.group(1))
                        suggested_predicate = match.group(2).strip()
                        
                        # Clean up the suggestion
                        suggested_predicate = suggested_predicate.replace('"', '').replace("'", "")
                        
                        # Try exact match first
                        if suggested_predicate in unique_predicates:
                            predicate_suggestions[case_num] = suggested_predicate
                            explanations[case_num] = explanation_part
                        else:
                            # Try partial matching
                            for predicate in unique_predicates:
                                if predicate in suggested_predicate or suggested_predicate in predicate:
                                    predicate_suggestions[case_num] = predicate
                                    explanations[case_num] = explanation_part
                                    break
        
        print(f"✓ Successfully parsed {len(predicate_suggestions)} predicates with explanations from response")
        
        # Apply the suggestions to the dataframe
        for case_num, idx in case_mapping.items():
            case_context = case_contexts[case_num]
            
            if case_num in predicate_suggestions:
                suggested_predicate = predicate_suggestions[case_num]
                explanation = explanations.get(case_num, "No explanation provided")
                df.at[idx, 'predicate'] = suggested_predicate
                llm_filled_count += 1
                print(f"✓ Row {idx}: Filled with '{suggested_predicate}' - {explanation}")
                
                detailed_responses.append({
                    'row': idx,
                    'case_num': case_num,
                    'subject': case_context['subject'],
                    'object': case_context['object'],
                    'context': case_context['context'],
                    'suggested_predicate': suggested_predicate,
                    'explanation': explanation,
                    'success': True,
                    'method': 'RAG_LLM'
                })
            else:
                # Fallback to random selection
                fallback_predicate = random.choice(unique_predicates)
                df.at[idx, 'predicate'] = fallback_predicate
                fallback_count += 1
                print(f"⚠ Row {idx}: No suggestion found, used random '{fallback_predicate}'")
                
                detailed_responses.append({
                    'row': idx,
                    'case_num': case_num,
                    'subject': case_context['subject'],
                    'object': case_context['object'],
                    'context': case_context['context'],
                    'suggested_predicate': fallback_predicate,
                    'explanation': "Fallback: No valid suggestion from LLM",
                    'success': False,
                    'method': 'Random_Fallback'
                })
        
        # Estimate token usage
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = len(response_text.split()) * 1.3
        
        # Store comprehensive response details
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'response_text': response_text,
            'parsed_suggestions': predicate_suggestions,
            'explanations': explanations,
            'case_mapping': case_mapping,
            'case_contexts': case_contexts,
            'detailed_responses': detailed_responses,
            'success': True,
            'estimated_input_tokens': input_tokens,
            'estimated_output_tokens': output_tokens,
            'rag_enhanced': True
        }]
        
    except Exception as e:
        print(f"✗ Batch RAG API request failed: {e}")
        failed_requests = 1
        
        # Fallback: fill all with random predicates
        for idx in empty_indices:
            fallback_predicate = random.choice(unique_predicates)
            df.at[idx, 'predicate'] = fallback_predicate
            fallback_count += 1
            
            detailed_responses.append({
                'row': idx,
                'suggested_predicate': fallback_predicate,
                'explanation': f"API Error Fallback: {str(e)}",
                'success': False,
                'method': 'Error_Fallback'
            })
        
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'error': str(e),
            'success': False,
            'fallback_used': True,
            'detailed_responses': detailed_responses,
            'rag_enhanced': True
        }]
        
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = 0
    
    end_time = time.time()
    total_time = end_time - start_time
    
    # Create enhanced metrics summary
    metrics = {
        'total_empty_predicates': empty_count,
        'llm_filled_count': llm_filled_count,
        'fallback_count': fallback_count,
        'successful_requests': successful_requests,
        'failed_requests': failed_requests,
        'total_requests': successful_requests + failed_requests,
        'success_rate': successful_requests / (successful_requests + failed_requests) if (successful_requests + failed_requests) > 0 else 0,
        'llm_success_rate': llm_filled_count / empty_count if empty_count > 0 else 0,
        'total_processing_time_seconds': total_time,
        'estimated_total_input_tokens': input_tokens,
        'estimated_total_output_tokens': output_tokens,
        'estimated_total_tokens': input_tokens + output_tokens,
        'batch_processing': True,
        'rag_enhanced': True,
        'explanations_provided': True,
        'speed_improvement': f"~{empty_count}x faster than individual requests",
        'context_retrieval_enabled': True
    }
    
    # Save files
    df.to_csv(output_file, index=False)
    
    with open(metrics_file, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    with open(responses_file, 'w') as f:
        json.dump(responses, f, indent=2)
    
    # Print comprehensive summary
    print(f"\n=== Optimized RAG-Enhanced LLM Processing Complete ===")
    print(f"LLM filled predicates saved to: {output_file}")
    print(f"Total predicates filled by LLM: {llm_filled_count}/{empty_count}")
    print(f"Fallback (random) assignments: {fallback_count}")
    print(f"LLM success rate: {metrics['llm_success_rate']:.2%}")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Estimated tokens used: {int(input_tokens + output_tokens)}")
    print(f"Speed improvement: ~{empty_count}x faster than individual requests!")
    print(f"RAG context: ✓ Enhanced with domain knowledge")
    print(f"Explanations: ✓ Provided with context references")
    print(f"Metrics saved to: {metrics_file}")
    print(f"Responses saved to: {responses_file}")
    
    return df, metrics, responses

In [59]:

# Ensure unique predicates are available
test_df_rag = modified.copy().reset_index(drop=True)  # Reset index to 0, 1, 2, 3...
if 'unique_predicates'in globals():
    unique_predicates = modified['predicate'].unique()
    unique_predicates = unique_predicates[unique_predicates != '']
    
    filled_df_rag, metrics_rag, responses_rag = fill_missing_predicates_llm_with_domain_knowledge(
    test_df_rag,
    unique_predicates,
    collection,
    model,
    output_file='gemini_rag_filled_test.csv',
    metrics_file='gemini_rag_metrics_test.json',
    responses_file='gemini_rag_responses_test.json'
    )
    
    print("RAG-enhanced LLM processing complete.")
# else:

Found 40 empty predicates to fill
Sending optimized batch request for 40 predicates with RAG context...
Sending optimized batch request for 40 predicates with RAG context...


E0000 00:00:1759381609.366532  710520 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


✓ Batch RAG API request successful
✓ Successfully parsed 40 predicates with explanations from response
✓ Row 0: Filled with 'biolink:produces' - Explanation: T-helper cells (CL:0000576) are known to produce Interleukin-2 (IL2, NCBIGene:3553).
✓ Row 1: Filled with 'biolink:produces' - Explanation: T-helper cells (CL:0000576) are known to produce Interleukin-2 (IL2, NCBIGene:3553).
✓ Row 2: Filled with 'biolink:contributes_to' - Explanation: Dysregulation of calcium ions (CHEBI:30411) can contribute to the development of various cancers, including ovarian cancer (MONDO:0002012).
✓ Row 4: Filled with 'biolink:contributes_to' - Explanation: Glucose (CHEBI:17303) is a primary metabolic fuel, and its altered metabolism is crucial for tumor (UMLS:C0040210) growth and proliferation.
✓ Row 5: Filled with 'biolink:regulates' - Explanation: Phosphatidylinositol-3,4,5-trisphosphate (CHEBI:63631) is a signaling lipid that regulates various membrane activities, including transmembrane transporter ac