# 🧠 GBC SciBERT Resource Mention Classifier Testing
This notebook uses GBC's fine-tuned version of the SciBERT model to identify mentions of biodata resources and classify them as true/false mentions. Input a pubmed ID or a PMC ID to extract potential biodata resource mentions and classify them.

### Usage instructions
- On fresh instance, press the 'Run all' button in the menu. This will perform all setup steps and then run classification.
- On an already-running instance, update your publication id below and run the cell. Then, from the 'Table of contents' menu, click the 3 dots next to '🧠 Run Predictions' section and select `Run cells in section`. This avoids rerunning setup each time.

In [1]:
# @title 📄 Set publication ID
pmid = "" # @param {"type":"string","placeholder":"None"}
pmcid = "PMC10666545" # @param {"type":"string", "placeholder":"None"}

# good test pubs: PMC10928905, PMC11514960, PMC10703934, PMC11271732
# perfect prediction: PMC3377363
# test pub: no matches: PMC11933476, PMC10796695, PMC12152223
# very large list of mentions: PMC12125710
# great test: DToL wasp genome: PMC9672531
# LOT of false positives: PMC12180258

# ⚙️ Setup

In [16]:
# @title 📦 Install dependencies
%pip install pandas requests beautifulsoup4
%pip install nltk transformers torch
%pip install sqlalchemy pymysql tqdm
%pip install itables ipywidgets

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
Downloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
Successfully installed ipywidgets-8.1.7 jupyterl

In [3]:
# @title ⬇️ Download tokenizers
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /Users/carla/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /Users/carla/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [None]:
# @title 🧩 Import modules
import sys
import re

import pandas as pd
import json

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from bs4 import BeautifulSoup

import sqlalchemy as db
# import pymysql

from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

from tqdm.notebook import tqdm
from collections import Counter

from IPython.display import Markdown, HTML
from itables import show
import itables.options as opt

In [5]:
# Set display options
opt.classes = "display nowrap"
opt.lengthMenu = [25, 50, 100]
opt.column_filters = "header"
opt.maxBytes = 0  # Show full strings
opt.scrollX = False

# 📥 Pull resource list from DB

In [18]:
# @title 🔎 DB connection setup (via public IP)
db_engine = db.create_engine('mysql+pymysql://gbcreader@34.89.127.34/gbc-publication-analysis', pool_recycle=3600, pool_size=50, max_overflow=50)
db_conn = db_engine.connect()

print("Successfully connected to GBC MySQL instance")

Successfully connected to GBC MySQL instance


In [19]:
# @title 📋 Load resource list

def display_dataframe(df, title=None):
    if title:
        print(f"### {title} ###")

    display(HTML("<div style='max-width: 800px; overflow-x: auto;'>"))
    show(df, include_index=False, classes="display compact", style="width: 800px;")
    display(HTML("</div>"))

additional_aliases = json.load(open("resource_names.additional_aliases.json"))

sql = "SELECT short_name, common_name, full_name FROM resource WHERE is_latest=1"
result = db_conn.execute(db.text(sql)).fetchall()
resource_names = []
for r in result:
    short_name = r[0].strip()
    common_name = r[1].strip() if r[1] else None
    full_name = r[2].strip() if r[2] else None
    if short_name:
        resource_names.append([short_name])
    if common_name and common_name != short_name:
        resource_names[-1].append(common_name)
    if full_name and full_name != short_name and full_name != common_name:
        resource_names[-1].append(full_name)

    if short_name in additional_aliases:
        resource_names[-1].extend(additional_aliases[short_name])

resource_display_df = pd.DataFrame([
    {
        'resource_name': entry[0],
        'aliases': str(entry)
    }
    for entry in resource_names
])
display_dataframe(resource_display_df)



0
Loading ITables v2.4.3 from the internet...  (need help?)


# 📦 Locate and Load Model

In [8]:
# grab model files from GDrive
!cp -r '/content/drive/MyDrive/Colab Notebooks/SciBERT Classifier/scibert_resource_classifier' .
!cp -r '/content/drive/MyDrive/Colab Notebooks/SciBERT Classifier/scibert_resource_classifier.v2' .

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    torch.set_num_threads(2)

