In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline
from datasets import load_dataset
import spacy
import torch

# Load SciSpacy for biomedical entity recognition
nlp = spacy.load("en_core_sci_sm")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Load the ChemProt dataset
dataset = load_dataset(
    "bigbio/chemprot",
    "chemprot_bigbio_kb",
    split="train", 
    trust_remote_code=True
    )
print("data loaded")


data loaded


In [3]:
dataset["passages"][0][0]
# text = " ".join(p["text"][0] for p in dataset["passages"])

{'id': '1',
 'type': 'title and abstract',
 'text': ['Selective costimulation modulators: a novel approach for the treatment of rheumatoid arthritis.\nT cells have a central role in the orchestration of the immune pathways that contribute to the inflammation and joint destruction characteristic of rheumatoid arthritis (RA). The requirement for a dual signal for T-cell activation and the construction of a fusion protein that prevents engagement of the costimulatory molecules required for this activation has led to a new approach to RA therapy. This approach is mechanistically distinct from other currently used therapies; it targets events early rather than late in the immune cascade, and it results in immunomodulation rather than complete immunosuppression. The fusion protein abatacept is a selective costimulation modulator that avidly binds to the CD80/CD86 ligands on an antigen-presenting cell, resulting in the inability of these ligands to engage the CD28 receptor on the T cell. Abat

In [4]:
dataset["relations"]

[[],
 [{'id': '38',
   'type': 'Downregulator',
   'arg1_id': '18',
   'arg2_id': '31',
   'normalized': []},
  {'id': '39',
   'type': 'Downregulator',
   'arg1_id': '18',
   'arg2_id': '32',
   'normalized': []},
  {'id': '40',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '33',
   'normalized': []},
  {'id': '41',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '34',
   'normalized': []},
  {'id': '42',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '35',
   'normalized': []},
  {'id': '43',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '33',
   'normalized': []},
  {'id': '44',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '34',
   'normalized': []},
  {'id': '45',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '35',
   'normalized': []},
  {'id': '46',
   'type': 'Downregulator',
   'arg1_id': '9',
   'arg2_id': '30',
   'normalized': []},
  {'id': '47',
   'type': 'Downregulator',
   'arg1

In [6]:

# Load BioBERT model fine-tuned on biomedical relations
model_name = "distilbert-base-uncased"  # swap with BioBERT if you fine-tune later
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
# Tokenize the dataset
def preprocess(example):
    # Join all passage texts into one long string
    full_text = " ".join(p["text"][0] for p in example["passages"])
    return tokenizer(full_text, truncation=True, padding="max_length", max_length=128)

tokenized_dataset = dataset.map(preprocess, batched=False)
tokenized_dataset = tokenized_dataset.rename_column("relations", "labels")
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
tokenized_dataset["labels"]

[[],
 [{'id': '38',
   'type': 'Downregulator',
   'arg1_id': '18',
   'arg2_id': '31',
   'normalized': []},
  {'id': '39',
   'type': 'Downregulator',
   'arg1_id': '18',
   'arg2_id': '32',
   'normalized': []},
  {'id': '40',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '33',
   'normalized': []},
  {'id': '41',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '34',
   'normalized': []},
  {'id': '42',
   'type': 'Downregulator',
   'arg1_id': '20',
   'arg2_id': '35',
   'normalized': []},
  {'id': '43',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '33',
   'normalized': []},
  {'id': '44',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '34',
   'normalized': []},
  {'id': '45',
   'type': 'Downregulator',
   'arg1_id': '21',
   'arg2_id': '35',
   'normalized': []},
  {'id': '46',
   'type': 'Downregulator',
   'arg1_id': '9',
   'arg2_id': '30',
   'normalized': []},
  {'id': '47',
   'type': 'Downregulator',
   'arg1

In [7]:
# Create classification pipeline (generic for now, tweak for chemprot)
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer)
print("classifier built")

# Example patient input
patient_report = """
The patient presents with HER2-positive breast cancer, showing overexpression of the ERBB2 gene.
Previous therapies included tamoxifen and trastuzumab, with partial response. Looking for targeted options.
"""

# Extract biomedical entities from patient report
doc = nlp(patient_report)
entities = [(ent.text, ent.label_) for ent in doc.ents]
print("\n🧬 Extracted Entities:")
for ent, label in entities:
    print(f" - {ent} ({label})")

# Create input for transformer (e.g., check relation between drug and protein)
chem_entity = "tamoxifen"
bio_entity = "breast cancer"

input_text = f"{chem_entity} interacts with {bio_entity}"
print(f"\n🔍 Testing relation: {input_text}")
prediction = classifier(input_text)
print("\n🧠 Prediction:")
print(prediction)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Device set to use mps:0


classifier built

🧬 Extracted Entities:
 - patient (ENTITY)
 - HER2-positive breast cancer (ENTITY)
 - overexpression (ENTITY)
 - ERBB2 gene (ENTITY)
 - therapies (ENTITY)
 - tamoxifen (ENTITY)
 - trastuzumab (ENTITY)
 - partial response (ENTITY)
 - targeted options (ENTITY)

🔍 Testing relation: tamoxifen interacts with breast cancer

🧠 Prediction:
[{'label': 'LABEL_1', 'score': 0.5353100299835205}]


In [12]:
print("prediction score", str(round(prediction[0]["score"], 2)))
"prediction score " + str(round(prediction[0]["score"], 2))

prediction score 0.56


'prediction score 0.56'