# LLM-based variant extration

# 1) Set up libraries and datasets

In [None]:
# Import libraries
import os
import re
import sys
import time
import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
import seaborn as sns
from pathlib import Path
from functools import reduce
import matplotlib.pyplot as plt
from collections import Counter
from datetime import datetime, timedelta
from sklearn.metrics import f1_score, precision_score, recall_score
print("Success!")

In [None]:
# Set the working directory and file paths
working_directory = "WORKING_DIRECTORY"
NLP_directory = "NLP_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
articles_file = "BioBERT_file.csv"

# Load the articles file
os.chdir(output_directory)
if "full_articles" not in globals():
    full_articles = pd.read_csv(articles_file)
    print(f"Loaded {len(full_articles)} articles from CSV.")
else:
    print("Using preloaded full_articles from memory.")
articles = full_articles
print("Article import successful!")
print(f"\nImported {len(articles):,} articles with {len(articles.columns):,} selected columns.")

# Get the number of rows and columns
num_rows = articles.shape[0]
num_columns = articles.shape[1]
os.chdir(working_directory)
print("\nCurrent Working Directory:", os.getcwd())

# 2) Select and set up LLMs

In [None]:
# Set up a language model to answer the questions
!pip install OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
print("Success!")

In [None]:
# Define the models to be tested
models = ["llama31-70b", "llama33-70b", "deepseek_v3", 
          "deepseek_r1", "deepseek_r1_distill_llama_70b","gpt4o"]

# Mapping model names to their full Hugging Face or DeepInfra identifiers
model_fullnames = {
    "llama31-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "llama33-70b": "meta-llama/Llama-3.3-70B-Instruct",
    "deepseek_v3": "deepseek-ai/DeepSeek-V3",
    "deepseek_r1": "deepseek-ai/DeepSeek-R1",
    "deepseek_r1_distill_llama_70b": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
    "gpt4o": "gpt-4o"
}

SYSTEM_MSG = "You are a helpful medical question answering assistant. Please carefully follow the exact instructions and do not provide explanations."
modelname = models[4] #in Python, list indexing starts from 0, not 1.

if modelname in [ "llama2-3b" ]:  # Local model
    model, tokenizer = load(model_fullnames[modelname])
    def generateFromPrompt(prompt):
        if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
            messages = [{"role": "system", "content": SYSTEM_MSG},
                {"role": "user", "content": prompt}]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            response = generate(model, tokenizer, prompt=prompt, verbose=False)
            return response