tokenizer = AutoTokenizer.from_pretrained("scibert_resource_classifier.v2")
model = AutoModelForSequenceClassification.from_pretrained("scibert_resource_classifier.v2").to(device)
model.eval()

cp: /content/drive/MyDrive/Colab Notebooks/SciBERT Classifier/scibert_resource_classifier: No such file or directory
cp: /content/drive/MyDrive/Colab Notebooks/SciBERT Classifier/scibert_resource_classifier.v2: No such file or directory


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

# 🧰 Define Helper Functions

In [9]:

# setup retry strategy and HTTP adapter to handle rate limiting
# and transient errors
retry_strategy = Retry(
    total=5,                      # Try up to 5 times
    backoff_factor=1.5,           # Starts with 1.5s → 3s → 6s → 12s → 24s
    status_forcelist=[429, 500, 502, 503, 504],
    allowed_methods=["HEAD", "GET", "OPTIONS"],
    raise_on_status=False
)

adapter = HTTPAdapter(max_retries=retry_strategy)
session = requests.Session()
session.mount("https://", adapter)
session.mount("http://", adapter)

# query EuropePMC for publication metadata
max_retries = 5
epmc_base_url = "https://www.ebi.ac.uk/europepmc/webservices/rest"

def query_europepmc(endpoint, request_params=None, no_exit=False):
    """
    Query Europe PMC REST API endpoint with retries.
    """
    for attempt in range(max_retries):
        try:
            response = session.get(endpoint, params=request_params, timeout=15)
            if response.status_code == 200:
                return response.json() if 'json' in response.headers.get('Content-Type', '') else response.text
            else:
                if no_exit:
                    return None
                else:
                    sys.exit(f"Error: {response.status_code} for {endpoint}")
        except requests.RequestException as e:
            print(f"⚠️ Request failed: {e}. Retrying ({attempt + 1}/{max_retries})...")
    sys.exit("Max retries exceeded.")

def preprocess_xml_table(table_wrap_tag):
    """Extracts and flattens a single <table-wrap> tag into a list of text lines suitable for NER."""
    lines = []

    # Caption
    caption = table_wrap_tag.find("caption")
    if caption:
        cap_text = caption.get_text(strip=True)
        if cap_text:
            lines.append(f"[TABLE-CAPTION] {cap_text}")

    # Table body
    table = table_wrap_tag.find("table")
    if table:
        rows = table.find_all("tr")
        for i, row in enumerate(rows):
            cells = row.find_all(["td", "th"])
            if cells:
                row_text = []
                for cell in cells:
                    text = cell.get_text(strip=True)
                    if text:
                        is_header = cell.name == "th" or i == 0
                        prefix = "[COLUMN-HEADER] " if is_header else ""
                        row_text.append(f"{prefix}{text}")
                if row_text:
                    lines.append(" ".join(row_text))

    return "\n".join(lines) if lines else None

def section_to_text(section, depth=1):
    """Converts a BeautifulSoup section to a string."""
    text = []
    title = section.find("title", recursive=False)
    if title:
        text.append(f"{'#'*depth} {title.get_text(strip=True).upper()}")

    elems = section.find_all(["sec", "p"], recursive=False) # only direct children
    for elem in elems:
        if elem.name == "sec":
            text.append(section_to_text(elem, depth=(depth+1)))
        elif elem.name == "p":
            # check for embedded lists
            plists = elem.find_all("list", recursive=False)
            for plist in plists:
                for li in elem.find_all("list-item", recursive=True):
                    li_text = li.get_text(strip=True)
                    if li_text:
                        text.append(f"- {li_text}.")

                plist.extract() # remove the lists from the main paragraph

            p_text = elem.get_text(strip=True)
            if p_text:
                text.append(p_text)

    return "\n".join(text) if text else ''

