# BioBERT-based gene extraction

# 1) Install libraries and load dataset

In [15]:
!pip install transformers
print("Success!")
from transformers import pipeline

import pandas as pd
import os
import re
import time
from fuzzywuzzy import process, fuzz
from tqdm import tqdm
import torch
print("Import successful!")

Import successful!


In [116]:
# Load BioBERT Genetic NER Model (Using CUDA if available)
biobert_model = pipeline(
    "ner",
    model="alvaroalon2/biobert_genetic_ner",
    tokenizer="alvaroalon2/biobert_genetic_ner",
    device=0 if torch.cuda.is_available() else -1
)
print("BioBERT model loaded successfully!")

Device set to use cuda:0


BioBERT model loaded successfully!


In [117]:
# Set the working directory and file paths
working_directory = "WORKING_DIRECTORY"
input_directory = "INPUT_DIRECTORY"
output_directory = "OUTPUT_DIRECTORY"
articles_file = "articles.csv"
genes_file = "genes.csv"

# Change the working directory
os.chdir(output_directory)
print("Current Working Directory:", os.getcwd())

os.chdir(input_directory)
genes = pd.read_csv(genes_file, header=None)
gene_list = genes[0].tolist()
print("Genes import successful!")

os.chdir(output_directory)
if "full_articles" not in globals():
    full_articles = pd.read_csv(articles_file)
    print(f"Loaded {len(full_articles)} articles from CSV.")
else:
    print("Using preloaded full_articles from memory.")
articles = full_articles.head(100)
print("Article import successful!")
print(f"\nImported {len(articles)} articles with {len(articles.columns)} selected columns.")
print(f"Imported {len(gene_list):,} oncomine genes.")

num_rows = articles.shape[0]
num_columns = articles.shape[1]
os.chdir(working_directory)
print("\nCurrent Working Directory:", os.getcwd())

Genes import successful!
Using preloaded full_articles from memory.
Article import successful!

Imported 100 articles with 9 selected columns.
Imported 161 oncomine genes.

Current Working Directory: /data/JH/marie/TrendyVariants/ICIMTH


# 2) Run BioBERT model

In [118]:
##### BioBERT with slinding window
# Function to fetch gene synonyms from MyGene.info API
def get_gene_synonyms(gene_symbol):
    """Fetches known synonyms, including protein products, for a given gene from MyGene.info."""
    url = f"https://mygene.info/v3/query?q={gene_symbol}&fields=symbol,alias,other_names"
    try:
        response = requests.get(url).json()
        synonyms = set()
        for hit in response.get("hits", []):
            if "symbol" in hit:
                synonyms.add(hit["symbol"].upper())
            if "alias" in hit:
                synonyms.update([alias.upper() for alias in hit["alias"]])
            if "other_names" in hit:
                synonyms.update([name.upper() for name in hit["other_names"]])
        return synonyms
    except:
        return {gene_symbol.upper()} 

# Expand the gene list dynamically with synonyms
expanded_gene_list = {gene.upper(): get_gene_synonyms(gene) for gene in gene_list}
print(f"Expanded gene list contains {len(expanded_gene_list)} genes with synonyms.")

# Function to normalize extracted genes
def normalize_extracted_genes(found_terms):
    """Normalize and map extracted entities to closest known gene or protein names."""
    normalized_genes = set()
    for term in found_terms:
        term_upper = term.upper()
        if term_upper in expanded_gene_list:
            normalized_genes.add(term_upper)
            continue
        cleaned_term = re.sub(r"[\[\]\(\),-]", " ", term_upper)
        cleaned_words = cleaned_term.split()
        for word in cleaned_words:
            if word in expanded_gene_list:
                normalized_genes.add(word)
        if not any(gene in normalized_genes for gene in cleaned_words):
            match = process.extractOne(term_upper, expanded_gene_list.keys(), scorer=fuzz.ratio)
            if match:
                best_match, score = match[:2]
                if score > 85:
                    normalized_genes.add(best_match)
    return normalized_genes

