# Use LLM for genetic variant extraction

- Set up libraries and datasets
- Select LLM
- Select performing prompt
- Run LLM on dataset
- Extract information and create variant matrix and dictionary
- Evaluation
    - Comparison of LLMs for genetic variant extraction
    - Compairson of different prompts to for genetic variant extraction

# 1) Set up libraries and datasets

In [None]:
import pandas as pd
import os
import numpy as np
import re
import time
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import logging
from pathlib import Path
import sys
print("Success!")

In [None]:
# Set the working directory and file paths
input_directory = "INPUT_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
LLM_evaluation_directory = "LLM_EVALUATION_DIRECTORY"
os.chdir(output_directory)
print("Current Working Directory:", os.getcwd())

# Load datasets
os.chdir(output_directory)
print("\nCurrent Working Directory:", os.getcwd())
cancer_df = pd.read_csv("binary_cancer_matrix_filtered.csv")
len_cancer_df=len(cancer_df)
print(f" --> Total rows in cancer dataset: {len_cancer_df:,}")
cancer_df = cancer_df[['PaperId', 'PaperTitle', 'Abstract']].copy()
print(cancer_df)

# 2) Select and set up LLMs

In [None]:
# Set up a language model to answer the questions
!pip install OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
print("Success!")

In [None]:
# Define the models to be tested
models = ["llama31-70b", "llama33-70b", "deepseek_v3", "deepseek_r1", "deepseek_r1_distill_llama_70b"]

# Mapping model names to their full Hugging Face or DeepInfra identifiers
model_fullnames = {
    "llama31-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "llama33-70b": "meta-llama/Llama-3.3-70B-Instruct",
    "deepseek_v3": "deepseek-ai/DeepSeek-V3",
    "deepseek_r1": "deepseek-ai/DeepSeek-R1",
    "deepseek_r1_distill_llama_70b": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
}

SYSTEM_MSG = "You are a helpful medical question answering assistant. Please carefully follow the exact instructions and do not provide explanations."
modelname = models[1]

if modelname in [ "llama2-3b" ]:  # Local model
    model, tokenizer = load(model_fullnames[modelname])
    def generateFromPrompt(prompt):
        if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
            messages = [{"role": "system", "content": SYSTEM_MSG},
                {"role": "user", "content": prompt}]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            response = generate(model, tokenizer, prompt=prompt, verbose=False)
            return response