def get_fulltext_body(pmcid):
    # 1. Download the XML
    url = f"{epmc_base_url}/{pmcid}/fullTextXML"
    response = requests.get(url)
    if response.status_code != 200:
        return None
    xml = response.text

    # 2. Parse with BeautifulSoup
    soup = BeautifulSoup(xml, "lxml-xml")

    # 3. Extract body text with headers
    text_blocks = []

    # 1. Title
    title = soup.find("article-title")
    if title:
        title_text = title.get_text(strip=True)
        if title_text:
            text_blocks.append(f"# TITLE\n{title_text}")
    text_blocks.append("\n")

    # 2. Abstract
    abstract = soup.find("abstract")
    if abstract:
        abstract_title = abstract.find("title")
        if abstract_title and abstract_title.get_text(strip=True).upper() == 'ABSTRACT':
            abstract_title.extract()  # remove the title

        text_blocks.append(f"# ABSTRACT\n{section_to_text(abstract)}")

    # 2.1. Other metadata sections
    funding_statement = soup.find("funding-statement")
    if funding_statement:
        funding_text = funding_statement.get_text(strip=True)
        if funding_text:
            text_blocks.append(f"### FUNDING\n{funding_text}")

    all_custom_metas = soup.find_all("custom-meta")
    for custom_meta in all_custom_metas:
        meta_name = custom_meta.find("meta-name").get_text(strip=True)
        meta_value = custom_meta.find("meta-value").get_text(strip=True)
        if meta_name and meta_value:
            text_blocks.append(f"### {meta_name.upper()}\n{meta_value}")

    text_blocks.append("\n")

    # 3. Tables (captions + content)
    table_blocks = []
    for tbl in soup.find_all("table-wrap"):
        tbl.extract()
        processed_table = preprocess_xml_table(tbl)
        if processed_table:
            table_blocks.append(processed_table)

    # 4. Main body (sections + paragraphs)
    # excluded_section_types = ["supplementary-material", "orcid"]
    excluded_section_types = ["orcid"]
    body = soup.find("body")
    if body:
        all_sections = body.find_all("sec", recursive=False)
        for elem in all_sections:
            if elem.get("sec-type") in excluded_section_types:
                continue

            text_blocks.append(section_to_text(elem))
            text_blocks.append("\n")

    return text_blocks, table_blocks

def remove_substring_matches(mentions):
    aliases = [m[1].lower() for m in mentions]
    unique_aliases = list(set(aliases))

    substr_aliases = []
    for alias1 in unique_aliases:
        for alias2 in unique_aliases:
            if alias1 in alias2 and alias1 != alias2:
                substr_aliases.append(alias1)

    for alias in substr_aliases:
        mentions = [m for m in mentions if m[1].lower() != alias]

    return mentions

case_sensitive_threshold = 30 # switch to case sensitive search after this number of matches for a resource
def get_resource_mentions(textblocks, tableblocks, resource_names):
    mentions = []

    # precompile regex patterns for each resource alias
    # This is more efficient than compiling them on-the-fly in the loop
    compiled_patterns = []
    for resource in resource_names:
        resource_name = resource[0]
        for alias in resource:
            pattern_case_insensitive = re.compile(rf"[^A-Za-z]{re.escape(alias.lower())}[^A-Za-z]")
            compiled_patterns.append((resource_name, alias, pattern_case_insensitive))

    # Split the fulltext into sentences and table rows
    for block in textblocks:
        # sentences = block.split('. ')
        sentences = sent_tokenize(block)  # Use NLTK to split into sentences
        for sentence in sentences:
            sentence = sentence.replace("\n", " ")
            s_lowered = sentence.lower()
            this_sentence_mentions = []
            for resource_name, alias, pattern_ci in compiled_patterns:
                if pattern_ci.search(s_lowered):
                    this_sentence_mentions.append((sentence.strip(), alias, resource_name))

            if len(this_sentence_mentions) > 1:
                this_sentence_mentions = remove_substring_matches(this_sentence_mentions)
            mentions.extend(this_sentence_mentions)

    for table in tableblocks:
        rows = table.split('\n')

        for row in rows:
            r_lowered = row.lower()
            this_row_mentions = []
            for resource_name, alias, pattern_ci in compiled_patterns:
                if pattern_ci.search(r_lowered):
                    this_row_mentions.append((row.strip(), alias, resource_name))

            if len(this_row_mentions) > 1:
                this_row_mentions = remove_substring_matches(this_row_mentions)
            mentions.extend(this_row_mentions)

    # if a large number of matches are found for one resource, switch to case sensitive mode
    filtered_mentions = []
    alias_counts = Counter([m[1] for m in mentions])
    for alias, count in alias_counts.items():
        if count > case_sensitive_threshold:
            print(f"⚠️ {count} matches found for {alias} - switching to case sensitive mode")
            pattern_case_sensitive = re.compile(rf"[^A-Za-z]{re.escape(alias)}[^A-Za-z]")
            for m in mentions:
                if m[1] == alias and pattern_case_sensitive.search(m[0]):
                    filtered_mentions.append(m)
        else:
            this_alias_mentions = [m for m in mentions if m[1] == alias]
            filtered_mentions.extend(this_alias_mentions)

    # Remove duplicates
    mentions = list(set(filtered_mentions))
    # Remove empty mentions
    mentions = [m for m in mentions if m[0]]

    return mentions

