# Use LLM (GPT-4) for gene extraction

# 1) Set up libraries and datasets

In [1]:
# Import libraries
import pandas as pd
import os
import numpy as np
import re
import time
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt
import logging
from pathlib import Path
import sys

print("Success!")

Success!


In [3]:
# Set the working directory and file paths
working_directory = "/data/JH/marie/TrendyVariants/ICIMTH"
input_directory = "/data/JH/marie/TrendyVariants/Input"
output_directory = "/data/JH/marie/TrendyVariants/Output"
articles_file = "clean_df_step4.csv" # We want to analyze the dataset after cleaning!
genes_file = "oncomine_ngs_panel.csv"

#Load OncoMine genes file
os.chdir(input_directory)
genes = pd.read_csv(genes_file, header=None)
gene_list = genes[0].tolist()
print("Genes import successful!")

# Load the articles file
os.chdir(output_directory)
if "full_articles" not in globals():
    full_articles = pd.read_csv(articles_file)
    print(f"Loaded {len(full_articles)} articles from CSV.")
else:
    print("Using preloaded full_articles from memory.")
articles = full_articles.head(100)
print("Article import successful!")
print(f"\nImported {len(articles):,} articles with {len(articles.columns):,} selected columns.")
print(f"Imported {len(gene_list):,} oncomine genes.")

# Get the number of rows and columns
num_rows = articles.shape[0]
num_columns = articles.shape[1]
os.chdir(working_directory)
print("\nCurrent Working Directory:", os.getcwd())

Genes import successful!
Loaded 2128318 articles from CSV.
Article import successful!

Imported 100 articles with 9 selected columns.
Imported 161 oncomine genes.

Current Working Directory: /data/JH/marie/TrendyVariants/ICIMTH


# 2) Select and set up LLMs

In [5]:
# Set up a language model to answer the questions
!pip install OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
print("Success!")

Success!


In [7]:
# Define the models to be tested
models = ["llama31-70b", "llama33-70b", "deepseek_v3", 
          "deepseek_r1", "deepseek_r1_distill_llama_70b","gpt4o"]

# Mapping model names to their full Hugging Face or DeepInfra identifiers
model_fullnames = {
    "llama31-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "llama33-70b": "meta-llama/Llama-3.3-70B-Instruct",
    "deepseek_v3": "deepseek-ai/DeepSeek-V3",
    "deepseek_r1": "deepseek-ai/DeepSeek-R1",
    "deepseek_r1_distill_llama_70b": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
    "gpt4o": "gpt-4o"
}

SYSTEM_MSG = "You are a helpful medical question answering assistant. Please carefully follow the exact instructions and do not provide explanations."
modelname = models[5] #in Python, list indexing starts from 0, not 1.

if modelname in [ "llama2-3b" ]:  # Local model
    model, tokenizer = load(model_fullnames[modelname])
    def generateFromPrompt(prompt):
        if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
            messages = [{"role": "system", "content": SYSTEM_MSG},
                {"role": "user", "content": prompt}]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            response = generate(model, tokenizer, prompt=prompt, verbose=False)
            return response
