In [None]:
import json
import csv
import pandas as pd
import random
import numpy as np
# import google.generativeai as genai
import json
import time
import re
import requests
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
import uuid
import re
import boto3
import json
import time
import boto3
import json
import time
import re

# JSON to CSV Conversion

In [2]:
def parse_edges_to_csv(input_file, output_file):
    """
    Parse edges.jsonl and extract subject, predicate, and object columns to CSV
    """
    with open(input_file, 'r') as jsonl_file, open(output_file, 'w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        
        # Write header
        writer.writerow(['subject', 'predicate', 'object'])
        
        # Process each line in the JSONL file
        for line in jsonl_file:
            try:
                data = json.loads(line.strip())
                subject = data.get('subject', '')
                predicate = data.get('predicate', '')
                obj = data.get('object', '')
                
                writer.writerow([subject, predicate, obj])
            except json.JSONDecodeError:
                print(f"Skipping invalid JSON line: {line}")
                continue

In [3]:
input_file = "./data/example_edges.jsonl"
output_file = "edges_output.csv"
    
parse_edges_to_csv(input_file, output_file)
print(f"CSV file created: {output_file}")

CSV file created: edges_output.csv


# Sub graph Preparation (random predicate removal)

In [4]:
def select_chunk_and_remove_predicates(input_csv, chunk_size=100, predicate_removal_percent=50, output_file='modified_chunk.csv'):
    """
    Select a random chunk from the CSV and remove a percentage of edges for each unique predicate.
    
    Args:
        input_csv: Path to the input CSV file
        chunk_size: Number of rows to select (default: 100)
        predicate_removal_percent: Percentage of edges to remove for each unique predicate (default: 50)
        output_file: Path to save the modified chunk
    
    Returns:
        tuple: (original_chunk_df, modified_chunk_df)
    """
    # Read the CSV
    df = pd.read_csv(input_csv)
    
    # Select a random chunk
    if chunk_size >= len(df):
        chunk_df = df.copy()
    else:
        start_idx = random.randint(0, len(df) - chunk_size)
        chunk_df = df.iloc[start_idx:start_idx + chunk_size].copy()
    
    # Store original chunk
    original_chunk = chunk_df.copy()
    modified_chunk = chunk_df.copy()
    
    # Get unique predicates in the chunk
    unique_predicates = modified_chunk['predicate'].unique()
    
    total_removed = 0
    
    # Remove specified percentage of edges for each unique predicate
    for predicate in unique_predicates:
        predicate_indices = modified_chunk[modified_chunk['predicate'] == predicate].index.tolist()
        num_to_remove_pred = int(len(predicate_indices) * (predicate_removal_percent / 100))
        
        if num_to_remove_pred > 0:
            indices_to_remove_pred = random.sample(predicate_indices, num_to_remove_pred)
            modified_chunk.loc[indices_to_remove_pred, 'predicate'] = ''
            total_removed += num_to_remove_pred
    
    # Save modified chunk to CSV
    modified_chunk.to_csv(output_file, index=False)
    
    print(f"Original chunk size: {len(original_chunk)}")
    print(f"Removed {predicate_removal_percent}% of edges for each unique predicate")
    print(f"Total edges with predicates removed: {total_removed}")
    print(f"Modified chunk size: {len(modified_chunk)}")
    print(f"Modified chunk saved to: {output_file}")
    
    return original_chunk, modified_chunk

In [5]:
original, modified = select_chunk_and_remove_predicates(
    './data/alzheimers_triples.csv',
    chunk_size=200,
    predicate_removal_percent=50,
    output_file='modified_chunk_50%_removed.csv'
)

Original chunk size: 200
Removed 50% of edges for each unique predicate
Total edges with predicates removed: 99
Modified chunk size: 200
Modified chunk saved to: modified_chunk_50%_removed.csv


In [6]:
print("Original Chunk:")
print(original.head())
print("\nModified Chunk:")
print(modified.head())
print("\nRemoved Rows:")


Original Chunk:
          subject                                      predicate  \
1020  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1021  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1022  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1023  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1024  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   

             object  
1020  MONDO:0021042  
1021  MONDO:0005335  
1022  MONDO:0008383  
1023  MONDO:0004972  
1024  MONDO:0100010  

Modified Chunk:
          subject                                      predicate  \
1020  CHEBI:35475                                                  
1021  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1022  CHEBI:35475  biolink:treats_or_applied_or_studied_to_treat   
1023  CHEBI:35475                                                  
1024  CHEBI:35475                                                  

             obje

# Random Edge Assignment

In [7]:
# function to randomly assign edges to nodes from the list of unique predicates 
def randomly_assign_edges(input_csv, unique_predicates, output_file='randomly_assigned_edges.csv'):
    """
    Randomly assign edges to nodes from the list of unique predicates and save to new CSV.
    
    Args:
        input_csv: Path to the input CSV file
        unique_predicates: List of unique predicates
        output_file: Path to save the new CSV with randomly assigned edges
    """
    df = pd.read_csv(input_csv)
    
    # Fill empty predicates with random choices from unique_predicates
    for idx, row in df.iterrows():
        if pd.isna(row['predicate']) or row['predicate'] == '' or str(row['predicate']).strip() == '':
            df.at[idx, 'predicate'] = random.choice(unique_predicates)
    
    # Save to new CSV
    df.to_csv(output_file, index=False)
    print(f"Randomly assigned edges saved to: {output_file}")
    
    
    return df

In [8]:
# list of all unique predicates in the dataset which are not empty
unique_predicates = modified['predicate'].unique()
unique_predicates = unique_predicates[unique_predicates != '']
print("Unique predicates found:")
print(unique_predicates)
print(len(unique_predicates))

# Run the improved function
result_df = randomly_assign_edges('modified_chunk_50%_removed.csv', unique_predicates, output_file='randomly_assigned_edges.csv')



Unique predicates found:
['biolink:treats_or_applied_or_studied_to_treat' 'biolink:subclass_of'
 'biolink:related_to' 'biolink:contributes_to' 'biolink:has_phenotype']
5
Randomly assigned edges saved to: randomly_assigned_edges.csv


# Bedrock LLM Edge Assignment

In [None]:
def fill_missing_predicates_llm_bedrock(input_df, unique_predicates, output_file='bedrock_filled_predicates.csv', 
                                        metrics_file='bedrock_metrics.json', responses_file='bedrock_responses.json'):
    """
    Use AWS Bedrock API with openai.gpt-oss-120b-1:0 to fill in missing predicates using a single batch prompt.
    """
    # Configure AWS Bedrock client
    try:
        bedrock_client = boto3.client(
            'bedrock-runtime',
            region_name='us-east-1',
        )
    except Exception as e:
        print(f"Failed to configure Bedrock client: {e}")
        bedrock_client = None
    
    df = input_df.copy()
    start_time = time.time()
    
    # Find all empty predicate rows
    empty_mask = df['predicate'].isna() | (df['predicate'] == '') | (df['predicate'].str.strip() == '')
    empty_indices = df[empty_mask].index.tolist()
    empty_count = len(empty_indices)
    
    print(f"Found {empty_count} empty predicates to fill")
    
    if empty_count == 0:
        print("No empty predicates found!")
        return df, {}, []
    
    # Build simplified prompt for better parsing
    predicate_list = '\n'.join([f"- {pred}" for pred in unique_predicates])
    
    batch_prompt = f"""Fill missing predicates for biomedical relationships. Choose ONLY from these options:

{predicate_list}

Cases:
"""
    
    # Add all empty predicate cases to the prompt
    case_mapping = {}
    for case_num, idx in enumerate(empty_indices, 1):
        row = df.iloc[idx]
        batch_prompt += f"{case_num}. {row['subject']} _____ {row['object']}\n"
        case_mapping[case_num] = idx
    
    batch_prompt += f"""

Respond with ONLY numbered predicates (one per line):
1. predicate_name
2. predicate_name
etc."""

    print(f"Sending batch request to AWS Bedrock for {empty_count} predicates...")
    
    llm_filled_count = 0
    successful_requests = 0
    failed_requests = 0
    responses = []
    
    try:
        if bedrock_client:
            # Prepare the request body for openai.gpt-oss-120b-1:0
            request_body = {
                "messages": [
                    {
                        "role": "user",
                        "content": batch_prompt
                    }
                ],
                "max_tokens": empty_count * 5000,  # Reduced tokens for just predicates
                "temperature": 0.1,  # Lower temperature for consistency
                "top_p": 0.9
            }
            
            # Single API request for all missing predicates
            response = bedrock_client.invoke_model(
                modelId="openai.gpt-oss-120b-1:0",
                body=json.dumps(request_body),
                contentType="application/json",
                accept="application/json"
            )
            
            # Parse the response
            response_body = json.loads(response['body'].read())
            response_text = response_body['choices'][0]['message']['content'].strip()
            
            print("✓ Batch Bedrock API request successful")
            print(f"Response preview: {response_text[:200]}...")
            successful_requests = 1
        else:
            print("No Bedrock client available")
            failed_requests = 1
            response_text = ""
        
        # Parse the response with improved logic
        predicate_suggestions = {}
        if response_text:
            for line in response_text.split('\n'):
                line = line.strip()
                if line:
                    # Match patterns like "1. predicate_name" or "1) predicate_name"
                    match = re.match(r'^(\d+)[\.\)\s]+(.+)', line)
                    if match:
                        case_num = int(match.group(1))
                        suggested_predicate = match.group(2).strip()
                        
                        # Clean up the suggestion
                        suggested_predicate = suggested_predicate.replace('"', '').replace("'", "")
                        # Take only the first token if multiple words
                        suggested_predicate = suggested_predicate.split()[0] if suggested_predicate.split() else suggested_predicate
                        
                        # Try exact match first
                        if suggested_predicate in unique_predicates:
                            predicate_suggestions[case_num] = suggested_predicate
                        else:
                            # Try partial matching
                            for predicate in unique_predicates:
                                if predicate in suggested_predicate or suggested_predicate in predicate:
                                    predicate_suggestions[case_num] = predicate
                                    break
        
        print(f"✓ Successfully parsed {len(predicate_suggestions)} predicates from response")
        
        # Fill predicates: use LLM suggestions where available, first predicate otherwise
        first_predicate = unique_predicates[0]
        for case_num, idx in case_mapping.items():
            if case_num in predicate_suggestions:
                suggested_predicate = predicate_suggestions[case_num]
                df.at[idx, 'predicate'] = suggested_predicate
                llm_filled_count += 1
                print(f"✓ Row {idx}: Filled with '{suggested_predicate}'")
            else:
                # Use first predicate instead of random
                df.at[idx, 'predicate'] = first_predicate
                print(f"→ Row {idx}: Used first predicate '{first_predicate}'")
        
        # Estimate token usage
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = len(response_text.split()) * 1.3 if response_text else 0
        
        # Store response details
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'response_text': response_text,
            'parsed_suggestions': predicate_suggestions,
            'case_mapping': case_mapping,
            'success': successful_requests > 0,
            'estimated_input_tokens': input_tokens,
            'estimated_output_tokens': output_tokens,
            'model_used': 'openai.gpt-oss-120b-1:0',
            'provider': 'AWS Bedrock'
        }]
        
    except Exception as e:
        print(f"✗ Batch Bedrock API request failed: {e}")
        failed_requests = 1
        
        # Fill all with first predicate on error
        first_predicate = unique_predicates[0]
        for idx in empty_indices:
            df.at[idx, 'predicate'] = first_predicate
        
        responses = [{
            'batch_request': True,
            'total_cases': empty_count,
            'prompt': batch_prompt,
            'error': str(e),
            'success': False,
            'fallback_used': True,
            'fallback_predicate': first_predicate,
            'model_used': 'openai.gpt-oss-120b-1:0',
            'provider': 'AWS Bedrock'
        }]
        
        input_tokens = len(batch_prompt.split()) * 1.3
        output_tokens = 0
    
    end_time = time.time()
    total_time = end_time - start_time
    
    # Create metrics summary
    metrics = {
        'total_empty_predicates': empty_count,
        'llm_filled_count': llm_filled_count,
        'fallback_count': empty_count - llm_filled_count,
        'successful_requests': successful_requests,
        'failed_requests': failed_requests,
        'total_requests': successful_requests + failed_requests,
        'success_rate': successful_requests / (successful_requests + failed_requests) if (successful_requests + failed_requests) > 0 else 0,
        'llm_success_rate': llm_filled_count / empty_count if empty_count > 0 else 0,
        'total_processing_time_seconds': total_time,
        'estimated_total_input_tokens': input_tokens,
        'estimated_total_output_tokens': output_tokens,
        'estimated_total_tokens': input_tokens + output_tokens,
        'batch_processing': True,
        'model_used': 'openai.gpt-oss-120b-1:0',
        'provider': 'AWS Bedrock'
    }
    
    # Save files
    df.to_csv(output_file, index=False)
    
    with open(metrics_file, 'w') as f:
        json.dump(metrics, f, indent=2)
    
    with open(responses_file, 'w') as f:
        json.dump(responses, f, indent=2)
    
    # Print summary
    print(f"\n=== Fixed Batch AWS Bedrock Processing Complete ===")
    print(f"Model: openai.gpt-oss-120b-1:0")
    print(f"LLM filled predicates saved to: {output_file}")
    print(f"Total predicates filled by LLM: {llm_filled_count}/{empty_count}")
    print(f"Non-random assignments: {empty_count}")
    print(f"LLM success rate: {metrics['llm_success_rate']:.2%}")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Estimated tokens used: {int(input_tokens + output_tokens)}")
    print(f"Metrics saved to: {metrics_file}")
    print(f"Responses saved to: {responses_file}")
    
    return df, metrics, responses