def classify_mentions(pmcid, pmid, candidate_pairs):
    predictions = []

    for sentence, alias, resource in tqdm(candidate_pairs, desc="🔍 Classifying"):
        inputs = tokenizer(alias, sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=512).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            pred = torch.argmax(probs, dim=1).item()
            if pred == 1:
                predictions.append({
                    "prediction": "MATCH",
                    "pmcid": pmcid,
                    "pmid": pmid,
                    "resource_name": resource,
                    "matched_alias": alias,
                    "sentence": sentence,
                    "confidence": probs[0, 1].item()
                })
            else:
                predictions.append({
                    "prediction": "NO MATCH",
                    "pmcid": pmcid,
                    "pmid": pmid,
                    "resource_name": resource,
                    "matched_alias": alias,
                    "sentence": sentence,
                    "confidence": probs[0, 0].item()
                })

    return predictions


# 🧠 Run Predictions

In [10]:
# @title 📄 Fetch and preprocess publication text

epmc_query = f"PMCID:{pmcid}" if pmcid else f"EXT_ID:{pmid}"
md = f"- 🔍 Querying Europe PMC for {epmc_query}\n"
data = query_europepmc(f"{epmc_base_url}/search", request_params={
    'query': epmc_query,
    'format': 'json',
    'pageSize': 10,
    'cursorMark': '*',
    'resultType': 'core'
})
md += f"- 🔍 Found {data.get('hitCount', 0)} results for {epmc_query}\n\n---\n"
display(Markdown(md))

for result in data.get('resultList', {}).get('result', []):
    this_pmcid = result.get('pmcid')
    this_pmid = result.get('pmid')
    title = result.get('title')

    # since we must use EXT_ID to search using PMID, this introduces room for error
    # keep skipping through results until the match is found.
    # In theory, we should only be processing 1 publication here.
    if pmid and this_pmid != pmid:
        # print(f"⚠️ Skipping {this_pmcid} as it does not match the provided PMID {pmid}.")
        continue

    md = f"- 📄 Title: {title}\n"
    md += f"- 🆔 PMCID: {this_pmcid}, PMID: {this_pmid}\n\n---\n"
    display(Markdown(md))

    if pmcid:
        # Get full text body and tables
        text_body, table_blocks = get_fulltext_body(this_pmcid)
        if not text_body:
            print("⚠️ No full text body found.")
            continue
    else:
        text_body, table_blocks = sent_tokenize(result.get('abstractText')), []

    break # only use the first successful match

text_body = [tb.replace('\n', ' ') for tb in text_body]
text_body = [tb for tb in text_body if tb.strip()]

md = f"- Processed {len(text_body)} text blocks\n"
md += f"- Processed {len(table_blocks)} table blocks\n"
display(Markdown(md))

- 🔍 Querying Europe PMC for PMCID:PMC10666545
- 🔍 Found 1 results for PMCID:PMC10666545

---


- 📄 Title: Advancing Computational Toxicology by Interpretable Machine Learning.
- 🆔 PMCID: PMC10666545, PMID: 37224004

---


- Processed 10 text blocks
- Processed 2 table blocks


In [11]:
# @title 🧐 Inspect text blocks (optional)
# txb_df = pd.DataFrame(text_body, columns=["text_block"])
# txb_df