# Function to split text into overlapping chunks for NER
def sliding_window_chunking(text, tokenizer, max_tokens=512, stride=256):
    """Splits text into overlapping chunks to avoid losing context."""
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return [tokenizer.decode(tokens, skip_special_tokens=True)]
    chunks = []
    for i in range(0, len(tokens), stride):
        chunk = tokens[i : i + max_tokens]
        if len(chunk) < max_tokens:
            break
        chunks.append(tokenizer.decode(chunk, skip_special_tokens=True))
    return chunks

# Function to process text with BioBERT using sliding window
def process_biobert(text, model):
    """Runs BioBERT NER with sliding window chunking."""
    if pd.isna(text) or len(text.strip()) == 0:
        return set()
    tokenizer = model.tokenizer
    text_chunks = sliding_window_chunking(text, tokenizer)
    found_terms = set()
    for chunk in text_chunks:
        results = model(chunk)
        current_term = []
        for res in results:
            word = res["word"].replace("##", "")
            if res["entity"].startswith("B-"):
                if current_term:
                    full_term = "".join(current_term)
                    found_terms.add(full_term)
                current_term = [word]
            elif res["entity"].startswith("I-"):
                current_term.append(word)
        if current_term:
            full_term = "".join(current_term)
            found_terms.add(full_term)
    return normalize_extracted_genes(found_terms)

print("Success!")


##### Gene extraction #####
start_time = time.strftime("%Y-%m-%d %H:%M:%S")
start_timestamp = time.time()
print(f"Processing {len(article_df)} articles with BioBERT. Started at {start_time}")

biobert_results = []
for index, row in tqdm(article_df.iterrows(), total=len(article_df), desc="Processing Articles"):
    title = row.get("PaperTitle", "")
    abstract = row.get("Abstract", "")
    genes_biobert = process_biobert(title, biobert_model) | process_biobert(abstract, biobert_model)
    biobert_results.append(", ".join(genes_biobert))
    print(f"Article {index+1}: {genes_biobert}")

df_results = article_df.copy()
df_results["BioBERT"] = biobert_results
num_articles = len(df_results)

output_file = f"filtered_articles_biobert_expanded_{num_articles}.csv"
runtime_file = f"filtered_articles_biobert_expanded_{num_articles}_runtime.txt"

end_time = time.strftime("%Y-%m-%d %H:%M:%S")
end_timestamp = time.time()
total_runtime = end_timestamp - start_timestamp
print(f"Processing completed at {end_time}. Total runtime: {total_runtime:.2f} seconds.")

with open(runtime_file, "w") as f:
    f.write(f"Processing of articles: {num_articles}\n")
    f.write(f"Processing started at: {start_time}\n")
    f.write(f"Processing completed at: {end_time}\n")
    f.write(f"Total runtime: {total_runtime:.2f} seconds\n")
print(f"Runtime details saved in: {runtime_file}")

Expanded gene list contains 161 genes with synonyms.
Success!
Processing 100 articles with BioBERT. Started at 2025-03-06 12:08:39


Processing Articles:   0%|                                                                                  | 0/100 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1355 > 512). Running this sequence through the model will result in indexing errors


Article 1: set()


Processing Articles:   4%|██▉                                                                       | 4/100 [00:00<00:10,  9.13it/s]

Article 2: set()
Article 3: set()
Article 4: set()
Article 5: set()


Processing Articles:   8%|█████▉                                                                    | 8/100 [00:00<00:07, 13.01it/s]

Article 6: set()
Article 7: set()
Article 8: set()
Article 9: set()


Processing Articles:  10%|███████▎                                                                 | 10/100 [00:00<00:06, 13.96it/s]

Article 10: {'EGFR'}
Article 11: {'TP53', 'CTNNB1'}


Processing Articles:  14%|██████████▏                                                              | 14/100 [00:01<00:06, 12.57it/s]

Article 12: set()
Article 13: {'NRG1', 'PTEN'}
Article 14: {'TP53', 'AKT1', 'PIK3CA', 'MTOR'}
Article 15: set()


Processing Articles:  18%|█████████████▏                                                           | 18/100 [00:01<00:05, 13.76it/s]