In [10]:
# Test the LLM function with a subset of data first - fix indexing issue
test_df = modified.copy().reset_index(drop=True)  # Reset index to 0, 1, 2, ...

print(f"DataFrame shape: {test_df.shape}")
print(f"Index range: {test_df.index.min()} to {test_df.index.max()}")

filled_df, metrics, responses = fill_missing_predicates_llm_bedrock(
    test_df,
    unique_predicates,
    output_file='bedrock_filled_test.csv',
    metrics_file='bedrock_metrics_test.json',
    responses_file='bedrock_responses_test.json'
)


DataFrame shape: (200, 3)
Index range: 0 to 199
Found 99 empty predicates to fill
Sending batch request to AWS Bedrock for 99 predicates...
✓ Batch Bedrock API request successful
Response preview: <reasoning>We need to fill missing predicates for biomedical relationships, choosing only from given options:

- biolink:treats_or_applied_or_studied_to_treat
- biolink:subclass_of
- biolink:related_t...
✓ Successfully parsed 99 predicates from response
✓ Row 0: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 3: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 4: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 5: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 6: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 7: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 8: Filled with 'biolink:subclass_of'
✓ Row 10: Filled with 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 13: Filled with 

# Rag Bedrock

In [None]:
import requests
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
import uuid
import re
import boto3
import json
import time
import boto3
import json
import time
import re
# -------------------------
# STEP 1: Search PubMed for IDs
# -------------------------
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {"db": "pubmed", "term": "alzheimers", "retmax": 100, "retmode": "json"}
resp = requests.get(search_url, params=params).json()
pmids = resp["esearchresult"]["idlist"]