In [12]:
# @title 🧐 Inspect table blocks (optional)
# tb_df = pd.DataFrame(table_blocks, columns=["table_block"])
# tb_df

In [13]:
# @title 🔍 Search for resource mentions
display(Markdown(f"🔍 Searching for resource mentions in {this_pmcid}..."))
mentions = get_resource_mentions(text_body, table_blocks, resource_names)
display(Markdown(f"🔍 Found {len(mentions)} mentions of {len(set([x[2] for x in mentions]))} resources in {this_pmcid}.\n"))

🔍 Searching for resource mentions in PMC10666545...

🔍 Found 43 mentions of 22 resources in PMC10666545.


In [14]:
# @title 🧐 Inspect unclassified mentions (optional)
mentions_df = pd.DataFrame(mentions, columns=["sentence", "alias", "resource_name"])
mentions_df


Unnamed: 0,sentence,alias,resource_name
0,;Figure1A) representing properties determining...,PubChem,PubChem
1,"Currently, the AOP-Wiki features more than 400...",PaVE,PaVE
2,Experimental conditions and protocols for gene...,ChEMBL,ChEMBL
3,"Molecular Signatures Database (MSigDB)187,188 ...",MSigDB,MSigDB
4,"gene (sets) annotation Gene Ontology178,179 A ...",Gene Ontology,GOC
5,comprised of gene ontology terms for many kind...,Gene Ontology,GOC
6,;Figure1A) representing properties determining...,ChEMBL,ChEMBL
7,developed a VNN model named P-NET that integra...,Reactome,Reactome
8,Inspired by the biological organization of the...,IDEAL,IDEAL
9,pathway and gene regulatory network informatio...,Regulatory Network,AutophagyNet


In [15]:
# @title 🧠 Classify resource mentions
classified_mentions = classify_mentions(this_pmcid, this_pmid, mentions)

class_df = pd.DataFrame(classified_mentions)
class_df.sort_values(by=['prediction', 'confidence'], ascending=[False, False], inplace=True)

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

In [None]:
class_df_display = class_df[['resource_name', 'matched_alias', 'prediction', 'pmcid', 'sentence']].copy()
class_df_display['prediction'] = class_df_display['prediction'].map({'MATCH': 1, 'NO MATCH': 0})
data_table.DataTable(class_df_display, include_index=False, num_rows_per_page=25)

## 📊 View results by confidence bands

In [None]:
# @title ⭐ High confidence (> 0.98)
high = class_df[class_df["confidence"] > 0.98]
high

In [None]:
# @title 📘 Medium-high confidence (0.9–0.98)
mid_high = class_df[(class_df["confidence"] <= 0.98) & (class_df["confidence"] > 0.9)]
mid_high

In [None]:
# @title 📒 Medium-low confidence (0.8–0.9)
mid_low = class_df[(class_df["confidence"] <= 0.9) & (class_df["confidence"] > 0.8)]
mid_low

In [None]:
# @title ⚠️ Low confidence (≤ 0.8)
low = class_df[class_df["confidence"] <= 0.8]
low

## 🏁 Publication Classification Final Result

In [None]:
# @title ❓ Does this publication have biodata resource mentions?

summary_df = (
    class_df[(class_df['prediction'] == 'MATCH') & (class_df['confidence'] >= 0.9)]
    .groupby('resource_name', as_index=False)
    .agg({
        'confidence': 'mean',
        'prediction': 'count',
        'sentence': lambda x: " || ".join(list(set(x)))  # unique sentences
    })
)
summary_df.rename(columns={'confidence': 'mean_confidence', 'sentence':'token_matches', 'prediction': 'num_matches'}, inplace=True)

if len(summary_df) > 0:
    result_md = f"## ✅ Publication _has_ verified known biodata resource mention(s)\n\n---\n"
    result_md += f"### Resource Match Summary\n"
    result_md += summary_df[['resource_name', 'num_matches', 'mean_confidence']].to_markdown(index=False)

else:
    result_md = f"## ❌ Publication _does not_ mention a known biodata resource"

display(Markdown(result_md))

# 🤖 Test Zone

In [None]:
# text_blocks, table_blocks = get_fulltext_body(pmcid)

In [None]:
# display(JSON(text_blocks))