elif modelname in [ "gpt35", "gpt4o" ]: # OpenAI models
    client = OpenAI(
       api_key='API_key1'  
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
elif modelname in [ "llama31-70b" , "llama33-70b" , "deepseek_v3" , "deepseek_r1" , "deepseek_r1_distill_llama_70b"]:  # DeepInfra models
    client = OpenAI(
        api_key = "API_key2",
        base_url="https://api.deepinfra.com/v1/openai",
    )
    def generateFromPrompt(promptStr,maxTokens=100):
      messages=[
        {"role": "system", "content": SYSTEM_MSG},
        {"role": "user", "content": promptStr}
      ]
      completion = client.chat.completions.create(
        model=model_fullnames[modelname],
        messages=messages)
      response=completion.choices[0].message.content
      return(response)
    
generateFromPrompt("hello!")

'Hello! How can I assist you today?'

In [8]:
print("All installed models:",   models)
print("Current model in use:",   modelname)

All installed models: ['llama31-70b', 'llama33-70b', 'deepseek_v3', 'deepseek_r1', 'deepseek_r1_distill_llama_70b', 'gpt4o']
Current model in use: gpt4o


# 3) Define prompts

In [10]:
# Define multiple prompts in a dictionary
# The model will extract genes and gene products, filtering against `gene_list`
# Use prompt #1


PROMPTS = {
    1: lambda title, abstract: (
        f"Extract all gene names and their gene products (e.g., TP53 and p53) from the given title and abstract."
        f"Only return genes that are present in the following predefined list:\n"
        f"{', '.join(gene_list)}.\n"
        f"If a gene is mentioned multiple times, only list it once.\n"
        f"Return the extracted genes as a **comma-separated list** (e.g., 'CTNNB1, RET, BRCA1').\n"
        f"If no genes from the list are present, return an **empty response** (do not return 'None' or 'No genes found').\n"
        f"Strictly **no additional information, no explanations, no formatting**.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    ),

    2: lambda title, abstract: (
        f"Identify all gene symbols and their corresponding gene products mentioned in the given title and abstract.\n"
        f"Only include genes that exist in the predefined list:\n"
        f"{', '.join(gene_list)}.\n"
        f"Return the result as a simple **comma-separated list** (e.g., 'BRCA1, TP53, EGFR').\n"
        f"If no matching genes are found, return **an empty response** (do not print anything).\n"
        f"Ensure strict compliance:\n"
        f"- Do not include extra text or explanations.\n"
        f"- No formatting, no bullet points, no sentence structure.\n\n"
        f"Title: {title}\nAbstract: {abstract}"
    )
}

print("Prompts for gene extraction successfully defined!")

Prompts for gene extraction successfully defined!


# 4) Run genetic variant extraction with LLM

In [16]:
# ========================== CONFIGURATION ========================== #
# Define output directory
os.chdir(working_directory)

# Ensure tqdm progress bar works with pandas
tqdm.pandas()

# Define batch size for processing
BATCH_SIZE = 101

# Set model name and prompt selection
modelname = modelname
selected_prompt_number = 2  # Change dynamically as needed

# Get today's date and current time
today_date = datetime.today().strftime('%Y-%m-%d')
start_time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')  # Includes date and time

# ========================== FILE PATHS USING ============================ #
variant_output_file_path = os.path.join(working_directory, f"ICIMTH_LLM_variant_extraction_{modelname}_prompt{selected_prompt_number}.csv")
runtime_file = os.path.join(working_directory, f"ICIMTH_runtime_summary_{modelname}_prompt{selected_prompt_number}.txt")
progress_log_file = os.path.join(working_directory, f"ICIMTH_progress_log_{modelname}_prompt{selected_prompt_number}.txt")

# ========================== ENSURE FILES EXIST ========================== #
def ensure_file_exists(file_path, header_text=None):
    """Creates the file if it does not exist and optionally writes a header."""
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            if header_text:
                f.write(f"{header_text}\n")
                f.write("=" * 60 + "\n")

ensure_file_exists(runtime_file, f"### Runtime Log - {today_date} ###\nStart Time: {start_time_str}")
ensure_file_exists(progress_log_file, f"### Progress Log - {today_date} ###")

if not os.path.exists(variant_output_file_path):
    with open(variant_output_file_path, "w") as f:
        f.write("PaperId,PaperTitle,Abstract,LLM_Prompt,LLM_Response\n")

# ========================== LOGGING SETUP ========================== #
logging.basicConfig(
    filename=progress_log_file,
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

print("Success! All necessary files and directories are set up.")
print("Script Start Time:", start_time_str)
print("Defined prompt number:", selected_prompt_number)
print("Defined batch size:", BATCH_SIZE)

# ========================== FUNCTION DEFINITIONS ========================== #

def screen_publication_for_variants(row, prompt_number):
    """Generates a dynamic prompt based on the selected prompt number."""
    title = row['PaperTitle']
    abstract = row['Abstract']
    
    # Validate the prompt number
    if prompt_number not in PROMPTS:
        raise ValueError(f"Invalid prompt number: {prompt_number}. Choose between 0-5.")

    # Dynamically apply the correct prompt from `PROMPTS`
    return PROMPTS[prompt_number](title, abstract)

def process_with_llm(prompt):
    """Process the given prompt using the LLM model."""
    try:
        response = generateFromPrompt(prompt)  # Replace this with actual LLM function
        return response
    except Exception as e:
        logging.error(f"LLM processing error: {e}")
        return "ERROR"

print("Functions defined!")

Success! All necessary files and directories are set up.
Script Start Time: 2025-03-10 12:52:58
Defined prompt number: 2
Defined batch size: 101
Functions defined!


In [17]:
# ========================== RESUME FROM LAST CHECKPOINT ========================== #

# Load cancer dataset (ensure it's already loaded in memory)
if 'articles' not in globals():
    raise ValueError("Dataset `articles` is not loaded in memory. Make sure it's defined before running the script.")

# Ensure 'PaperId' column exists for tracking progress
if 'PaperId' not in articles.columns:
    raise KeyError("Dataset must contain a 'PaperId' column to track progress.")

# Check if output file exists and count previously processed articles
if os.path.exists(variant_output_file_path):
    processed_df = pd.read_csv(variant_output_file_path)
    processed_articles = set(processed_df['PaperId'])  # Track completed articles
    total_processed_articles = len(processed_articles)  # Total processed so far
    print(f"Resuming from last processed row. {total_processed_articles} articles completed so far.")
else:
    processed_articles = set()
    total_processed_articles = 0
    print("Starting fresh processing.")

# Calculate total batches
total_batches = (len(articles) // BATCH_SIZE) + (1 if len(articles) % BATCH_SIZE != 0 else 0)

# If all articles are processed, print a final message and stop execution
if total_processed_articles == len(articles):
    print("\nAll batches are complete. No more articles to process.")
    print("You have successfully processed the entire dataset.")
    
    try:
        sys.exit(0)  #Exit normally
    except SystemExit:
        pass  #Suppress the SystemExit message in Jupyter


# Filter only unprocessed articles
unprocessed_df = articles[~articles['PaperId'].isin(processed_articles)]
total_articles = len(unprocessed_df)    
    
print("Success! All necessary files and directories are set up.")
print("Defined prompt number:", selected_prompt_number)
print("Defined batch size to run in chunks:", BATCH_SIZE)
print(f"Total unprocessed articles: {total_articles}")

# ========================== TRACK CUMULATIVE RUNTIME ========================== #
# Load previous runtime if exists
if os.path.exists(runtime_file):
    with open(runtime_file, "r") as f:
        lines = f.readlines()
        total_runtime_previous = sum(float(line.split(":")[-1].strip().split()[0])
                                     for line in lines if "Total runtime so far" in line)
else:
    total_runtime_previous = 0.0  # Start fresh if no file exists

Resuming from last processed row. 0 articles completed so far.
Success! All necessary files and directories are set up.
Defined prompt number: 2
Defined batch size to run in chunks: 101
Total unprocessed articles: 100


In [18]:
# ========================== BATCH PROCESSING ========================== #
start_time = time.time()
os.chdir(working_directory)

# Calculate the next batch number
batch_number = (total_processed_articles // BATCH_SIZE) + 1

for batch_start in range(0, total_articles, BATCH_SIZE):
    batch_end = min(batch_start + BATCH_SIZE, total_articles)
    batch = unprocessed_df.iloc[batch_start:batch_end].copy()

    print(f"\nProcessing Batch {batch_number}/{total_batches} ({batch_start + 1} to {batch_end})...")

    batch_start_time = time.time()

    # Generate prompts dynamically
    batch['LLM_Prompt'] = batch.apply(lambda row: screen_publication_for_variants(row, selected_prompt_number), axis=1)

    # Process with LLM
    llm_response_column = f'LLM_Response_{modelname}'
    batch[llm_response_column] = batch['LLM_Prompt'].progress_apply(process_with_llm)

    batch_runtime = time.time() - batch_start_time

    # Keep only required columns before saving to CSV
    batch_to_save = batch[['PaperId', 'PaperTitle', 'Abstract', 'LLM_Prompt', llm_response_column]]

    # Calculate total progress
    def generate_progress_bar(percentage, bar_length=20):
        filled_length = int(bar_length * percentage / 100)
        bar = '|' * filled_length + '-' * (bar_length - filled_length)
        return f"[{bar}] {percentage:.2f}%"

    total_articles = batch_end + len(processed_articles)  # Updated for total count
    total_articles_to_process = len(unprocessed_df) - batch_end  # Remaining articles
    processed_percentage = (total_articles / len(articles)) * 100
    to_process_percentage = (total_articles_to_process / len(articles)) * 100

    # Save progress to CSV
    if os.path.exists(variant_output_file_path):
        batch_to_save.to_csv(variant_output_file_path, mode='a', header=False, index=False)
    else:
        batch_to_save.to_csv(variant_output_file_path, mode='w', index=False)

    total_runtime_so_far = total_runtime_previous + (time.time() - start_time)

    with open(runtime_file, "a") as f:
        f.write(f"\nBatch {batch_number}/{total_batches} started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Batch Runtime: {batch_runtime:.2f} sec\n")
        f.write(f"Total runtime so far (all runs combined): {total_runtime_so_far:.2f} sec\n")
        f.write(f"Total articles processed in this batch: {batch_end}\n")
        f.write("=" * 60 + "\n")

    logging.info(f"Processed batch {batch_number}/{total_batches} in {batch_runtime:.2f} sec.")

    # If 100% processed, show final message
    if processed_percentage >= 100:
        print("\nAll articles have been successfully processed.")
        print("No more articles remaining.")
        break
    else:
        print(f"\nPaused! {batch_end} articles processed in this batch.")
        print(f"{total_articles} articles processed in total {generate_progress_bar(processed_percentage)}")
        print(f"{total_articles_to_process} articles to process in total {generate_progress_bar(to_process_percentage)}")
        print("Check the CSV and runtime file. When ready, rerun the script to continue processing.")
        break  # Stops execution after the first batch; re-run script to continue

# ========================== FINAL SUMMARY ========================== #
total_runtime = total_runtime_so_far
total_hours = total_runtime // 3600
total_minutes = (total_runtime % 3600) // 60
total_seconds = total_runtime % 60

summary_text = f"""
### Genetic Variant Extraction Summary ###

- Model used: {modelname}
- Prompt number: {selected_prompt_number}
- Total batches processed: {batch_number}/{total_batches}
- Total rows processed: {total_articles}
- Cumulative runtime: {total_runtime:.2f} seconds ({total_hours:.0f} hr {total_minutes:.0f} min {total_seconds:.2f} sec)
"""

print(summary_text)

# Save final runtime summary
with open(runtime_file, "a") as f:
    f.write("\n### Final runtime summary ###\n")
    f.write(f"End Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(summary_text)
    f.write("\n" + "=" * 60 + "\n")

print("Final results saved.")


Processing Batch 1/1 (1 to 100)...


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:07<00:00,  1.48it/s]


All articles have been successfully processed.
No more articles remaining.

### Genetic Variant Extraction Summary ###

- Model used: gpt4o
- Prompt number: 2
- Total batches processed: 1/1
- Total rows processed: 100
- Cumulative runtime: 67.68 seconds (0 hr 1 min 7.68 sec)

Final results saved.





# **RERUN "FROM LAST CHECKPOINT**" to continue batch processing

# Make binary matrix for gene identification evaluation

In [39]:
LLM_file = "ICIMTH_LLM_variant_extraction_gpt4o_prompt1.csv"

#Load OncoMine genes file
os.chdir(working_directory)
LLM_variant_df = pd.read_csv(LLM_file)
print("\nFinal DataFrame Preview:")
print(LLM_variant_df.head(50))
print("\nLength of dataframe",len(LLM_variant_df))

# Print all column names
print("\nColumns in the final DataFrame:")
print(LLM_variant_df.columns)


Final DataFrame Preview:
       PaperId                                         PaperTitle  \
0   4405900941  Oncological outcomes following extreme oncopla...   
1   4405941037  Regarding: Alpha1 antitrypsin deficiency assoc...   
2   4405952152  Incidence and risk factors of immune checkpoin...   
3   4405971509  Pedicle ossification following mandibular reco...   
4   4406120665  Artificial Intelligence for Autonomous Robotic...   
5   4406120699  The Necessity of Human Papillomavirus Vaccinat...   
6   4393515872          Jagodinsky et al 2023 bulk RNA-seq counts   
7   4403620280  Two cases of pancreatic tuberculosis in immuno...   
8   4405796123  Successful repositioning of mertansine for imp...   
9   4405900528  Tissue Prior to the Initial Hematoxylin-Eosin ...   
10  4405900803  Detection of urine circulating tumor DNA using...   
11  4405900916  A qualitative study of the perioperative exerc...   
12  4405900978  In vitro characterization of some of the anti-...   
13  4405

In [40]:
import pandas as pd
import os
import time
from fuzzywuzzy import process, fuzz
from tqdm import tqdm

# ========================== CONFIGURATION ========================== #

# Ensure LLM_variant_df is correctly defined
if 'LLM_variant_df' not in globals():
    raise ValueError("LLM_variant_df is not defined. Please ensure batch processing is completed.")

# Define the LLM response column
llm_response_column = 'LLM_Response'

# Ensure necessary columns exist
required_columns = ["PaperId", "PaperTitle", "Abstract", llm_response_column]
missing_columns = [col for col in required_columns if col not in LLM_variant_df.columns]
if missing_columns:
    raise KeyError(f"Missing columns in dataset: {missing_columns}")

# ========================== NORMALIZATION FUNCTION ========================== #

# Expand gene list dynamically
expanded_gene_list = {gene.upper(): {gene.upper()} for gene in gene_list}
print(f"Expanded gene list contains {len(expanded_gene_list)} genes.")

def normalize_extracted_entities(found_terms):
    """Normalize extracted genes using fuzzy matching against `gene_list`."""
    normalized_entities = set()
    for term in found_terms:
        term_upper = term.upper()

        if term_upper in expanded_gene_list:
            normalized_entities.add(term_upper)
        else:
            match = process.extractOne(term_upper, expanded_gene_list.keys(), scorer=fuzz.ratio)
            if match:
                best_match, score = match[:2]
                if score > 85:  # Fuzzy match threshold
                    normalized_entities.add(best_match)

    return normalized_entities

# ========================== GENE EXTRACTION FUNCTION ========================== #

def extract_genes_from_text(text):
    """Extract gene mentions from LLM response text."""
    if pd.isna(text) or len(text.strip()) == 0:
        return set()

    words = text.replace(",", "").split()  # Split text into words
    extracted_terms = set(word.upper() for word in words if word.upper() in expanded_gene_list)

    return normalize_extracted_entities(extracted_terms)

# ========================== PROCESSING LLM RESPONSES ========================== #

start_time = time.strftime("%Y-%m-%d %H:%M:%S")
start_timestamp = time.time()

print(f"Processing {len(LLM_variant_df)} articles for gene extraction. Started at {start_time}")

# Apply extraction to the dataset
tqdm.pandas(desc="Extracting genes from LLM responses")
LLM_variant_df["Extracted_Genes"] = LLM_variant_df[llm_response_column].progress_apply(extract_genes_from_text)

# Convert extracted gene sets to comma-separated strings
LLM_variant_df["Extracted_Genes"] = LLM_variant_df["Extracted_Genes"].apply(lambda genes: ", ".join(genes) if genes else "")

# ========================== CREATE BINARY MATRIX ========================== #

print("Creating binary gene presence matrix...")

# Convert extracted genes into a list for each row
LLM_variant_df["Extracted_Gene_List"] = LLM_variant_df["Extracted_Genes"].apply(lambda x: x.split(", ") if isinstance(x, str) else [])

# Generate binary columns for each gene
binary_gene_data = {gene: LLM_variant_df["Extracted_Gene_List"].apply(lambda genes: 1 if gene in genes else 0) for gene in gene_list}
binary_gene_df = pd.DataFrame(binary_gene_data)

# Merge binary matrix with original dataframe
LLM_variant_df = pd.concat([LLM_variant_df, binary_gene_df], axis=1)

# Count total gene mentions per article (Sum_Entity_Mentions)
LLM_variant_df["Sum_Entity_Mentions"] = binary_gene_df.sum(axis=1)

# ========================== SAVE RESULTS ========================== #

# Define output filename
output_filename = os.path.join(working_directory, "LLM_evaluation_gpt4o.csv")

# Drop unnecessary columns before saving
LLM_variant_df.drop(columns=["Extracted_Gene_List"], errors="ignore").to_csv(output_filename, index=False)

print(f"\nGene extraction complete! Results saved as: {output_filename}")

# ========================== GENERATE SUMMARY ========================== #

summary_results = LLM_variant_df["Sum_Entity_Mentions"].sum()

print("\n### Gene Extraction Summary ###")
print(f"Total gene mentions found: {summary_results}")

# Save summary to file
summary_file = os.path.join(working_directory, "LLM_Gene_Extraction_Summary.txt")
with open(summary_file, "w") as f:
    f.write("### Gene Extraction Summary ###\n")
    f.write(f"Total gene mentions found: {summary_results}\n")

print(f"\nExtraction summary saved in: {summary_file}")



Expanded gene list contains 161 genes.
Processing 100 articles for gene extraction. Started at 2025-03-10 14:07:24


Extracting genes from LLM responses: 100%|████████████████████████████████████████████████████| 100/100 [00:00<00:00, 216536.09it/s]

Creating binary gene presence matrix...

Gene extraction complete! Results saved as: /data/JH/marie/TrendyVariants/ICIMTH/LLM_evaluation_gpt4o.csv

### Gene Extraction Summary ###
Total gene mentions found: 31

Extraction summary saved in: /data/JH/marie/TrendyVariants/ICIMTH/LLM_Gene_Extraction_Summary.txt