print(f"Found {len(pmids)} PMIDs: {pmids}")

# -------------------------
# STEP 2: Fetch abstracts per PMID (clean affiliations)
# -------------------------
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
abstracts = {}
for pmid in pmids:
    params = {"db": "pubmed", "id": pmid, "rettype": "abstract", "retmode": "text"}
    txt = requests.get(fetch_url, params=params).text.strip()
    # Remove affiliations/emails
    txt_clean = "\n".join([
        line for line in txt.split("\n")
        if not re.search(r'@|Department|University|Center', line)
    ])
    abstracts[pmid] = txt_clean

print(f"Fetched {len(abstracts)} clean abstracts")

# -------------------------
# STEP 3: Initialize ChromaDB + Embeddings
# -------------------------
client = chromadb.Client()
try:
    collection = client.create_collection("pubtator_data")
except Exception:
    collection = client.get_collection("pubtator_data")

model = SentenceTransformer("all-MiniLM-L6-v2")
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)

def embed_text(text):
    return model.encode(text)

# -------------------------
# STEP 4: Store abstracts in ChromaDB
# -------------------------
for pmid, abs_text in abstracts.items():
    chunks = splitter.split_text(abs_text)
    for chunk_idx, chunk in enumerate(chunks):
        if chunk.strip():
            embedding = embed_text(chunk)
            unique_id = f"{pmid}_{chunk_idx}_{str(uuid.uuid4())[:8]}"
            collection.add(
                ids=[unique_id],
                documents=[chunk],
                metadatas=[{"pmid": pmid, "chunk_idx": chunk_idx}],
                embeddings=[embedding.tolist()]
            )