elif modelname in [ "gpt35", "gpt4o" ]: # OpenAI models
    client = OpenAI(
       api_key='API_KEY1'   
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
elif modelname in [ "llama31-70b" , "llama33-70b" , "deepseek_v3" , "deepseek_r1" , "deepseek_r1_distill_llama_70b"]:  # DeepInfra models
    client = OpenAI(
        api_key = "API_KEY2",
        base_url="https://api.deepinfra.com/v1/openai",
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
    
generateFromPrompt("hello!")

In [None]:
print("All installed models:",   models)
print("Current model in use:",   modelname)

# 3) Define prompts

In [None]:
# Define prompts in a dictionary
# All prompts have been tested and #3 is the best performing

PROMPTS = {
    0: lambda title, abstract: (
        f"You are a helpful and highly detail-oriented research assistant. Carefully review the title and abstract of the given publication to identify any specific genetic variants mentioned."
        f"Genetic variants must include specific and concrete names such as variant reference IDs, DNA base changes, protein changes, or precise genetic notations."
        f"Your goal is to report only **explicitly mentioned and well-defined genetic variants** and disregard any vague or generic descriptions. "
        f"Do NOT include generic mentions like simple gene names (e.g., BRCA1, BRCA2, ATM) or unspecific descriptions (e.g., 'allele loss,' 'mutation detected'). "
        f"Do NOT report terms such as 'reversion mutations,' 'gene mutation,' or 'mutation detected,' unless a specific variant name is explicitly mentioned.\n\n"
        f"Start your response with:\n"
        f" 'No genetic variant detected in this publication.' if no genetic variants are found, or if only vague terms or placeholders are mentioned (e.g., 'Not specified,' 'unknown variant').\n\n"
        f" 'Genetic variant detected:' if you identify any genetic variant, followed by detailed information below.\n\n"
        f"For each identified variant, provide the following details:\n"
        f"- **Variant name**: e.g., N2875H\n"
        f"- **Variant type**: Specify SNP, Indel, missense variant, or splice site variant.\n"
        f"- **Notation**: Provide HGVS notation if available (e.g., c.7089+1del).\n"
        f"- **Sequence ontology**: e.g., missense variant, synonymous variant, or frameshift.\n"
        f"- **Location**: Specify gene name, chromosome, or precise locus (e.g., Gene: ATM, Chromosome: chr11).\n"
        f"- **Functional impact**: Describe effects on gene function or expression (e.g., nonsynonymous alteration, base change C > T).\n"
        f"- **Clinical relevance**: Mention any associated disease, trait, or phenotype.\n"
        f"Important: If a Variant name is **'Not mentioned,' 'Not specified,' or only inferred** (e.g., 'reversion mutation,' 'mutation detected'), respond with: **'No genetic variant detected in this publication.'**\n\n"
        f"Now analyze the following details:\n"
        f"Title: {title}\n\nAbstract: {abstract}\n\n"
        f"Provide your response in the specified format."
    ),
    
    1: lambda title, abstract: (
        f"Analyze the following title and abstract to identify **genetic variants** mentioned," 
        f"including specific variant names, base changes, and protein alterations.\n"
        f"Only report well-defined variants, such as specific reference IDs or exact mutations."
        f"Ignore vague terms like 'mutation detected' or generic mentions of genes (e.g., BRCA1, ATM)."
        f"If no specific variant is mentioned, respond with: 'No variant.'\n\n"
        f"For each identified variant, provide:\n"
        f"- **Variant Name**: e.g., c.2138C>G\n"
        f"- **Gene Name**: e.g., BRCA1\n"
        f"Return results in the following format: 'Variant: Variant Name, Gene Name' or 'No variant.'\n\n"
        f"Title: {title}\nAbstract: {abstract}\n"
    ),

    2: lambda title, abstract: (
        f"Extract all genetic variants from the following title and abstract. Only return:\n"
        f"- **HGVS Notation** (c., p., g.), e.g., c.2138C>G, p.Arg713Trp, g.32389625G>A\n"
        f"- **Protein changes** (e.g., V600E, Arg713Trp)\n"
        f"- **rsIDs** (e.g., rs121913529)\n"
        f"- Ignore vague terms (e.g., 'mutation found').\n\n"
        f"### Format response strictly as:\n"
        f"'Variant: <mutation>, Gene: <gene>' per line, or 'No variant' if none found.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),

    3: lambda title, abstract: (
        f"Extract only specific genetic variants from the text. Return strictly:\n"
        f"- **HGVS Notation** (c., p., g.) e.g., c.2138C>G, p.Arg713Trp\n"
        f"- **Protein changes** (e.g., V600E, Arg713Trp)\n"
        f"- **rsIDs** (e.g., rs121913529)\n"
        f"- Ignore vague terms (e.g., 'mutation found').\n\n"
        f"### Format:\n"
        f"- Variant: 'Variant: <mutation>, Gene: <gene>' per line\n"
        f"- If none, return: 'No variant'\n"
        f"- No extra text, no explanations.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),

    4: lambda title, abstract: (
        f"Extract only genetic variants from the text in strict format:\n"
        f"- HGVS Notation: c., p., g. (e.g., c.2138C>G, p.Arg713Trp)\n"
        f"- Protein changes: (e.g., V600E, Arg713Trp, frameshift mutations e.g., p.Asp427Thrfs*3**)\n"
        f"- rsIDs: (e.g., rs121913529)\n"
        f"- Ignore vague terms ('mutation found').\n\n"
        f"**Format:**\n"
        f"- 'Variant: <mutation>, Gene: <gene>' per line\n"
        f"- If none, return: 'No variant'\n"
        f"- No extra text or explanation.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),

    5: lambda title, abstract: (
        f"Extract **only genetic variants** from the text in strict format:\n"
        f"- HGVS Notation: c., p., g. (e.g., c.2138C>G, p.Arg713Trp)\n"
        f"- Protein changes: (e.g., V600E, Arg713Trp, frameshift mutations e.g., p.Asp427Thrfs*3**)\n"
        f"- rsIDs: (e.g., rs121913529)\n"
        f"- Ignore vague terms (e.g., 'mutation found').\n"
        f"- **Do not add extra comments, explanations, or summaries.**\n\n"
        f"**Format:**\n"
        f"- 'Variant: <mutation>, Gene: <gene>' per line\n"
        f"- If none, return exactly: 'No variant'\n"
        f"- **No extra text, no summaries, no explanations, no other output.**\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),

    6: lambda title, abstract: (
        f"Extract only specific genetic variants from the text. Strictly return:\n"
        f"- **HGVS Notation** (c., p., g.) e.g., c.2138C>G, p.Arg713Trp\n"
        f"- **Protein changes** (e.g., V600E, Arg713Trp)\n"
        f"- **rsIDs** (e.g., rs121913529)\n"
        f"- Each variant must be associated with a gene.\n"
        f"- **Do NOT return genes alone with no variant.**\n"
        f"- **Ignore vague terms** (e.g., 'mutation found', 'gene alteration', or general mentions of genes like TP53).\n"
        f"- **Strictly format the output as follows:**\n"
        f"\n### Output format:\n"
        f"Variant: <variant>, Gene: <gene>  (one per line)\n"
        f"Example: 'Variant: V600E, Gene: BRAF'\n"
        f"If no valid variants exist, return only: 'No variant'\n"
        f"No explanations, no extra text.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    )
}
print("Prompts successfully defined!")

# 4) Run genetic variant extraction with LLM

In [None]:
# ========================== CONFIGURATION ========================== #
# Define output directory
os.chdir(output_directory)
tqdm.pandas()

# Define batch size for processing
BATCH_SIZE = 60000

# Set model name and prompt selection
modelname = modelname
selected_prompt_number = 3  # Change dynamically as needed

# Get today's date and current time
today_date = datetime.today().strftime('%Y-%m-%d')
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 

# Define file paths
variant_output_file_path = os.path.join(output_directory, f"LLM_variant_extraction_{modelname}_prompt{selected_prompt_number}.csv")
runtime_file = os.path.join(output_directory, f"runtime_summary_{modelname}_prompt{selected_prompt_number}.txt")
progress_log_file = os.path.join(output_directory, f"progress_log_{modelname}_prompt{selected_prompt_number}.txt")

# Ensure files exist
def ensure_file_exists(file_path, header_text=None):
    """Creates the file if it does not exist and optionally writes a header."""
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            if header_text:
                f.write(f"{header_text}\n")
                f.write("=" * 60 + "\n")

ensure_file_exists(runtime_file, f"### Runtime Log - {today_date} ###\nStart Time: {start_time_str}")
ensure_file_exists(progress_log_file, f"### Progress Log - {today_date} ###")

if not os.path.exists(variant_output_file_path):
    with open(variant_output_file_path, "w") as f:
        f.write("PaperId,PaperTitle,Abstract,LLM_Prompt,LLM_Response\n")

# Define logging setup
logging.basicConfig(
    filename=progress_log_file,
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

print("Success! All necessary files and directories are set up.")
print("Script Start Time:", start_time_str)
print("Defined prompt number:", selected_prompt_number)
print(f"Defined batch size: {BATCH_SIZE:,}")

# Define functions
def screen_publication_for_variants(row, prompt_number):
    """Generates a dynamic prompt based on the selected prompt number."""
    title = row['PaperTitle']
    abstract = row['Abstract']
    if prompt_number not in PROMPTS:
        raise ValueError(f"Invalid prompt number: {prompt_number}. Choose between 0-5.")
    return PROMPTS[prompt_number](title, abstract)
def process_with_llm(prompt):
    """Process the given prompt using the LLM model."""
    try:
        response = generateFromPrompt(prompt)
        return response
    except Exception as e:
        logging.error(f"LLM processing error: {e}")
        return "ERROR"
print("Functions defined!")

In [None]:
# ========================== RESUME FROM LAST CHECKPOINT ========================== #

# Load cancer dataset
if 'cancer_df' not in globals():
    raise ValueError("Dataset `cancer_df` is not loaded in memory. Make sure it's defined before running the script.")
if 'PaperId' not in cancer_df.columns:
    raise KeyError("Dataset must contain a 'PaperId' column to track progress.")

if os.path.exists(variant_output_file_path):
    processed_df = pd.read_csv(variant_output_file_path)
    processed_articles = set(processed_df['PaperId']) 
    total_processed_articles = len(processed_articles) 
    print(f"Resuming from last processed row. {total_processed_articles} articles completed so far.")
else:
    processed_articles = set()
    total_processed_articles = 0
    print("Starting fresh processing.")

total_batches = (len(cancer_df) // BATCH_SIZE) + (1 if len(cancer_df) % BATCH_SIZE != 0 else 0)
if total_processed_articles == len(cancer_df):
    print("\nAll batches are complete. No more articles to process.")
    print("You have successfully processed the entire dataset.")
    try:
        sys.exit(0)
    except SystemExit:
        pass 

unprocessed_df = cancer_df[~cancer_df['PaperId'].isin(processed_articles)]
total_articles = len(unprocessed_df)    
    
print("Success! All necessary files and directories are set up.")
print("Defined prompt number:", selected_prompt_number)
print(f"Defined batch size to run in chunks: {BATCH_SIZE:,}")
print(f"Total unprocessed articles: {total_articles:,}")

# ========================== TRACK CUMULATIVE RUNTIME ========================== #
if os.path.exists(runtime_file):
    with open(runtime_file, "r") as f:
        lines = f.readlines()
        total_runtime_previous = sum(float(line.split(":")[-1].strip().split()[0])
                                     for line in lines if "Total runtime so far" in line)
else:
    total_runtime_previous = 0.0

In [None]:
# ========================== BATCH PROCESSING ========================== #
start_time = time.time()

# Calculate the next batch number
batch_number = (total_processed_articles // BATCH_SIZE) + 1

for batch_start in range(0, total_articles, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE, total_articles)
    batch = unprocessed_df.iloc[batch_start:batch_end].copy()
    print(f"\nProcessing Batch {batch_number}/{total_batches} ({batch_start + 1} to {batch_end})...")
    batch_start_time = time.time()
    batch['LLM_Prompt'] = batch.apply(lambda row: screen_publication_for_variants(row, selected_prompt_number), axis=1)

    llm_response_column = f'LLM_Response_{modelname}'
    batch[llm_response_column] = batch['LLM_Prompt'].progress_apply(process_with_llm)

    batch_runtime = time.time() - batch_start_time
    batch_to_save = batch[['PaperId', 'PaperTitle', 'Abstract', 'LLM_Prompt', llm_response_column]]

    def generate_progress_bar(percentage, bar_length=20):
        filled_length = int(bar_length * percentage / 100)
        bar = '|' * filled_length + '-' * (bar_length - filled_length)
        return f"[{bar}] {percentage:.2f}%"

    total_articles = batch_end + len(processed_articles) 
    total_articles_to_process = len(unprocessed_df) - batch_end 
    processed_percentage = (total_articles / len(cancer_df)) * 100
    to_process_percentage = (total_articles_to_process / len(cancer_df)) * 100

    if os.path.exists(variant_output_file_path):
        batch_to_save.to_csv(variant_output_file_path, mode='a', header=False, index=False)
    else:
        batch_to_save.to_csv(variant_output_file_path, mode='w', index=False)

    total_runtime_so_far = total_runtime_previous + (time.time() - start_time)

    with open(runtime_file, "a") as f:
        f.write(f"\nBatch {batch_number}/{total_batches} started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Batch Runtime: {batch_runtime:.2f} sec\n")
        f.write(f"Total runtime so far (all runs combined): {total_runtime_so_far:.2f} sec\n")
        f.write(f"Total articles processed in this batch: {batch_end}\n")
        f.write("=" * 60 + "\n")

    logging.info(f"Processed batch {batch_number}/{total_batches} in {batch_runtime:.2f} sec.")

    if processed_percentage >= 100:
        print("\nAll articles have been successfully processed.")
        print("No more articles remaining.")
        break
    else:
        print(f"\nPaused! {batch_end} articles processed in this batch.")
        print(f"{total_articles} articles processed in total {generate_progress_bar(processed_percentage)}")
        print(f"{total_articles_to_process} articles to process in total {generate_progress_bar(to_process_percentage)}")
        print("Check the CSV and runtime file. When ready, rerun the script to continue processing.")
        break

# ========================== FINAL SUMMARY ========================== #
total_runtime = total_runtime_so_far
total_hours = total_runtime // 3600
total_minutes = (total_runtime % 3600) // 60
total_seconds = total_runtime % 60

summary_text = f"""
### Genetic Variant Extraction Summary ###

- Model used: {modelname}
- Prompt number: {selected_prompt_number}
- Total batches processed: {batch_number:,}/{total_batches:,}
- Total articles processed: {total_articles:,}
- Batch runtime: {batch_runtime:,}
- Cumulative runtime: {total_runtime:.2f} seconds ({total_hours:.0f} hr {total_minutes:.0f} min {total_seconds:.2f} sec)
"""
print(summary_text)
with open(runtime_file, "a") as f:
    f.write("\n### Final runtime summary ###\n")
    f.write(f"End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(summary_text)
    f.write("\n" + "=" * 60 + "\n")

print("Final results saved.")

# RERUN FROM LAST CHECKPOINT

# ============================================

# 6) Evaluation of LLMs and prompts

## 6.1) Load dataset

In [None]:
# Set the working directory and file paths
print("Current Working Directory:", os.getcwd())
os.chdir(input_directory)
print("Current Working Directory:", os.getcwd())
llm_ev = "LLM_evaluation_statistics.csv"
llm_df = pd.read_csv(llm_ev)
print(llm_df.head(5))
print(len(llm_df))

# Investigate column names
header = llm_df.columns[1:].tolist()
print(header)
model_names = ['LLama31-70b', 'LLama33-70b', 'DeepSeek_V3', 'DeepSeek-R1-Distill-Llama-70B']
print("\n\nModels selected for evaluation:", model_names)

# Change the working directory
os.chdir(LLM_evaluation_directory)
print("Current Working Directory:", os.getcwd())
llm_df_bn = llm_df.copy()
columns_to_transform = ['Human'] + model_names
llm_df_bn[columns_to_transform] = llm_df_bn[columns_to_transform].applymap(lambda x: 0 if x == "0" else 1)
print(llm_df_bn.head(100))

## 6.2) Calculate confusion matrix

In [None]:
# Select the first model from model_names
first_model = model_names[0]
y_true = llm_df_bn['Human']
y_pred = llm_df_bn[first_model]
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
f1 = f1_score(y_true, y_pred)
results = {
    'Model Name': first_model,
    'True Positives (TP)': tp,
    'False Positives (FP)': fp,
    'False Negatives (FN)': fn,
    'True Negatives (TN)': tn,
    'F1 Score': f1
}
results_df = pd.DataFrame([results])
print(results_df)

In [None]:
# Ensure llm_df_bn exists before running analysis
if 'llm_df_bn' in globals() or 'llm_df_bn' in locals():
    results = {}
    y_true = llm_df_bn['Human']
    # Loop through each model and compute confusion matrix & metrics
    for model in model_names:
        y_pred = llm_df_bn[model]
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        # Calculate performance metrics
        accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0  # Sensitivity
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = f1_score(y_true, y_pred, zero_division=1)        
        results[model] = {
            "True Positives (TP)": tp,
            "False Positives (FP)": fp,
            "False Negatives (FN)": fn,
            "True Negatives (TN)": tn,
            "F1 Score": f1,
            "Sensitivity (Recall)": recall,
            "Specificity": specificity,
            "Precision": precision,
            "Accuracy": accuracy
        }
    results_df = pd.DataFrame(results).T
    print(results_df)

results_df.to_csv("llm_performance_metrics.csv", index=True)
print("\nResults saved as 'llm_performance_metrics.csv'.")

In [None]:
# Convert confusion matrix values to integers to avoid formatting error
confusion_matrix_values = results_df[['True Positives (TP)', 'False Positives (FP)', 'False Negatives (FN)', 'True Negatives (TN)']].astype(int)

# Heatmap of confusion matrix counts
plt.figure(figsize=(10, 6))
sns.heatmap(confusion_matrix_values, annot=True, fmt="d", cmap="Blues", linewidths=0.5)
plt.title("Confusion Matrix Counts for Each Model")
plt.xlabel("Metrics")
plt.ylabel("LLM Models")
plt.show()

# Bar chart for F1 scores
plt.figure(figsize=(8, 5))
sns.barplot(x=results_df.index, y=results_df["F1 Score"], palette="viridis")
plt.title("F1 Scores for Each Model")
plt.xlabel("LLM Models")
plt.ylabel("F1 Score")
plt.xticks(rotation=45)
plt.show()

# Line plot for Sensitivity, Specificity, Precision, and Accuracy
plt.figure(figsize=(10, 6))
plt.plot(results_df.index, results_df["Sensitivity (Recall)"], marker='o', label="Sensitivity (Recall)")
plt.plot(results_df.index, results_df["Specificity"], marker='s', label="Specificity")
plt.plot(results_df.index, results_df["Precision"], marker='^', label="Precision")
plt.plot(results_df.index, results_df["Accuracy"], marker='d', label="Accuracy")
plt.title("Model Performance Metrics")
plt.xlabel("LLM Models")
plt.ylabel("Score")
plt.xticks(rotation=45)
plt.grid()
plt.legend()
plt.show()