# Use LLM for coassociarion analysis

# 1) Set up libraries and datasets

In [None]:
# Import necessary libraries
import os
import sys
import re
import time
import random
import logging
from pathlib import Path
from datetime import datetime, timedelta
from collections import Counter

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

print("Success!")

In [None]:
# Set the working directory and file paths
input_directory = "INPUT_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
variantscape_directory = "VARIANTSCAPE_DIRECTORY"
variantscape_LLM_coas_directory = "VARIANTSCAPE_LLM_COAS_DIRECTORY"
figure_directory = "FIGURE_DIRECTORY"

os.chdir(variantscape_directory)
print("Current directory:", os.getcwd())

In [None]:
# Investigate dataset
variant_analysis_df = pd.read_csv("cleaned_df_v4.csv", low_memory=False)
metadata_mapping = pd.read_csv("metadata_mapping_transposed.csv", low_memory=False)
variant_cols = set(variant_analysis_df.columns)
metadata_mapping["Entity"] = metadata_mapping["Entity"].astype(str).str.strip()
metadata_mapping["Category"] = metadata_mapping["Category"].astype(str).str.strip()
valid_metadata = metadata_mapping[metadata_mapping["Entity"].isin(variant_cols)].copy()
col_to_category = dict(zip(valid_metadata["Entity"], valid_metadata["Category"]))
from collections import defaultdict

category_to_cols = defaultdict(list)
for col, cat in col_to_category.items():
    category_to_cols[cat].append(col)

for cat, cols in category_to_cols.items():
    print(f"{cat}: {len(cols)} columns")

In [None]:
# Define helper to extract info per paper
def extract_paper_info(row, variant_cols, treatment_cols):
    paper_id = row.get("PaperId", "")
    title = row.get("PaperTitle", "")
    abstract = row.get("Abstract", "")
    mentioned_variants = [col for col in variant_cols if row.get(col, 0) == 1]
    mentioned_treatments = [col for col in treatment_cols if row.get(col, 0) == 1]
    
    return {
        "PaperId": paper_id,
        "PaperTitle": title,
        "Abstract": abstract,
        "Variants": mentioned_variants,
        "Treatments": mentioned_treatments
    }
variant_cols = category_to_cols["Variant"]
treatment_cols = category_to_cols["Treatment"]
from tqdm import tqdm
tqdm.pandas()
print("Extracting per-paper treatment and variant mentions...")
paper_data = variant_analysis_df.progress_apply(
    lambda row: extract_paper_info(row, variant_cols, treatment_cols), axis=1
).tolist()
print(f"Extracted data for {len(paper_data):,} papers.")

In [None]:
# Check the type and size
print("Type of paper_data:", type(paper_data))
print("Number of papers:", len(paper_data))
print("\nKeys in first entry:", paper_data[0].keys())
sample = random.choice(paper_data)
print("\nSample extracted entry:")

for k, v in sample.items():
    if isinstance(v, str):
        print(f"{k}: {v[:200]}...")
    elif isinstance(v, list):
        print(f"{k}: {v[:5]}...")   
    else:
        print(f"{k}: {v}")           
        
# Check how many papers have at least one variant and one treatment
num_with_both = sum(1 for entry in paper_data if entry["Variants"] and entry["Treatments"])
print(f"\nPapers with at least one variant AND one treatment: {num_with_both:,}")

# How many papers have neither
num_with_neither = sum(1 for entry in paper_data if not entry["Variants"] and not entry["Treatments"])
print(f"Papers with NO variant and NO treatment: {num_with_neither:,}")

In [None]:
for i, entry in enumerate(paper_data[:5]):
    print(f"--- Paper {i+1} ---")
    print("PaperId:", entry["PaperId"])
    print("Title:", entry["PaperTitle"])
    print("Abstract:", entry["Abstract"][:200], "...") 
    print("Variants:", entry["Variants"])
    print("Treatments:", entry["Treatments"])
    print()

In [None]:
# Convert to dataframe
paper_df = pd.DataFrame(paper_data)
paper_df["Variants"] = paper_df["Variants"].apply(lambda x: ", ".join(x))
paper_df["Treatments"] = paper_df["Treatments"].apply(lambda x: ", ".join(x))
paper_df.to_csv("filtered_paper_data_for_LLM_coassociation.csv", index=False)
print(f"Saved {len(paper_df):,} entries to 'filtered_paper_data_for_LLM_coassociation.csv'")
print(paper_df)

# 2) Select and set up LLMs

In [None]:
# Set up a language model to answer the questions
!pip install OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
print("Success!")

In [None]:
# Define the models to be tested
models = ["llama31-70b", "llama33-70b", "deepseek_v3", "deepseek_r1", "deepseek_r1_distill_llama_70b"]

# Mapping model names to their full Hugging Face or DeepInfra identifiers
model_fullnames = {
    "llama31-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "llama33-70b": "meta-llama/Llama-3.3-70B-Instruct",
    "deepseek_v3": "deepseek-ai/DeepSeek-V3",
    "deepseek_r1": "deepseek-ai/DeepSeek-R1",
    "deepseek_r1_distill_llama_70b": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
}

SYSTEM_MSG = "You are a helpful medical question answering assistant. Please carefully follow the exact instructions and do not provide explanations."
modelname = models[1]