print("Stored abstracts in ChromaDB")


2025-10-03 05:05:45.574230: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Found 100 PMIDs: ['41039627', '41039597', '41039514', '41039482', '41039478', '41039439', '41039123', '41038784', '41038753', '41038644', '41038511', '41038490', '41038475', '41038390', '41038328', '41038155', '41038110', '41038002', '41037984', '41037599', '41037500', '41037382', '41037374', '41037159', '41037141', '41037108', '41037094', '41036749', '41036743', '41036709', '41036572', '41036471', '41036470', '41036435', '41036411', '41036149', '41036040', '41035985', '41035927', '41035827', '41035826', '41035821', '41035497', '41035200', '41035143', '41035127', '41035086', '41035073', '41035071', '41034936', '41034702', '41034564', '41034513', '41034502', '41034500', '41034368', '41034365', '41034302', '41034231', '41034207', '41034120', '41033931', '41033755', '41033434', '41033413', '41033231', '41033139', '41033048', '41032868', '41032666', '41032659', '41032625', '41032522', '41032287', '41032255', '41032252', '41032250', '41032245', '41032155', '41032140', '41032131', '41032100'

In [12]:
def query_domain_knowledge(query_text, collection, model, top_k=3):
    """
    Query the ChromaDB collection for relevant domain knowledge.
    
    Args:
        query_text: Text to search for in the knowledge base
        collection: ChromaDB collection containing domain knowledge
        model: SentenceTransformer model for embeddings
        top_k: Number of top relevant documents to retrieve (default: 3)
    """
    query_embedding = embed_text(query_text)
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results

In [13]:

def fill_missing_predicates_llm_with_domain_knowledge(
    input_df,
    unique_predicates,
    collection,
    model,
    output_file='llm_rag_filled_predicates.csv',
    metrics_file='llm_rag_metrics.json',
    responses_file='llm_rag_responses.json'
):

    try:
        bedrock_client = boto3.client('bedrock-runtime', region_name='us-east-1')
    except Exception as e:
        print(f"Failed to configure Bedrock client: {e}")
        bedrock_client = None

    df = input_df.copy()
    start_time = time.time()

    pred_series = df['predicate'].astype(str) if 'predicate' in df.columns else None
    empty_mask = df['predicate'].isna() | (pred_series == '') | (pred_series.str.strip() == '')
    empty_indices = df[empty_mask].index.tolist()
    empty_count = len(empty_indices)

    print(f"Found {empty_count} empty predicates to fill")
    if empty_count == 0:
        return df, {}, []

    predicate_bulleted = "\n".join([f"- {p}" for p in unique_predicates])
    first_predicate = unique_predicates[0]

    batch_prompt = f"""You are a biomedical knowledge graph expert. 
Fill the missing predicates for the triples below.

Choose ONLY from these predicates (exact strings):
{predicate_bulleted}

Instructions:
1) For each numbered triple, output exactly ONE predicate from the list above.
2) Use the provided Domain Context to decide.
3) Write ONE biomedical reasoning-style sentence using "because" or "due to".
4) Always end the explanation with PMID(s) in parentheses.
   Example: "CHEBI:xxxx related_to MONDO:yyyy | Explanation: CHEBI:xxxx affects MONDO:yyyy because it increases protein Z levels (PMID:12345678)"

Triples to complete:
"""

    case_mapping = {}
    case_contexts = {}

    for case_num, idx in enumerate(empty_indices, 1):
        row = df.iloc[idx]
        subj, obj = row['subject'], row['object']
        query_text = f"Relationship between {subj} and {obj} in biomedical research."

        try:
            retrieval = query_domain_knowledge(query_text, collection, model, top_k=3)
            context_lines = []
            for doc, meta in zip(retrieval["documents"][0], retrieval["metadatas"][0]):
                pmid = meta.get("pmid", "NA")
                context_lines.append(f"{doc.strip()} (PMID:{pmid})")
            context_text = "\n".join(context_lines) if context_lines else "No specific context found."
        except Exception as e:
            print(f"RAG retrieval failed for case {case_num}: {e}")
            context_text = "No specific context found."

        case_mapping[case_num] = idx
        case_contexts[case_num] = {"subject": subj, "object": obj, "context": context_text}

        batch_prompt += (
            f"\n{case_num}. Subject: {subj} | Object: {obj}\n"
            f"   Domain Context:\n{context_text[:800]}...\n"
            f"   Instruction: Provide one reasoning-style biomedical explanation ending with PMID(s).\n"
        )

    batch_prompt += "\nRespond with ONLY the numbered lines."

    print(f"Sending Bedrock batch request for {empty_count} RAG-enhanced cases...")

    def normalize_pred(s: str) -> str:
        return s.strip().lower()

    norm_options = {normalize_pred(p): p for p in unique_predicates}
    for p in unique_predicates:
        if p.lower().startswith("biolink:"):
            norm_options[normalize_pred(p.split(":", 1)[1])] = p

    def best_match(suggested: str) -> str:
        s_norm = normalize_pred(suggested).replace('"', '').replace("'", "")
        if s_norm in norm_options:
            return norm_options[s_norm]
        for opt in unique_predicates:
            if s_norm in opt.lower() or opt.lower() in s_norm:
                return opt
        return first_predicate

    llm_filled_count = 0
    successful_requests, failed_requests = 0, 0
    detailed_responses = []

    try:
        if not bedrock_client:
            raise RuntimeError("Bedrock client not available")

        request_body = {
            "messages": [{"role": "user", "content": batch_prompt}],
            "max_tokens": 500000,
            "temperature": 0.2,
            "top_p": 0.9
        }

        resp = bedrock_client.invoke_model(
            modelId="openai.gpt-oss-120b-1:0",
            body=json.dumps(request_body),
            contentType="application/json",
            accept="application/json"
        )
        body = json.loads(resp["body"].read())
        response_text = body["choices"][0]["message"]["content"].strip()
        print("✓ Batch RAG Bedrock request successful")

        predicate_suggestions, explanations = {}, {}
        for raw_line in response_text.splitlines():
            if '|' not in raw_line:
                continue
            
            # split into "N. predicate" and "Explanation: ..."
            lhs, rhs = raw_line.split('|', 1)
            lhs = lhs.strip()
            rhs = rhs.strip()
            
            # robust number+predicate parsing
            m = re.match(r'^(\d+)[\.\)]?\s*(.+)$', lhs.strip())

            if not m:
                continue
            
            case_num = int(m.group(1))
            suggested_predicate = m.group(2).strip()
            
            matched = best_match(suggested_predicate)
            predicate_suggestions[case_num] = matched
            
            # enforce PMID presence
            pmids = re.findall(r'PMID:\d+', rhs)
            if not pmids:
                # fallback: grab one PMID from retrieval context
                fallback_pmid = None
                ctx_text = case_contexts[case_num]["context"]
                match = re.search(r'PMID:(\d+)', ctx_text)
                if match:
                    fallback_pmid = match.group(1)
            
                if fallback_pmid:
                    rhs = rhs + f" (PMID:{fallback_pmid})"
                else:
                    rhs = rhs + " (PMID:NA)"
            
            explanations[case_num] = rhs

        for case_num, idx in case_mapping.items():
            ctx = case_contexts[case_num]
            if case_num in predicate_suggestions:
                pred = predicate_suggestions[case_num]
                expl = explanations.get(case_num, "Explanation missing")
                df.at[idx, 'predicate'] = pred
                llm_filled_count += 1
                print(f"✓ Row {idx}: '{ctx['subject']}' {pred} '{ctx['object']}'  Explanation: {expl}")
                detailed_responses.append({
                    "row": idx,
                    "case_num": case_num,
                    "subject": ctx["subject"],
                    "object": ctx["object"],
                    "context": ctx["context"],
                    "suggested_predicate": pred,
                    "explanation": expl,
                    "success": True,
                    "method": "RAG_Bedrock"
                })
            else:
                df.at[idx, 'predicate'] = first_predicate
                print(f"→ Row {idx}: No suggestion parsed; used first predicate '{first_predicate}'")
                detailed_responses.append({
                    "row": idx,
                    "case_num": case_num,
                    "subject": ctx["subject"],
                    "object": ctx["object"],
                    "context": ctx["context"],
                    "suggested_predicate": first_predicate,
                    "explanation": "Fallback: no valid suggestion parsed",
                    "success": False,
                    "method": "FirstPredicate_Fallback"
                })

        successful_requests = 1

    except Exception as e:
        print(f"✗ Batch RAG Bedrock request failed: {e}")
        failed_requests = 1

    end_time = time.time()
    metrics = {
        "total_empty_predicates": empty_count,
        "llm_filled_count": llm_filled_count,
        "fallback_count": empty_count - llm_filled_count,
        "successful_requests": successful_requests,
        "failed_requests": failed_requests,
        "llm_success_rate": (llm_filled_count / empty_count) if empty_count > 0 else 0,
        "total_processing_time_seconds": end_time - start_time,
        "rag_enhanced": True
    }

    df.to_csv(output_file, index=False)
    with open(metrics_file, 'w') as f:
        json.dump(metrics, f, indent=2)

    print("\n=== RAG-Enhanced Bedrock Processing Complete ===")
    print(f"Total predicates filled by LLM: {llm_filled_count}/{empty_count}")
    print(f"LLM success rate: {metrics['llm_success_rate']:.2%}")

    return df, metrics, responses


In [14]:
# Ensure unique predicates are available
test_df_rag = modified.copy().reset_index(drop=True)  # Reset index to 0, 1, 2, 3...
if 'unique_predicates'in globals():
    unique_predicates = modified['predicate'].unique()
    unique_predicates = unique_predicates[unique_predicates != '']
    
    filled_df_rag, metrics_rag, responses_rag = fill_missing_predicates_llm_with_domain_knowledge(
    test_df_rag,
    unique_predicates,
    collection,
    model,
    output_file='bedrock_rag_filled_test.csv',
    metrics_file='bedrock_rag_metrics_test.json',
    responses_file='bedrock_rag_responses_test.json'
    )
    
    print("RAG-enhanced LLM processing complete.")
# else:

Found 99 empty predicates to fill
Sending Bedrock batch request for 99 RAG-enhanced cases...
✓ Batch RAG Bedrock request successful
→ Row 0: No suggestion parsed; used first predicate 'biolink:treats_or_applied_or_studied_to_treat'
✓ Row 3: 'CHEBI:35475' biolink:treats_or_applied_or_studied_to_treat 'MONDO:0004972'  Explanation: MONDO:0004972 | biolink:treats_or_applied_or_studied_to_treat | Explanation: CHEBI:35475 has been studied for treatment of MONDO:0004972 due to its ability to inhibit pathogenic pathways implicated in the disease (PMID:41039597)
✓ Row 4: 'CHEBI:35475' biolink:treats_or_applied_or_studied_to_treat 'MONDO:0100010'  Explanation: MONDO:0100010 | biolink:treats_or_applied_or_studied_to_treat | Explanation: CHEBI:35475 is evaluated for MONDO:0100010 because pre‑clinical data suggest it can ameliorate key clinical manifestations (PMID:41039597)
✓ Row 5: 'CHEBI:35475' biolink:treats_or_applied_or_studied_to_treat 'MONDO:0006875'  Explanation: MONDO:0006875 | biolink:tr