Article 16: set()
Article 17: set()
Article 18: set()
Article 19: set()


Processing Articles:  22%|████████████████                                                         | 22/100 [00:01<00:05, 14.10it/s]

Article 20: set()
Article 21: {'PRKACA', 'ESR1'}
Article 22: set()


Processing Articles:  24%|█████████████████▌                                                       | 24/100 [00:01<00:05, 14.87it/s]

Article 23: set()
Article 24: set()
Article 25: set()


Processing Articles:  28%|████████████████████▍                                                    | 28/100 [00:02<00:04, 14.60it/s]

Article 26: set()
Article 27: set()
Article 28: set()


Processing Articles:  32%|███████████████████████▎                                                 | 32/100 [00:02<00:04, 15.02it/s]

Article 29: set()
Article 30: set()
Article 31: set()
Article 32: {'PMS2', 'MSH2', 'MLH1', 'MSH6'}


Processing Articles:  34%|████████████████████████▊                                                | 34/100 [00:02<00:04, 15.62it/s]

Article 33: set()
Article 34: set()
Article 35: set()


Processing Articles:  38%|███████████████████████████▋                                             | 38/100 [00:02<00:04, 14.80it/s]

Article 36: {'EGFR'}
Article 37: set()
Article 38: {'JAK3'}


Processing Articles:  42%|██████████████████████████████▋                                          | 42/100 [00:03<00:03, 14.78it/s]

Article 39: set()
Article 40: set()
Article 41: set()
Article 42: set()


Processing Articles:  44%|████████████████████████████████                                         | 44/100 [00:03<00:03, 14.55it/s]

Article 43: set()
Article 44: set()
Article 45: set()


Processing Articles:  48%|███████████████████████████████████                                      | 48/100 [00:03<00:03, 14.21it/s]

Article 46: set()
Article 47: set()
Article 48: set()


Processing Articles:  50%|████████████████████████████████████▌                                    | 50/100 [00:03<00:03, 13.42it/s]

Article 49: set()
Article 50: set()
Article 51: set()


Processing Articles:  54%|███████████████████████████████████████▍                                 | 54/100 [00:03<00:03, 15.05it/s]

Article 52: set()
Article 53: set()
Article 54: set()
Article 55: set()


Processing Articles:  58%|██████████████████████████████████████████▎                              | 58/100 [00:04<00:02, 14.61it/s]

Article 56: set()
Article 57: set()
Article 58: set()


Processing Articles:  62%|█████████████████████████████████████████████▎                           | 62/100 [00:04<00:02, 15.61it/s]

Article 59: set()
Article 60: set()
Article 61: set()
Article 62: set()


Processing Articles:  66%|████████████████████████████████████████████████▏                        | 66/100 [00:04<00:02, 15.34it/s]

Article 63: {'TP53'}
Article 64: {'MAPK1', 'BRAF'}
Article 65: set()
Article 66: set()


Processing Articles:  70%|███████████████████████████████████████████████████                      | 70/100 [00:04<00:01, 16.05it/s]

Article 67: set()
Article 68: set()
Article 69: set()
Article 70: set()


Processing Articles:  74%|██████████████████████████████████████████████████████                   | 74/100 [00:05<00:01, 15.81it/s]

Article 71: set()
Article 72: set()
Article 73: set()
Article 74: {'NOTCH1'}


Processing Articles:  76%|███████████████████████████████████████████████████████▍                 | 76/100 [00:05<00:01, 15.43it/s]

Article 75: set()
Article 76: set()
Article 77: set()


Processing Articles:  80%|██████████████████████████████████████████████████████████▍              | 80/100 [00:05<00:01, 15.28it/s]

Article 78: set()
Article 79: set()
Article 80: {'CDKN2A', 'MET', 'ALK', 'EGFR'}


Processing Articles:  84%|█████████████████████████████████████████████████████████████▎           | 84/100 [00:05<00:01, 14.95it/s]

Article 81: set()
Article 82: set()
Article 83: set()
Article 84: set()


Processing Articles:  86%|██████████████████████████████████████████████████████████████▊          | 86/100 [00:06<00:00, 15.13it/s]