if modelname in [ "llama2-3b" ]: 
    model, tokenizer = load(model_fullnames[modelname])
    def generateFromPrompt(prompt):
        if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
            messages = [{"role": "system", "content": SYSTEM_MSG},
                {"role": "user", "content": prompt}]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            response = generate(model, tokenizer, prompt=prompt, verbose=False)
            return response
elif modelname in [ "gpt35", "gpt4o" ]: # OpenAI models
    client = OpenAI(
       api_key='API_key1' 
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
elif modelname in [ "llama31-70b" , "llama33-70b" , "deepseek_v3" , "deepseek_r1" , "deepseek_r1_distill_llama_70b"]:  # DeepInfra models
    client = OpenAI(
        api_key = "API_key2",
        base_url="https://api.deepinfra.com/v1/openai",
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
    
generateFromPrompt("hello!")

In [None]:
print("All installed models:",   models)
print("Current model in use:",   modelname)

# 3) Define prompts

In [None]:
#Define prompts

PROMPTS = {
    0: lambda title, abstract, pairs: (
        f"You are a biomedical research assistant analyzing the relationship between genetic variants and treatments based on scientific publications.\n\n"
        f"Read the following title and abstract carefully, then evaluate each of the variant-treatment pairs listed.\n\n"
        f"Title: {title}\n\nAbstract: {abstract}\n\n"
        f"Variant-Treatment pairs:\n" +
        "\n".join(f"- {v} + {t}" for v, t in pairs) + "\n\n"
        f"For each pair, classify the relationship described in the abstract using only one of the following labels:\n"
        f"- Sensitive\n- Resistant\n- Diagnostic\n- Unrelated\n- Unknown\n\n"
        f"Respond **only** with the list of pairs and their labels, in this format:\n"
        f"<variant> + <treatment> : <label>\n"
    ),

    1: lambda title, abstract, pairs: (
        f"You are analyzing biomedical literature to extract clinical relationships between gene variants and drugs.\n"
        f"Using the information from the title and abstract, determine whether each of the following variant-treatment pairs has a meaningful clinical association.\n\n"
        f"Paper title: {title}\nAbstract: {abstract}\n\n"
        f"Variant-treatment pairs:\n" +
        "\n".join(f"- {v} + {t}" for v, t in pairs) + "\n\n"
        f"Label each pair using one of the following categories:\n"
        f"Sensitive, Resistant, Diagnostic, Unrelated, Unknown.\n\n"
        f"Format your output as follows:\n"
        f"<variant> + <treatment> : <label>\n"
    ),

    2: lambda title, abstract, pairs: (
        f"Carefully analyze the title and abstract of the following biomedical paper. Then evaluate the relationship between the listed variant-treatment pairs.\n\n"
        f"Use only these labels:\nSensitive, Resistant, Diagnostic, Unrelated, Unknown.\n\n"
        f"Title: {title}\n\nAbstract: {abstract}\n\n"
        f"Pairs to analyze:\n" +
        "\n".join(f"- {v} + {t}" for v, t in pairs) + "\n\n"
        f"Respond strictly in this format:\n<variant> + <treatment> : <label>"
    )
}

print("Prompts successfully defined.")


# 4) Run genetic variant extraction with LLM

In [None]:
# Configuration
os.chdir(variantscape_LLM_coas_directory)
tqdm.pandas()

BATCH_SIZE = 7524
selected_prompt_number = 1
modelname = modelname

today_date = datetime.today().strftime('%Y-%m-%d')
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

# Set file paths
variant_output_file_path = os.path.join(
    variantscape_LLM_coas_directory, f"LLM_variant_screening_{modelname}_prompt{selected_prompt_number}.csv"
)
runtime_file = os.path.join(
    variantscape_LLM_coas_directory, f"runtime_summary_{modelname}_prompt{selected_prompt_number}.txt"
)
progress_log_file = os.path.join(
    variantscape_LLM_coas_directory, f"progress_log_{modelname}_prompt{selected_prompt_number}.txt"
)

def ensure_file_exists(file_path, header_text=None):
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            if header_text:
                f.write(f"{header_text}\n")
                f.write("=" * 60 + "\n")

ensure_file_exists(runtime_file, f"### Runtime Log - {today_date} ###\nStart Time: {start_time_str}")
ensure_file_exists(progress_log_file, f"### Progress Log - {today_date} ###")

if not os.path.exists(variant_output_file_path):
    with open(variant_output_file_path, "w") as f:
        f.write("PaperId,PaperTitle,Abstract,Variants,Treatments,VariantTreatmentPairs,LLM_Prompt,LLM_Response\n")

# Set up logging
logging.basicConfig(
    filename=progress_log_file,
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

print("Success! All necessary files and directories are set up.")
print("Script Start Time:", start_time_str)
print("Defined prompt number:", selected_prompt_number)
print(f"Defined batch size: {BATCH_SIZE:,}")

# Define functions
def generate_variant_treatment_pairs(row):
    variants = [v.strip() for v in row["Variants"].split(",") if v.strip()]
    treatments = [t.strip() for t in row["Treatments"].split(",") if t.strip()]
    return list(product(variants, treatments))

def screen_publication_for_variants(row, prompt_number):
    title = row["PaperTitle"]
    abstract = row["Abstract"]
    pairs = generate_variant_treatment_pairs(row)
    if prompt_number not in PROMPTS:
        raise ValueError(f"Invalid prompt number: {prompt_number}. Choose from 0, 1, 2.")
    prompt = PROMPTS[prompt_number](title, abstract, pairs)
    return prompt, pairs

def process_with_llm(prompt):
    try:
        response = generateFromPrompt(prompt)
        return response
    except Exception as e:
        logging.error(f"LLM processing error: {e}")
        return "ERROR"

print("Functions defined and updated for the current dataset.")

In [None]:
# ========================== RESUME FROM LAST CHECKPOINT ========================== #

# Check if paper_df is already loaded
if 'paper_df' not in globals():
    raise ValueError("Dataset `paper_df` is not loaded in memory. Make sure it's defined before running the script.")
if 'PaperId' not in paper_df.columns:
    raise KeyError("Dataset must contain a 'PaperId' column to track progress.")
paper_df['PaperId'] = paper_df['PaperId'].astype(str).str.strip()

if os.path.exists(variant_output_file_path) and os.path.getsize(variant_output_file_path) > 100:
    processed_df = pd.read_csv(variant_output_file_path)
    processed_df.columns = processed_df.columns.str.strip().str.replace('\ufeff', '')
    processed_df['PaperId'] = processed_df['PaperId'].astype(str).str.strip()

    processed_df = processed_df.drop_duplicates(subset=["PaperId"])
    processed_articles = set(processed_df['PaperId'])
    total_processed_articles = len(processed_articles)

    print(f"Resuming from last processed row. {total_processed_articles:,} articles already processed.")
else:
    processed_articles = set()
    total_processed_articles = 0
    print("Starting fresh processing.")


unprocessed_df = paper_df[~paper_df['PaperId'].isin(processed_articles)].copy()
total_articles = len(unprocessed_df)
total_batches = (total_articles // BATCH_SIZE) + (1 if total_articles % BATCH_SIZE != 0 else 0)
print(f"Unprocessed articles remaining: {total_articles:,} in {total_batches:,} batches")

if os.path.exists(variant_output_file_path) and os.path.getsize(variant_output_file_path) > 100:
    processed_df = pd.read_csv(variant_output_file_path)
    processed_df.columns = processed_df.columns.str.strip().str.replace('\ufeff', '') 
    if 'PaperId' not in processed_df.columns:
        raise KeyError("CSV exists but does not contain 'PaperId'. Check batch_to_save column selection.")
    processed_articles = set(processed_df['PaperId'])  
    total_processed_articles = len(processed_articles)
    print(f"Resuming from last processed row. {total_processed_articles} articles completed so far.")
else:
    processed_articles = set()
    total_processed_articles = 0
    print("Starting fresh processing.")

total_batches = (len(paper_df) // BATCH_SIZE) + (1 if len(paper_df) % BATCH_SIZE != 0 else 0)

# If all articles are processed, stop
if total_processed_articles == len(paper_df):
    print("\nAll batches are complete. No more articles to process.")
    print("You have successfully processed the entire dataset.")
    try:
        sys.exit(0)
    except SystemExit:
        pass 

unprocessed_df = paper_df[~paper_df['PaperId'].isin(processed_articles)]
total_articles = len(unprocessed_df)
print("Success! All necessary files and directories are set up.")
print("Defined prompt number:", selected_prompt_number)
print(f"Defined batch size to run in chunks: {BATCH_SIZE:,}")
print(f"Total unprocessed articles: {total_articles:,}")

# Track cumulative runtime 
if os.path.exists(runtime_file):
    with open(runtime_file, "r") as f:
        lines = f.readlines()
        total_runtime_previous = sum(
            float(line.split(":")[-1].strip().split()[0])
            for line in lines if "Total runtime so far" in line
        )
else:
    total_runtime_previous = 0.0

In [None]:
# ========================== BATCH PROCESSING ========================== #
# Fallbacks in case previous setup blocks were not yet run
if 'total_processed_articles' not in globals():
    total_processed_articles = 0
if 'total_runtime_previous' not in globals():
    total_runtime_previous = 0.0
if 'unprocessed_df' not in globals():
    unprocessed_df = paper_df.copy()
if 'total_articles' not in globals():
    total_articles = len(unprocessed_df)
if 'total_batches' not in globals():
    total_batches = (len(unprocessed_df) // BATCH_SIZE) + (1 if len(unprocessed_df) % BATCH_SIZE != 0 else 0)

start_time = time.time()
batch_number = (total_processed_articles // BATCH_SIZE) + 1
for batch_start in range(0, total_articles, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE, total_articles)
    batch = unprocessed_df.iloc[batch_start:batch_end].copy()
    print(f"\nProcessing Batch {batch_number}/{total_batches} ({batch_start + 1} to {batch_end})...")
    batch_start_time = time.time()
    batch["Prompt_Pair"] = batch.apply(lambda row: screen_publication_for_variants(row, selected_prompt_number), axis=1)
    batch[["LLM_Prompt", "VariantTreatmentPairs"]] = pd.DataFrame(batch["Prompt_Pair"].tolist(), index=batch.index)

    # Query the LLM
    llm_response_column = f'LLM_Response_{modelname}'
    batch[llm_response_column] = batch['LLM_Prompt'].progress_apply(process_with_llm)
    batch_runtime = time.time() - batch_start_time
    batch_to_save = batch[['PaperId', 'PaperTitle', 'Abstract', 'Variants', 'Treatments', 'VariantTreatmentPairs', 'LLM_Prompt', llm_response_column]]
    if os.path.exists(variant_output_file_path):
        batch_to_save.to_csv(variant_output_file_path, mode='a', header=False, index=False)
    else:
        batch_to_save.to_csv(variant_output_file_path, mode='w', index=False)
    total_runtime_so_far = total_runtime_previous + (time.time() - start_time)
    with open(runtime_file, "a") as f:
        f.write(f"\nBatch {batch_number}/{total_batches} started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Batch Runtime: {batch_runtime:.2f} sec\n")
        f.write(f"Total runtime so far (all runs combined): {total_runtime_so_far:.2f} sec\n")
        f.write(f"Total articles processed in this batch: {batch_end - batch_start}\n")
        f.write("=" * 60 + "\n")

    with open(progress_log_file, "a") as f:
        f.write(f"Completed Batch {batch_number} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

    logging.info(f"Processed batch {batch_number}/{total_batches} in {batch_runtime:.2f} sec.")

    articles_in_batch = batch_end - batch_start
    total_articles_processed_now = total_processed_articles + articles_in_batch
    processed_percentage = (total_articles_processed_now / len(paper_df)) * 100
    remaining_articles = len(paper_df) - total_articles_processed_now
    remaining_percentage = (remaining_articles / len(paper_df)) * 100

    def generate_progress_bar(percentage, bar_length=20):
        filled_length = int(bar_length * percentage / 100)
        bar = '|' * filled_length + '-' * (bar_length - filled_length)
        return f"[{bar}] {percentage:.2f}%"
    print(f"\nPaused! Batch {batch_number} completed.")
    print(f"Total processed: {total_articles_processed_now:,} {generate_progress_bar(processed_percentage)}")
    print(f"Remaining: {remaining_articles:,} {generate_progress_bar(remaining_percentage)}")
    print("Check the CSV and runtime file. Re-run this cell to continue with the next batch.")

    total_processed_articles = total_articles_processed_now
    batch_number += 1
    break

# ========================== FINAL SUMMARY ========================== #
total_runtime = total_runtime_so_far
total_hours = total_runtime // 3600
total_minutes = (total_runtime % 3600) // 60
total_seconds = total_runtime % 60

summary_text = f"""
### Genetic Variant-Treatment Screening Summary ###

- Model used: {modelname}
- Prompt number: {selected_prompt_number}
- Total batches processed: {batch_number - 1:,}/{total_batches:,}
- Total articles processed: {total_processed_articles:,}
- Batch runtime: {batch_runtime:.2f} sec
- Cumulative runtime: {total_runtime:.2f} sec ({total_hours:.0f} hr {total_minutes:.0f} min {total_seconds:.2f} sec)
"""
print(summary_text)
with open(runtime_file, "a") as f:
    f.write("\n### Final runtime summary ###\n")
    f.write(f"End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(summary_text)
    f.write("\n" + "=" * 60 + "\n")
print("Final results saved.")

# RERUN FROM LAST CHECKPOINT

# ============================================

# 6) Evaluation of LLMs and prompts

## 6.1) Load dataset

In [None]:
# Load CSVs with different prompts for e3valuation
df_prompt0 = pd.read_csv(os.path.join(variantscape_LLM_coas_directory, f"LLM_variant_screening_{modelname}_prompt0.csv"))
df_prompt1 = pd.read_csv(os.path.join(variantscape_LLM_coas_directory, f"LLM_variant_screening_{modelname}_prompt1.csv"))
df_prompt2 = pd.read_csv(os.path.join(variantscape_LLM_coas_directory, f"LLM_variant_screening_{modelname}_prompt2.csv"))

print("CSV shapes:")
print("Prompt 0:", df_prompt0.shape)
print("Prompt 1:", df_prompt1.shape)
print("Prompt 2:", df_prompt2.shape)


def extract_runtime(prompt_num):
    path = os.path.join(variantscape_LLM_coas_directory, f"runtime_summary_{modelname}_prompt{prompt_num}.txt")
    if not os.path.exists(path):
        return f"Runtime file for prompt {prompt_num} not found."
    with open(path) as f:
        lines = f.readlines()
    final_summary = [line for line in lines if "Cumulative runtime" in line]
    return final_summary[-1].strip() if final_summary else "No runtime summary found."

print("\nRuntime summaries:")
print("Prompt 0:", extract_runtime(0))
print("Prompt 1:", extract_runtime(1))
print("Prompt 2:", extract_runtime(2))

In [None]:
def display_llm_comparison(df, num_rows=5, maxlen=120):
    for i in range(min(num_rows, len(df))):
        title = df.iloc[i]["PaperTitle"][:80]
        p0 = df.iloc[i]["Prompt0"][:maxlen].replace("\n", " ") + ("..." if len(df.iloc[i]["Prompt0"]) > maxlen else "")
        p1 = df.iloc[i]["Prompt1"][:maxlen].replace("\n", " ") + ("..." if len(df.iloc[i]["Prompt1"]) > maxlen else "")
        p2 = df.iloc[i]["Prompt2"][:maxlen].replace("\n", " ") + ("..." if len(df.iloc[i]["Prompt2"]) > maxlen else "")

        print(f"\n--- Sample {i+1} ---")
        print(f"Title: {title}")
        print(f"Prompt 0: {p0}")
        print(f"Prompt 1: {p1}")
        print(f"Prompt 2: {p2}")

display_llm_comparison(df_compare_trunc, num_rows=5, maxlen=120)

## 6.2) Calculate confusion matrix

In [None]:
# Load dataset
df_prompt0 = pd.read_csv(f"{variantscape_LLM_coas_directory}/LLM_variant_screening_llama33-70b_prompt0.csv")
df_prompt1 = pd.read_csv(f"{variantscape_LLM_coas_directory}/LLM_variant_screening_llama33-70b_prompt1.csv")
df_prompt2 = pd.read_csv(f"{variantscape_LLM_coas_directory}/LLM_variant_screening_llama33-70b_prompt2.csv")

resp_col_0 = "LLM_Response"
resp_col_1 = "LLM_Response"
resp_col_2 = "LLM_Response"

def split_predictions(text):
    if pd.isna(text) or not isinstance(text, str):
        return []
    return [s.strip() for s in text.strip().split('\n') if '+' in s and ':' in s]

# Count total predictions
df_prompt0_preds = df_prompt0[resp_col_0].apply(split_predictions)
df_prompt1_preds = df_prompt1[resp_col_1].apply(split_predictions)
df_prompt2_preds = df_prompt2[resp_col_2].apply(split_predictions)

flat_0 = [pred for sublist in df_prompt0_preds.tolist() for pred in sublist]
flat_1 = [pred for sublist in df_prompt1_preds.tolist() for pred in sublist]
flat_2 = [pred for sublist in df_prompt2_preds.tolist() for pred in sublist]

print(f"Prompt 0: {len(flat_0):,} total variant-treatment predictions")
print(f"Prompt 1: {len(flat_1):,} total variant-treatment predictions")
print(f"Prompt 2: {len(flat_2):,} total variant-treatment predictions")

In [None]:
# Define a function to extract variant-treatment pairs from LLM responses
def extract_variant_treatment_pairs(response):
    """Extracts the variant-treatment pairs from LLM responses."""
    pairs = []
    for line in response.split("\n"):
        match = re.match(r"(\S.+?)\s*:\s*(\w+)", line.strip())
        if match:
            variant_treatment = match.group(1)
            prediction = match.group(2)
            pairs.append((variant_treatment, prediction))
    return pairs

# Apply the function to all responses in the three prompts
df_prompt0['Predictions'] = df_prompt0['LLM_Response'].apply(extract_variant_treatment_pairs)
df_prompt1['Predictions'] = df_prompt1['LLM_Response'].apply(extract_variant_treatment_pairs)
df_prompt2['Predictions'] = df_prompt2['LLM_Response'].apply(extract_variant_treatment_pairs)
comparison = []

for i in range(len(df_prompt0)):
    pairs_prompt0 = df_prompt0.iloc[i]['Predictions']
    pairs_prompt1 = df_prompt1.iloc[i]['Predictions']
    pairs_prompt2 = df_prompt2.iloc[i]['Predictions']
    for p0, p1, p2 in zip(pairs_prompt0, pairs_prompt1, pairs_prompt2):
        variant_treatment_0, prediction_0 = p0
        variant_treatment_1, prediction_1 = p1
        variant_treatment_2, prediction_2 = p2
        
        comparison.append({
            'PaperId': df_prompt0.iloc[i]['PaperId'],
            'Variant_Treatment': variant_treatment_0,
            'Prompt0_Prediction': prediction_0,
            'Prompt1_Prediction': prediction_1,
            'Prompt2_Prediction': prediction_2,
            'Agreement_0_1': prediction_0 == prediction_1,
            'Agreement_0_2': prediction_0 == prediction_2,
            'Agreement_1_2': prediction_1 == prediction_2
        })
df_comparison = pd.DataFrame(comparison)

def agreement_rate(col1, col2):
    return (df_comparison[col1] == True).mean()

print("Agreement rates:")
print(f"Prompt 0 vs 1: {agreement_rate('Agreement_0_1', 'Agreement_0_1'):.2%}")
print(f"Prompt 0 vs 2: {agreement_rate('Agreement_0_2', 'Agreement_0_2'):.2%}")
print(f"Prompt 1 vs 2: {agreement_rate('Agreement_1_2', 'Agreement_1_2'):.2%}")


df_disagree = df_comparison[
    (df_comparison["Agreement_0_1"] == False) |
    (df_comparison["Agreement_0_2"] == False) |
    (df_comparison["Agreement_1_2"] == False)
]

print(f"\nTotal disagreements: {len(df_disagree)} / {len(df_comparison)}")
print(df_disagree.head(10)) 

# Get the unique labels (predictions) from all three prompts
labels = sorted(set(df_comparison['Prompt0_Prediction'].unique()) |
                set(df_comparison['Prompt1_Prediction'].unique()) |
                set(df_comparison['Prompt2_Prediction'].unique()))

# Generate confusion matrix for Prompt 0 vs Prompt 1
cm_0_1 = confusion_matrix(df_comparison['Prompt0_Prediction'], df_comparison['Prompt1_Prediction'], labels=labels)
cm_0_1_df = pd.DataFrame(cm_0_1, index=labels, columns=labels)
disagreements_0_1 = cm_0_1.sum() - cm_0_1.diagonal().sum()

print(f"\nConfusion Matrix (Prompt 0 vs Prompt 1) - Disagreements: {disagreements_0_1}")
print(cm_0_1_df)

# Generate confusion matrix for Prompt 1 vs Prompt 2
cm_1_2 = confusion_matrix(df_comparison['Prompt1_Prediction'], df_comparison['Prompt2_Prediction'], labels=labels)
cm_1_2_df = pd.DataFrame(cm_1_2, index=labels, columns=labels)

# Count disagreements for Prompt 1 vs Prompt 2
disagreements_1_2 = cm_1_2.sum() - cm_1_2.diagonal().sum()

print(f"\nConfusion Matrix (Prompt 1 vs Prompt 2) - Disagreements: {disagreements_1_2}")
print(cm_1_2_df)

# Generate confusion matrix for Prompt 0 vs Prompt 2
cm_0_2 = confusion_matrix(df_comparison['Prompt0_Prediction'], df_comparison['Prompt2_Prediction'], labels=labels)
cm_0_2_df = pd.DataFrame(cm_0_2, index=labels, columns=labels)

# Count disagreements for Prompt 0 vs Prompt 2
disagreements_0_2 = cm_0_2.sum() - cm_0_2.diagonal().sum()
print(f"\nConfusion Matrix (Prompt 0 vs Prompt 2) - Disagreements: {disagreements_0_2}")
print(cm_0_2_df)

In [None]:
# Display disagreements with the full abstract for each comparison
df_disagree_with_abstract = df_comparison[
    (df_comparison["Agreement_0_1"] == False) |
    (df_comparison["Agreement_0_2"] == False) |
    (df_comparison["Agreement_1_2"] == False)
]
df_disagree_with_abstract = df_disagree_with_abstract.merge(df_prompt0[['PaperId', 'Abstract']], on='PaperId', how='left')

print(f"\nTotal disagreements: {len(df_disagree_with_abstract)} / {len(df_comparison)}")
print("\nDisagreements with Full Abstract:")

for index, row in df_disagree_with_abstract.iterrows():
    print(f"PaperId: {row['PaperId']}")
    print(f"Variant-Treatment: {row['Variant_Treatment']}")
    print(f"Prompt 0 Prediction: {row['Prompt0_Prediction']}")
    print(f"Prompt 1 Prediction: {row['Prompt1_Prediction']}")
    print(f"Prompt 2 Prediction: {row['Prompt2_Prediction']}")
    print(f"Agreement 0 vs 1: {row['Agreement_0_1']}")
    print(f"Agreement 0 vs 2: {row['Agreement_0_2']}")
    print(f"Agreement 1 vs 2: {row['Agreement_1_2']}")
    print(f"Full Abstract: {row['Abstract']}")
    print("\n" + "="*80 + "\n")

os.chdir(variantscape_LLM_coas_directory)
df_disagree_with_abstract.to_csv("evaluation_of_llm_coassociations_prompts.csv", index=False)

In [None]:
# Evalution: Prompt #1 is the most specific one!

# ============================================

# 7) Investigate dataset and normalize, find consensus of all coassociations

In [None]:
# Investigate LLM dataset
csv_path = "LLM_variant_screening_llama33-70b_prompt1.csv"
variant_coassociation_LLM_df = pd.read_csv(csv_path)
variant_coassociation_LLM_df.columns = variant_coassociation_LLM_df.columns.str.strip()

print(f"Loaded {len(variant_coassociation_LLM_df):,} rows")
print("Columns:", variant_coassociation_LLM_df.columns.tolist())
print("Shape of the dataset:", variant_coassociation_LLM_df.shape)

print("\nSample rows:")
print(variant_coassociation_LLM_df.head())

print("\nMissing values per column:")
print(variant_coassociation_LLM_df.isnull().sum())

empty_rows = variant_coassociation_LLM_df[variant_coassociation_LLM_df.isnull().all(axis=1)]
print(f"\nNumber of completely empty rows: {len(empty_rows)}")

duplicate_paperids = variant_coassociation_LLM_df[variant_coassociation_LLM_df.duplicated(subset='PaperId')]
print(f"Number of duplicated PaperId rows: {len(duplicate_paperids)}")

print("\nDuplicated PaperIds and their counts:")
print(variant_coassociation_LLM_df['PaperId'].value_counts()[variant_coassociation_LLM_df['PaperId'].value_counts() > 1])

llm_col = [col for col in variant_coassociation_LLM_df.columns if col.startswith("LLM_Response")]
if llm_col:
    llm_response_column = llm_col[0]
    
    print(f"\nSample LLM responses from column: {llm_response_column}")
    print(variant_coassociation_LLM_df[llm_response_column].dropna().sample(3, random_state=42).values)

    sample = variant_coassociation_LLM_df.sample(3)
    for i, row in sample.iterrows():
        print(f"\n--- Sample {i+1} ---")
        print(f"Title: {row.get('PaperTitle', '[Missing]')}")
        print(f"Abstract: {row.get('Abstract', '[Missing]')}")
        print(f"Variants: {row.get('Variants', '[Missing]')}")
        print(f"Treatments: {row.get('Treatments', '[Missing]')}")
        print(f"LLM Response:\n{row.get(llm_response_column, '[Missing]')}")


In [None]:
# Extract LLM resposne
def extract_variant_treatment_prediction(response):
    """Extracts (Variant, Treatment, Prediction) tuples from a raw LLM response."""
    results = []
    if pd.isna(response) or not isinstance(response, str):
        return results
    for line in response.split("\n"):
        match = re.match(r"(.+?)\s*\+\s*(.+?)\s*:\s*(\w+)", line.strip())
        if match:
            variant = match.group(1).strip()
            treatment = match.group(2).strip()
            prediction = match.group(3).strip()
            results.append((variant, treatment, prediction))
    return results

variant_coassociation_LLM_df["Parsed_Triples"] = variant_coassociation_LLM_df["LLM_Response"].apply(extract_variant_treatment_prediction)
flat_records = []
for idx, row in variant_coassociation_LLM_df.iterrows():
    paper_id = row.get("PaperId", None)
    for variant, treatment, prediction in row["Parsed_Triples"]:
        flat_records.append({
            "PaperId": paper_id,
            "Variant": variant,
            "Treatment": treatment,
            "Prediction": prediction
        })

df_llm_extracted = pd.DataFrame(flat_records)
print(df_llm_extracted.head())
print(f"\nTotal variant-treatment-prediction entries extracted: {len(df_llm_extracted):,}")

In [None]:
# Define the extraction function
def extract_variant_treatment_prediction(response):
    results = []
    if pd.isna(response) or not isinstance(response, str):
        return results
    for line in response.split("\n"):
        match = re.match(r"(.+?)\s*\+\s*(.+?)\s*:\s*(\w+)", line.strip())
        if match:
            variant = match.group(1).strip()
            treatment = match.group(2).strip()
            prediction = match.group(3).strip()
            results.append((variant, treatment, prediction))
    return results

variant_coassociation_LLM_df["Parsed_Triples"] = variant_coassociation_LLM_df["LLM_Response"].apply(extract_variant_treatment_prediction)
flat_records = []
for idx, row in variant_coassociation_LLM_df.iterrows():
    paper_id = row.get("PaperId", None)
    parsed_triples = row.get("Parsed_Triples", [])
    for variant, treatment, prediction in parsed_triples:
        flat_records.append({
            "PaperId": paper_id,
            "Variant": variant,
            "Treatment": treatment,
            "Prediction": prediction
        })

df_llm_extracted = pd.DataFrame(flat_records)
print(" Preview of extracted variant-treatment-prediction entries:")
print(df_llm_extracted.head())
print(f"\nTotal variant-treatment-prediction entries extracted: {len(df_llm_extracted):,}")

# Clean variant strings
def clean_variant(variant):
    if pd.isna(variant):
        return ""
    variant = variant.strip().lower()
    variant = re.sub(r"[^\w\s]", "", variant)
    variant = re.sub(r"\s+", "", variant)
    variant = variant.replace("__", "_")
    return variant

df_llm_extracted["Variant_Clean"] = df_llm_extracted["Variant"].apply(clean_variant)

# Create Variant_Treatment_Pair column
df_llm_extracted["Variant_Treatment_Pair"] = (
    df_llm_extracted["Variant_Clean"].str.strip() + " + " +
    df_llm_extracted["Treatment"].str.strip().str.lower()
)

# Count predictions per pair
prediction_counts = (
    df_llm_extracted
    .groupby(["Variant_Treatment_Pair", "Prediction"])
    .size()
    .reset_index(name="Count")
)

# Total predictions per pair
total_counts = (
    df_llm_extracted
    .groupby("Variant_Treatment_Pair")
    .size()
    .reset_index(name="Total")
)

merged = prediction_counts.merge(total_counts, on="Variant_Treatment_Pair")
merged["Is_Consensus"] = merged["Count"] == merged["Total"]
consensus_only = merged[merged["Is_Consensus"]].copy()
print("\n=== Variant + Treatment Pairs with 100% LLM Prediction Consensus ===")
print(consensus_only.sort_values(by="Count", ascending=False).head(20))

total_pairs = merged["Variant_Treatment_Pair"].nunique()
non_consensus_pairs = merged[~merged["Is_Consensus"]]["Variant_Treatment_Pair"].nunique()
consensus_pairs = total_pairs - non_consensus_pairs

print(f"\n Total unique Variant + Treatment pairs: {total_pairs}")
print(f" Pairs WITHOUT full consensus: {non_consensus_pairs}")
print(f" Pairs WITH full consensus: {consensus_pairs}")

# Show non-consensus rows
no_consensus = merged[~merged["Is_Consensus"]].copy()
no_consensus_sorted = no_consensus.sort_values(by="Total", ascending=False)
print("\n===  Variant + Treatment Pairs WITHOUT LLM Prediction Consensus ===")
print(no_consensus_sorted.head(20))

# Crosscheck 1: Unparsed but non-empty responses
unparsed_count = (
    variant_coassociation_LLM_df["Parsed_Triples"]
    .apply(lambda x: len(x) == 0)
    .sum()
)
print(f"\n Number of LLM responses that returned ZERO parsed triples: {unparsed_count}")

# Crosscheck 2: Duplicated variant-treatment pairs
dupe_check = df_llm_extracted.duplicated(subset=["Variant_Treatment_Pair", "Prediction"], keep=False)
if dupe_check.any():
    print("\n Duplicate Variant-Treatment-Prediction rows found:")
    print(df_llm_extracted[dupe_check].head())
else:
    print("\n No duplicate Variant-Treatment-Prediction rows found.")

# Crosscheck 3: Prediction class distribution
print("\n Prediction label breakdown:")
print(df_llm_extracted["Prediction"].value_counts())

In [None]:
# Building consensus labels for variant-treatment predictions
valid_preds = ["Sensitive", "Resistant", "Diagnostic"]
merged["Proportion"] = merged["Count"] / merged["Total"]
merged_sorted = merged.sort_values(["Variant_Treatment_Pair", "Proportion"], ascending=[True, False])
dominant_per_pair = merged_sorted.drop_duplicates("Variant_Treatment_Pair").copy()
dominant_per_pair.loc[:, "Soft_Consensus"] = (
    (dominant_per_pair["Proportion"] >= 0.60) &
    (dominant_per_pair["Total"] >= 3) &
    (dominant_per_pair["Prediction"].isin(valid_preds))
)

soft_consensus_total = dominant_per_pair["Soft_Consensus"].sum()
soft_consensus_percent = 100 * soft_consensus_total / dominant_per_pair.shape[0]
print(f"Soft consensus pairs (custom rules): {soft_consensus_total}")
print(f"Percentage of all Variant+Treatment pairs: {soft_consensus_percent:.1f}%")

hard_consensus_pairs = set(merged.loc[merged["Is_Consensus"], "Variant_Treatment_Pair"].unique())
soft_consensus_pairs = set(dominant_per_pair.loc[dominant_per_pair["Soft_Consensus"], "Variant_Treatment_Pair"].unique())
total_pairs = df_llm_extracted["Variant_Treatment_Pair"].nunique()
all_consensus_pairs = hard_consensus_pairs.union(soft_consensus_pairs)
no_consensus_pairs = set(df_llm_extracted["Variant_Treatment_Pair"].unique()) - all_consensus_pairs
print(f"\nTotal Variant + Treatment pairs: {total_pairs}")
print(f"Hard consensus pairs: {len(hard_consensus_pairs)}")
print(f"Soft-only consensus pairs: {len(soft_consensus_pairs - hard_consensus_pairs)}")
print(f"Total pairs with any consensus (before fallback): {len(all_consensus_pairs)}")
print(f"Pairs WITHOUT any consensus (before fallback): {len(no_consensus_pairs)}")

# Fallback consensus resolution

df_no_consensus = df_llm_extracted[df_llm_extracted["Variant_Treatment_Pair"].isin(no_consensus_pairs)].copy()

fallback_results = []
for pair, group in df_no_consensus.groupby("Variant_Treatment_Pair"):
    label_counts = group["Prediction"].value_counts()
    unique_labels = label_counts.index.tolist()

    valid_labels = {"Sensitive", "Resistant", "Diagnostic"}
    weak_labels = {"Unknown", "Unrelated"}

    present_valid = [label for label in unique_labels if label in valid_labels]
    present_weak = [label for label in unique_labels if label in weak_labels]

    if set(unique_labels).issubset(weak_labels):
        if label_counts.get("Unknown", 0) > label_counts.get("Unrelated", 0):
            fallback_results.append((pair, "Unknown"))
        elif label_counts.get("Unrelated", 0) > label_counts.get("Unknown", 0):
            fallback_results.append((pair, "Unrelated"))
        else:
            fallback_results.append((pair, "Unknown"))
    elif len(unique_labels) == 1:
        fallback_results.append((pair, unique_labels[0]))
    elif present_valid and present_weak:
        top_valid_label = label_counts.loc[present_valid].idxmax()
        fallback_results.append((pair, top_valid_label))
    else:
        fallback_results.append((pair, "No consensus"))

df_fallback_consensus = pd.DataFrame(fallback_results, columns=["Variant_Treatment_Pair", "Resolved_Prediction"])

# Final consensus assembly

hard_labels = merged[merged["Is_Consensus"]].copy()
hard_labels = hard_labels.sort_values(["Variant_Treatment_Pair", "Count"], ascending=[True, False])
hard_labels = hard_labels.drop_duplicates("Variant_Treatment_Pair")[["Variant_Treatment_Pair", "Prediction"]]
hard_labels = hard_labels.rename(columns={"Prediction": "Resolved_Prediction"})

soft_labels = dominant_per_pair[dominant_per_pair["Soft_Consensus"]][["Variant_Treatment_Pair", "Prediction"]].copy()
soft_labels = soft_labels.rename(columns={"Prediction": "Resolved_Prediction"})

df_final_consensus = pd.concat([
    hard_labels,
    soft_labels[~soft_labels["Variant_Treatment_Pair"].isin(hard_labels["Variant_Treatment_Pair"])],
    df_fallback_consensus[~df_fallback_consensus["Variant_Treatment_Pair"].isin(hard_labels["Variant_Treatment_Pair"]) &
                          ~df_fallback_consensus["Variant_Treatment_Pair"].isin(soft_labels["Variant_Treatment_Pair"])]
], ignore_index=True)

df_final_consensus = df_final_consensus[["Variant_Treatment_Pair", "Resolved_Prediction"]]
df_final_consensus.to_csv("final_variant_treatment_consensus.csv", index=False)

print("Final consensus dataset shape:")
print(df_final_consensus.shape)
print("Saved to: final_variant_treatment_consensus.csv")


In [None]:
# Count number of pairs per prediction category
prediction_counts = df_final_consensus["Resolved_Prediction"].value_counts()
prediction_percent = (prediction_counts / len(df_final_consensus) * 100).round(2)
summary = pd.DataFrame({
    "Count": prediction_counts,
    "Percentage": prediction_percent
})

print("Number and percentage of Variant + Treatment pairs per prediction category:")
print(summary)
df_no_consensus_final = df_final_consensus[
    df_final_consensus["Resolved_Prediction"] == "No consensus"
]

print("\nSample of 'No consensus' entries:")
print(df_no_consensus_final.head(10))