elif modelname in [ "gpt35", "gpt4o" ]: # OpenAI models
    client = OpenAI(
       api_key='API_key1'
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
elif modelname in [ "llama31-70b" , "llama33-70b" , "deepseek_v3" , "deepseek_r1" , "deepseek_r1_distill_llama_70b"]:  # DeepInfra models
    client = OpenAI(
        api_key = "API_key2",
        base_url="https://api.deepinfra.com/v1/openai",
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
generateFromPrompt("hello!")

In [None]:
print("All installed models:",   models)
print("Current model in use:",   modelname)

# 3) Define prompts

In [None]:
# Define multiple prompts in a dictionary
# Use prompt #3
PROMPTS = {
    1: lambda title, abstract: (
        f"Analyze the following title and abstract to identify **genetic variants** mentioned," 
        f"including specific variant names, base changes, and protein alterations.\n"
        f"Only report well-defined variants, such as specific reference IDs or exact mutations."
        f"Ignore vague terms like 'mutation detected' or generic mentions of genes (e.g., BRCA1, ATM)."
        f"If no specific variant is mentioned, respond with: 'No variant.'\n\n"
        f"For each identified variant, provide:\n"
        f"- **Variant Name**: e.g., c.2138C>G\n"
        f"- **Gene Name**: e.g., BRCA1\n"
        f"Return results in the following format: 'Variant: Variant Name, Gene Name' or 'No variant.'\n\n"
        f"Title: {title}\nAbstract: {abstract}\n"
    ),

    2: lambda title, abstract: (
        f"Extract all genetic variants from the following title and abstract. Only return:\n"
        f"- **HGVS Notation** (c., p., g.), e.g., c.2138C>G, p.Arg713Trp, g.32389625G>A\n"
        f"- **Protein changes** (e.g., V600E, Arg713Trp)\n"
        f"- **rsIDs** (e.g., rs121913529)\n"
        f"- Ignore vague terms (e.g., 'mutation found').\n\n"
        f"### Format response strictly as:\n"
        f"'Variant: <mutation>, Gene: <gene>' per line, or 'No variant' if none found.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),
    
    3: lambda title, abstract: (
        f"Extract only specific genetic variants from the text. Return strictly:\n"
        f"- **HGVS Notation** (c., p., g.) e.g., c.2138C>G, p.Arg713Trp\n"
        f"- **Protein changes** (e.g., V600E, Arg713Trp)\n"
        f"- **rsIDs** (e.g., rs121913529)\n"
        f"- Ignore vague terms (e.g., 'mutation found').\n\n"
        f"### Format:\n"
        f"- Variant: 'Variant: <mutation>, Gene: <gene>' per line\n"
        f"- If none, return: 'No variant'\n"
        f"- No extra text, no explanations.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    )
}
print("Prompts for gene extraction successfully defined!")

# 4) Run genetic variant extraction with LLM

In [None]:
# ========================== CONFIGURATION ========================== #
os.chdir(working_directory)
tqdm.pandas()

# Define batch size for processing
BATCH_SIZE = 1000
modelname = modelname
selected_prompt_number = 3
today_date = datetime.today().strftime('%Y-%m-%d')
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 

# ========================== FILE PATHS USING ============================ #
variant_output_file_path = os.path.join(working_directory, f"LLM_variant_extraction_{modelname}_prompt{selected_prompt_number}.csv")
runtime_file = os.path.join(working_directory, f"runtime_summary_{modelname}_prompt{selected_prompt_number}.txt")
progress_log_file = os.path.join(working_directory, f"progress_log_{modelname}_prompt{selected_prompt_number}.txt")

# ========================== ENSURE FILES EXIST ========================== #
def ensure_file_exists(file_path, header_text=None):
    """Creates the file if it does not exist and optionally writes a header."""
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            if header_text:
                f.write(f"{header_text}\n")
                f.write("=" * 60 + "\n")

ensure_file_exists(runtime_file, f"### Runtime Log - {today_date} ###\nStart Time: {start_time_str}")
ensure_file_exists(progress_log_file, f"### Progress Log - {today_date} ###")
if not os.path.exists(variant_output_file_path):
    with open(variant_output_file_path, "w") as f:
        f.write("PaperId,PaperTitle,Abstract,LLM_Prompt,LLM_Response\n")

# ========================== LOGGING SETUP ========================== #
logging.basicConfig(
    filename=progress_log_file,
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
print("Success! All necessary files and directories are set up.")
print("Script Start Time:", start_time_str)
print("Defined prompt number:", selected_prompt_number)
print("Defined batch size:", BATCH_SIZE)

# ========================== FUNCTION DEFINITIONS ========================== #

def screen_publication_for_variants(row, prompt_number):
    """Generates a dynamic prompt based on the selected prompt number."""
    title = row['PaperTitle']
    abstract = row['Abstract']
    if prompt_number not in PROMPTS:
        raise ValueError(f"Invalid prompt number: {prompt_number}. Choose between 0-5.")
    return PROMPTS[prompt_number](title, abstract)

def process_with_llm(prompt):
    """Process the given prompt using the LLM model."""
    try:
        response = generateFromPrompt(prompt) 
        return response
    except Exception as e:
        logging.error(f"LLM processing error: {e}")
        return "ERROR"
print("Functions defined!")

In [None]:
# ========================== RESUME FROM LAST CHECKPOINT ========================== #
# Load cancer dataset (ensure it's already loaded in memory)
if 'articles' not in globals():
    raise ValueError("Dataset `articles` is not loaded in memory. Make sure it's defined before running the script.")

# Ensure 'PaperId' column exists for tracking progress
if 'PaperId' not in articles.columns:
    raise KeyError("Dataset must contain a 'PaperId' column to track progress.")

# Check if output file exists and count previously processed articles
if os.path.exists(variant_output_file_path):
    processed_df = pd.read_csv(variant_output_file_path)
    processed_articles = set(processed_df['PaperId'])  
    total_processed_articles = len(processed_articles) 
    print(f"Resuming from last processed row. {total_processed_articles} articles completed so far.")
else:
    processed_articles = set()
    total_processed_articles = 0
    print("Starting fresh processing.")

# Calculate total batches
total_batches = (len(articles) // BATCH_SIZE) + (1 if len(articles) % BATCH_SIZE != 0 else 0)
if total_processed_articles == len(articles):
    print("\nAll batches are complete. No more articles to process.")
    print("You have successfully processed the entire dataset.")
    
    try:
        sys.exit(0)  #Exit normally
    except SystemExit:
        pass  #Suppress the SystemExit message in Jupyter


# Filter only unprocessed articles
unprocessed_df = articles[~articles['PaperId'].isin(processed_articles)]
total_articles = len(unprocessed_df)    
    
print("Success! All necessary files and directories are set up.")
print("Defined prompt number:", selected_prompt_number)
print("Defined model:", modelname)
print("Defined batch size to run in chunks:", BATCH_SIZE)
print(f"Total unprocessed articles: {total_articles}")

# ========================== TRACK CUMULATIVE RUNTIME ========================== #
# Load previous runtime if exists
if os.path.exists(runtime_file):
    with open(runtime_file, "r") as f:
        lines = f.readlines()
        total_runtime_previous = sum(float(line.split(":")[-1].strip().split()[0])
                                     for line in lines if "Total runtime so far" in line)
else:
    total_runtime_previous = 0.0 

In [None]:
# ========================== BATCH PROCESSING ========================== #
start_time = time.time()
os.chdir(working_directory)

# Calculate the next batch number
batch_number = (total_processed_articles // BATCH_SIZE) + 1
for batch_start in range(0, total_articles, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE, total_articles)
    batch = unprocessed_df.iloc[batch_start:batch_end].copy()
    print(f"\nProcessing Batch {batch_number}/{total_batches} ({batch_start + 1} to {batch_end})...")

    batch_start_time = time.time()
    batch['LLM_Prompt'] = batch.apply(lambda row: screen_publication_for_variants(row, selected_prompt_number), axis=1)

    # Process with LLM
    llm_response_column = f'LLM_Response_{modelname}'
    batch[llm_response_column] = batch['LLM_Prompt'].progress_apply(process_with_llm)
    batch_runtime = time.time() - batch_start_time
    batch_to_save = batch[['PaperId', 'PaperTitle', 'Abstract', 'LLM_Prompt', llm_response_column]]

    # Calculate total progress
    def generate_progress_bar(percentage, bar_length=20):
        filled_length = int(bar_length * percentage / 100)
        bar = '|' * filled_length + '-' * (bar_length - filled_length)
        return f"[{bar}] {percentage:.2f}%"
    total_articles = batch_end + len(processed_articles)  
    total_articles_to_process = len(unprocessed_df) - batch_end 
    processed_percentage = (total_articles / len(articles)) * 100
    to_process_percentage = (total_articles_to_process / len(articles)) * 100

    # Save progress to CSV
    if os.path.exists(variant_output_file_path):
        batch_to_save.to_csv(variant_output_file_path, mode='a', header=False, index=False)
    else:
        batch_to_save.to_csv(variant_output_file_path, mode='w', index=False)

    total_runtime_so_far = total_runtime_previous + (time.time() - start_time)

    with open(runtime_file, "a") as f:
        f.write(f"\nBatch {batch_number}/{total_batches} started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Batch Runtime: {batch_runtime:.2f} sec\n")
        f.write(f"Total runtime so far (all runs combined): {total_runtime_so_far:.2f} sec\n")
        f.write(f"Total articles processed in this batch: {batch_end}\n")
        f.write("=" * 60 + "\n")

    logging.info(f"Processed batch {batch_number}/{total_batches} in {batch_runtime:.2f} sec.")

    if processed_percentage >= 100:
        print("\nAll articles have been successfully processed.")
        print("No more articles remaining.")
        break
    else:
        print(f"\nPaused! {batch_end} articles processed in this batch.")
        print(f"{total_articles} articles processed in total {generate_progress_bar(processed_percentage)}")
        print(f"{total_articles_to_process} articles to process in total {generate_progress_bar(to_process_percentage)}")
        print("Check the CSV and runtime file. When ready, rerun the script to continue processing.")
        break 

# ========================== FINAL SUMMARY ========================== #
total_runtime = total_runtime_so_far
total_hours = total_runtime // 3600
total_minutes = (total_runtime % 3600) // 60
total_seconds = total_runtime % 60

summary_text = f"""
### Genetic Variant Extraction Summary ###

- Model used: {modelname}
- Prompt number: {selected_prompt_number}
- Total batches processed: {batch_number}/{total_batches}
- Total rows processed: {total_articles}
- Cumulative runtime: {total_runtime:.2f} seconds ({total_hours:.0f} hr {total_minutes:.0f} min {total_seconds:.2f} sec)
"""

print(summary_text)
with open(runtime_file, "a") as f:
    f.write("\n### Final runtime summary ###\n")
    f.write(f"End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(summary_text)
    f.write("\n" + "=" * 60 + "\n")
print("Final results saved.")