Article 85: set()
Article 86: set()
Article 87: {'POLE'}


Processing Articles:  90%|█████████████████████████████████████████████████████████████████▋       | 90/100 [00:06<00:00, 14.79it/s]

Article 88: set()
Article 89: {'NF2'}
Article 90: set()


Processing Articles:  94%|████████████████████████████████████████████████████████████████████▌    | 94/100 [00:06<00:00, 15.23it/s]

Article 91: {'EZH2', 'KRAS'}
Article 92: {'SRC'}
Article 93: set()
Article 94: set()


Processing Articles:  98%|███████████████████████████████████████████████████████████████████████▌ | 98/100 [00:06<00:00, 15.95it/s]

Article 95: set()
Article 96: set()
Article 97: set()
Article 98: set()


Processing Articles: 100%|████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.46it/s]

Article 99: set()
Article 100: set()
Processing completed at 2025-03-06 12:08:46. Total runtime: 6.92 seconds.
Runtime details saved in: filtered_articles_biobert_expanded_100_runtime.txt





## 3) Create binary matrix

In [122]:
# Binary matrix creation
BioBERT_df = df_results.copy()
print("Length of dataset:", len(BioBERT_df))
print("Column length of dataset:", len(BioBERT_df.columns))
print("Columns of dataset:", BioBERT_df.columns)
BioBERT_original = BioBERT_df["BioBERT"].copy()
BioBERT_df["BioBERT"] = BioBERT_df["BioBERT"].fillna("").astype(str)
BioBERT_df["Extracted_Genes"] = BioBERT_df["BioBERT"].apply(lambda x: [gene.strip() for gene in x.split(',') if gene.strip()])
binary_gene_data = {gene: BioBERT_df["Extracted_Genes"].apply(lambda genes: 1 if gene in genes else 0) for gene in gene_list}
binary_gene_df = pd.DataFrame(binary_gene_data)
BioBERT_df = pd.concat([BioBERT_df, binary_gene_df], axis=1)

BioBERT_df["Sum_Entity_Mentions"] = binary_gene_df.sum(axis=1)

BioBERT_df["BioBERT"] = BioBERT_original
BioBERT_df.drop(columns=["Extracted_Genes"], inplace=True)

# Save as CSV
os.chdir(working_directory)
output_filename = f"BioBERT_evaluation_results.csv"
BioBERT_df.to_csv(output_filename, index=False)
print(f"\nFile saved as: {output_filename}")

os.chdir(working_directory)
binary_gene_columns = [col for col in BioBERT_df.columns if col in gene_list]
BioBERT_df[binary_gene_columns] = BioBERT_df[binary_gene_columns].apply(pd.to_numeric, errors='coerce')

total_gene_mentions = BioBERT_df["Sum_Entity_Mentions"].sum()
total_binary_sum = BioBERT_df[binary_gene_columns].sum().sum()

results_dict = {
    "Metric": ["Total_Sum_Entity_Mentions", "Total_Binary_Matrix_Sum"],
    "Value": [total_gene_mentions, total_binary_sum]
}

results_df = pd.DataFrame(results_dict)
results_df.to_csv("BiobERT_evaluation_results.txt", sep="\t", index=False)

print(f"Results saved to 'Sum_Entity_Mentions_Evaluations.txt'")
print(f"Total sum of 'Sum_Entity_Mentions' column: {total_gene_mentions}")
print(f"Cross-check: Total sum of all binary matrix values (1s in the matrix): {total_binary_sum}")

Length of dataset: 100
Column length of dataset: 10
Columns of dataset: Index(['PaperId', 'PaperTitle', 'Citations', 'CoFoS', 'Authors', 'Abstract',
       'Language', 'PubYear', 'PubDate', 'BioBERT'],
      dtype='object')

File saved as: BioBERT_evaluation_results.csv
Results saved to 'Sum_Entity_Mentions_Evaluations.txt'
Total sum of 'Sum_Entity_Mentions' column: 30
Cross-check: Total sum of all binary matrix values (1s in the matrix): 30
