In [8]:
# !pip install -q transformers datasets scikit-learn tqdm
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install transformers[torch]

In [9]:
# Cell 2: imports and config
import os
import json
from datasets import load_dataset, Dataset, DatasetDict
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 32
learning_rate = 0.001
max_length = 256
special_entity_token = "<ENT>"


In [10]:
# Cell 3: load local json files (train.json, val.json, test.json expected in current dir)
def load_json_to_dataset(path):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    for ex in data:
        ex["label"] = int(ex.get("label", ex.get("answer", "0")))
        if isinstance(ex["label"], str):
            ex["label"] = int(ex["label"])
    return Dataset.from_list(data)

train_ds = load_json_to_dataset("train.json")
val_ds = load_json_to_dataset("val.json")
test_ds = load_json_to_dataset("test.json")
dataset = DatasetDict({"train": train_ds, "validation": val_ds, "test": test_ds})


In [11]:
# Cell 4: tokenizer and preprocessing (map entity markers to a single token and add pad token)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
if special_entity_token not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"additional_special_tokens": [special_entity_token]})

import re

def preprocess_function(examples):
    texts = []
    for s in examples["sentence"]:
        s2 = re.sub(r'@DRUG\d+\$(?:s)?', special_entity_token, s)
        s2 = s2.replace("<\\\\entity><\\\\entity>", special_entity_token)
        s2 = s2.replace("<\\entity><\\entity>", special_entity_token)
        s2 = s2.replace("<\\\\entity>", special_entity_token)
        s2 = s2.replace("<\\entity>", special_entity_token)
        texts.append(s2)
    tokenized = tokenizer(texts, truncation=True, padding=False, max_length=max_length)
    tokenized["labels"] = [int(l) for l in examples["label"]]
    return tokenized

tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names)

Map: 100%|██████████| 18779/18779 [00:03<00:00, 5743.91 examples/s]
Map: 100%|██████████| 7244/7244 [00:00<00:00, 7976.73 examples/s]
Map: 100%|██████████| 5761/5761 [00:00<00:00, 10146.01 examples/s]


In [12]:

# Cell 5: model setup
model = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=5)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id

for name, param in model.named_parameters():
    if not name.startswith("transformer.h.10") and \
       not name.startswith("transformer.h.11") and \
       not name.startswith("ln_f") and \
       not name.startswith("score"):
        param.requires_grad = False

data_collator = DataCollatorWithPadding(tokenizer, padding=True)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Cell 6 (corrected for older transformers versions)
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    f1_macro = f1_score(labels, preds, average="macro")
    prec_macro = precision_score(labels, preds, average="macro", zero_division=0)
    rec_macro = recall_score(labels, preds, average="macro", zero_division=0)
    return {"accuracy": acc, "f1_macro": f1_macro, "precision_macro": prec_macro, "recall_macro": rec_macro}


training_args = TrainingArguments(
    output_dir="gpt2-relation-classifier",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    eval_strategy="epoch",
    save_steps=500,
    learning_rate=learning_rate,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=50,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


  trainer = Trainer(


In [14]:
# Cell 7: train
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,Precision Macro,Recall Macro
1,0.3012,0.291388,0.892877,0.499538,0.759647,0.425581
2,0.2129,0.245365,0.910409,0.715183,0.691557,0.745357
3,0.1464,0.233703,0.920762,0.695146,0.736095,0.685871


TrainOutput(global_step=1761, training_loss=0.2799533183841391, metrics={'train_runtime': 4270.5194, 'train_samples_per_second': 13.192, 'train_steps_per_second': 0.412, 'total_flos': 5754194573114880.0, 'train_loss': 0.2799533183841391, 'epoch': 3.0})

In [15]:
# Cell 8: evaluate on test set
metrics = trainer.evaluate(tokenized_datasets["test"])
print("Eval metrics (trainer.evaluate):", metrics)

pred_output = trainer.predict(tokenized_datasets["test"])
preds = pred_output.predictions
if preds.ndim > 1:
    preds = np.argmax(preds, axis=-1)
labels = pred_output.label_ids

acc = accuracy_score(labels, preds)
f1_macro = f1_score(labels, preds, average="macro")
f1_weighted = f1_score(labels, preds, average="weighted")
report = classification_report(labels, preds, digits=4)
cm = confusion_matrix(labels, preds)

print("Accuracy:", acc)
print("Macro F1:", f1_macro)
print("Weighted F1:", f1_weighted)
print("Classification report:\n", report)
print("Confusion matrix:\n", cm)


Eval metrics (trainer.evaluate): {'eval_loss': 0.38575467467308044, 'eval_accuracy': 0.9008852629751779, 'eval_f1_macro': 0.6946933300483344, 'eval_precision_macro': 0.7597122932219438, 'eval_recall_macro': 0.6569940185490856, 'eval_runtime': 47.0921, 'eval_samples_per_second': 122.335, 'eval_steps_per_second': 3.844, 'epoch': 3.0}
Accuracy: 0.9008852629751779
Macro F1: 0.6946933300483344
Weighted F1: 0.8978714409074594
Classification report:
               precision    recall  f1-score   support

           0     0.6041    0.6528    0.6275       360
           1     0.7716    0.5927    0.6704       302
           2     0.7647    0.7059    0.7341       221
           3     0.7200    0.3750    0.4932        96
           4     0.9382    0.9586    0.9483      4782

    accuracy                         0.9009      5761
   macro avg     0.7597    0.6570    0.6947      5761
weighted avg     0.8983    0.9009    0.8979      5761

Confusion matrix:
 [[ 235    6    4    2  113]
 [   2  179    4

In [16]:
sample_sentences = [
    "Coadministration of @DRUG1$ and certain antihypertensives may increase the risk of dizziness.",
    "@DRUG1$ should not be taken together with high doses of @DRUG2$ due to altered clearance.",
    "Concurrent use of @DRUG1$ with antacids may reduce absorption when paired with @DRUG2$.",
    "Patients receiving @DRUG1$ might require dosage adjustments if also treated with @DRUG2$.",
    "No clinically significant changes were observed when @DRUG1$ was combined with @DRUG2$.",
]

processed = []
for s in sample_sentences:
    s2 = re.sub(r'@DRUG\d+\$(?:s)?', special_entity_token, s)
    processed.append(s2)

enc = tokenizer(processed, truncation=True, padding=True, max_length=max_length, return_tensors="pt").to(device)
model.to(device)

with torch.no_grad():
    logits = model(**enc).logits.cpu().numpy()

preds = np.argmax(logits, axis=-1)

for orig, pred in zip(sample_sentences, preds):
    print(orig)
    print("Pred:", pred)
    print()


Coadministration of @DRUG1$ and certain antihypertensives may increase the risk of dizziness.
Pred: 0

@DRUG1$ should not be taken together with high doses of @DRUG2$ due to altered clearance.
Pred: 2

Concurrent use of @DRUG1$ with antacids may reduce absorption when paired with @DRUG2$.
Pred: 4

Patients receiving @DRUG1$ might require dosage adjustments if also treated with @DRUG2$.
Pred: 2

No clinically significant changes were observed when @DRUG1$ was combined with @DRUG2$.
Pred: 4

