In [1]:
from bioengine.dochandlers.bioc_dataset import BioCDataset, BioCPreprocessor

stats = BioCPreprocessor("../data/BioRED/Train.BioC.XML").analyze_bioc_file()
set(stats['relation_types'].keys())

{'Association',
 'Bind',
 'Comparison',
 'Conversion',
 'Cotreatment',
 'Drug_Interaction',
 'Negative_Correlation',
 'Positive_Correlation'}

In [2]:
from dochandlers.bioc_dataset import BioCDataset


class Config:
    # Model settings
    # model_name = "dmis-lab/biobert-base-cased-v1.2"  # BioBERT for biomedical text
    model_name = "dmis-lab/biobert-v1.1"  # BioBERT for biomedical text
    max_length = 512
    hidden_size = 768

    # Training settings
    batch_size = 16
    learning_rate = 2e-5
    num_epochs = 3
    dropout = 0.1

    # Relationship types
    relation_types = [
        *set(stats['relation_types'].keys()),
        "no_relation"  #include negative class
    ]

    num_relations = len(relation_types)
    relation_to_id = {rel: idx for idx, rel in enumerate(relation_types)}
    id_to_relation = {idx: rel for idx, rel in enumerate(relation_types)}


dataset = BioCDataset('../data/BioRED/Train.BioC.XML', config=Config())
dataset[0]

Loaded 3831 relation examples from 400 documents
2025-09-27 15:53:45,658 - urllib3.connectionpool - DEBUG - Starting new HTTPS connection (1): huggingface.co:443
2025-09-27 15:53:46,370 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/tokenizer_config.json HTTP/1.1" 307 0
2025-09-27 15:53:46,428 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /api/resolve-cache/models/dmis-lab/biobert-v1.1/551ca18efd7f052c8dfa0b01c94c2a8e68bc5488/tokenizer_config.json HTTP/1.1" 200 0
2025-09-27 15:53:46,602 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "GET /api/models/dmis-lab/biobert-v1.1/tree/main/additional_chat_templates?recursive=False&expand=False HTTP/1.1" 404 64


{'input_ids': tensor([[  101, 28996,  1124,  4163,  2430,  3457,  1566,  4272,  5318,   118,
            127, 28997,   131,  9815,  1206,  7434, 15661, 26468,  1105, 28998,
           2076,  1563, 17972, 28999,  1105,  1206,  7434, 15661, 26468,  1105,
          10777,  1104, 26825,  3318,  1988,   119,  1109, 15416,  5318,  1119,
           4163,  2430,  3457,  1566,  4272,  5318,   113,   145, 28047,   114,
            118,   127,  1110,  1126, 15011, 27335,  1104,  1317,  9077,  2017,
           1107,  1103,  3507, 19790, 16317,  1104, 18462,   118, 15415, 17972,
           1104,  1103,  1685,   119,  1284,  3335,  7289,  1103, 11066,  1115,
          15661, 26468,  1107,  1103,   145, 28047,   118,   127,  5565,  1110,
           2628,  1114, 18005,  1116,  1104,  6902,  1563,   113,  1664,   118,
          26825,   118,  7449,   114, 17972,  1143,  6473,  4814,  1105, 10777,
           1104, 26825,  3318,  1988,  1107, 20636, 26827,  5174,   119,  1284,
          22121,  1181,  11

In [3]:
from transformers import AutoModel
from torch import nn

from transformers import BertModel
import torch.nn as nn
class SimpleRelationModel(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.encoder = BertModel.from_pretrained(config.model_name)
        self.dropout = nn.Dropout(0.1,inplace=False)
        self.classifier = nn.Linear(config.hidden_size, config.num_relations)
    def forward(self, x, attention_mask ,label=None):
        x = self.encoder(x, attention_mask=attention_mask)[0]
        x = x[:, 0, :]
        # x = self.dropout(x)
        # x = self.fc(x)
                # Apply dropout and classify
        x = self.dropout(x)
        logits = self.classifier(x)  # [batch_size, num_relations]

        return logits


In [5]:
from transformers import BertTokenizer
import torch

In [6]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


In [9]:


config = Config()
model = SimpleRelationModel(config)
model.load_state_dict(torch.load("simple_relation_model.pt", map_location=device, weights_only=False))
model.to(config.device)
model.eval()

tokenizer = BertTokenizer.from_pretrained(config.model_name)

2025-09-27 15:57:24,665 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/config.json HTTP/1.1" 307 0
2025-09-27 15:57:24,717 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /api/resolve-cache/models/dmis-lab/biobert-v1.1/551ca18efd7f052c8dfa0b01c94c2a8e68bc5488/config.json HTTP/1.1" 200 0
2025-09-27 15:57:24,868 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/config.json HTTP/1.1" 307 0
2025-09-27 15:57:24,922 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /api/resolve-cache/models/dmis-lab/biobert-v1.1/551ca18efd7f052c8dfa0b01c94c2a8e68bc5488/config.json HTTP/1.1" 200 0
2025-09-27 15:57:25,148 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-09-27 15:57:25,335 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "GET /api/models/dmis-

RuntimeError: Error(s) in loading state_dict for SimpleRelationModel:
	Missing key(s) in state_dict: "encoder.embeddings.word_embeddings.weight", "encoder.embeddings.position_embeddings.weight", "encoder.embeddings.token_type_embeddings.weight", "encoder.embeddings.LayerNorm.weight", "encoder.embeddings.LayerNorm.bias", "encoder.encoder.layer.0.attention.self.query.weight", "encoder.encoder.layer.0.attention.self.query.bias", "encoder.encoder.layer.0.attention.self.key.weight", "encoder.encoder.layer.0.attention.self.key.bias", "encoder.encoder.layer.0.attention.self.value.weight", "encoder.encoder.layer.0.attention.self.value.bias", "encoder.encoder.layer.0.attention.output.dense.weight", "encoder.encoder.layer.0.attention.output.dense.bias", "encoder.encoder.layer.0.attention.output.LayerNorm.weight", "encoder.encoder.layer.0.attention.output.LayerNorm.bias", "encoder.encoder.layer.0.intermediate.dense.weight", "encoder.encoder.layer.0.intermediate.dense.bias", "encoder.encoder.layer.0.output.dense.weight", "encoder.encoder.layer.0.output.dense.bias", "encoder.encoder.layer.0.output.LayerNorm.weight", "encoder.encoder.layer.0.output.LayerNorm.bias", "encoder.encoder.layer.1.attention.self.query.weight", "encoder.encoder.layer.1.attention.self.query.bias", "encoder.encoder.layer.1.attention.self.key.weight", "encoder.encoder.layer.1.attention.self.key.bias", "encoder.encoder.layer.1.attention.self.value.weight", "encoder.encoder.layer.1.attention.self.value.bias", "encoder.encoder.layer.1.attention.output.dense.weight", "encoder.encoder.layer.1.attention.output.dense.bias", "encoder.encoder.layer.1.attention.output.LayerNorm.weight", "encoder.encoder.layer.1.attention.output.LayerNorm.bias", "encoder.encoder.layer.1.intermediate.dense.weight", "encoder.encoder.layer.1.intermediate.dense.bias", "encoder.encoder.layer.1.output.dense.weight", "encoder.encoder.layer.1.output.dense.bias", "encoder.encoder.layer.1.output.LayerNorm.weight", "encoder.encoder.layer.1.output.LayerNorm.bias", "encoder.encoder.layer.2.attention.self.query.weight", "encoder.encoder.layer.2.attention.self.query.bias", "encoder.encoder.layer.2.attention.self.key.weight", "encoder.encoder.layer.2.attention.self.key.bias", "encoder.encoder.layer.2.attention.self.value.weight", "encoder.encoder.layer.2.attention.self.value.bias", "encoder.encoder.layer.2.attention.output.dense.weight", "encoder.encoder.layer.2.attention.output.dense.bias", "encoder.encoder.layer.2.attention.output.LayerNorm.weight", "encoder.encoder.layer.2.attention.output.LayerNorm.bias", "encoder.encoder.layer.2.intermediate.dense.weight", "encoder.encoder.layer.2.intermediate.dense.bias", "encoder.encoder.layer.2.output.dense.weight", "encoder.encoder.layer.2.output.dense.bias", "encoder.encoder.layer.2.output.LayerNorm.weight", "encoder.encoder.layer.2.output.LayerNorm.bias", "encoder.encoder.layer.3.attention.self.query.weight", "encoder.encoder.layer.3.attention.self.query.bias", "encoder.encoder.layer.3.attention.self.key.weight", "encoder.encoder.layer.3.attention.self.key.bias", "encoder.encoder.layer.3.attention.self.value.weight", "encoder.encoder.layer.3.attention.self.value.bias", "encoder.encoder.layer.3.attention.output.dense.weight", "encoder.encoder.layer.3.attention.output.dense.bias", "encoder.encoder.layer.3.attention.output.LayerNorm.weight", "encoder.encoder.layer.3.attention.output.LayerNorm.bias", "encoder.encoder.layer.3.intermediate.dense.weight", "encoder.encoder.layer.3.intermediate.dense.bias", "encoder.encoder.layer.3.output.dense.weight", "encoder.encoder.layer.3.output.dense.bias", "encoder.encoder.layer.3.output.LayerNorm.weight", "encoder.encoder.layer.3.output.LayerNorm.bias", "encoder.encoder.layer.4.attention.self.query.weight", "encoder.encoder.layer.4.attention.self.query.bias", "encoder.encoder.layer.4.attention.self.key.weight", "encoder.encoder.layer.4.attention.self.key.bias", "encoder.encoder.layer.4.attention.self.value.weight", "encoder.encoder.layer.4.attention.self.value.bias", "encoder.encoder.layer.4.attention.output.dense.weight", "encoder.encoder.layer.4.attention.output.dense.bias", "encoder.encoder.layer.4.attention.output.LayerNorm.weight", "encoder.encoder.layer.4.attention.output.LayerNorm.bias", "encoder.encoder.layer.4.intermediate.dense.weight", "encoder.encoder.layer.4.intermediate.dense.bias", "encoder.encoder.layer.4.output.dense.weight", "encoder.encoder.layer.4.output.dense.bias", "encoder.encoder.layer.4.output.LayerNorm.weight", "encoder.encoder.layer.4.output.LayerNorm.bias", "encoder.encoder.layer.5.attention.self.query.weight", "encoder.encoder.layer.5.attention.self.query.bias", "encoder.encoder.layer.5.attention.self.key.weight", "encoder.encoder.layer.5.attention.self.key.bias", "encoder.encoder.layer.5.attention.self.value.weight", "encoder.encoder.layer.5.attention.self.value.bias", "encoder.encoder.layer.5.attention.output.dense.weight", "encoder.encoder.layer.5.attention.output.dense.bias", "encoder.encoder.layer.5.attention.output.LayerNorm.weight", "encoder.encoder.layer.5.attention.output.LayerNorm.bias", "encoder.encoder.layer.5.intermediate.dense.weight", "encoder.encoder.layer.5.intermediate.dense.bias", "encoder.encoder.layer.5.output.dense.weight", "encoder.encoder.layer.5.output.dense.bias", "encoder.encoder.layer.5.output.LayerNorm.weight", "encoder.encoder.layer.5.output.LayerNorm.bias", "encoder.encoder.layer.6.attention.self.query.weight", "encoder.encoder.layer.6.attention.self.query.bias", "encoder.encoder.layer.6.attention.self.key.weight", "encoder.encoder.layer.6.attention.self.key.bias", "encoder.encoder.layer.6.attention.self.value.weight", "encoder.encoder.layer.6.attention.self.value.bias", "encoder.encoder.layer.6.attention.output.dense.weight", "encoder.encoder.layer.6.attention.output.dense.bias", "encoder.encoder.layer.6.attention.output.LayerNorm.weight", "encoder.encoder.layer.6.attention.output.LayerNorm.bias", "encoder.encoder.layer.6.intermediate.dense.weight", "encoder.encoder.layer.6.intermediate.dense.bias", "encoder.encoder.layer.6.output.dense.weight", "encoder.encoder.layer.6.output.dense.bias", "encoder.encoder.layer.6.output.LayerNorm.weight", "encoder.encoder.layer.6.output.LayerNorm.bias", "encoder.encoder.layer.7.attention.self.query.weight", "encoder.encoder.layer.7.attention.self.query.bias", "encoder.encoder.layer.7.attention.self.key.weight", "encoder.encoder.layer.7.attention.self.key.bias", "encoder.encoder.layer.7.attention.self.value.weight", "encoder.encoder.layer.7.attention.self.value.bias", "encoder.encoder.layer.7.attention.output.dense.weight", "encoder.encoder.layer.7.attention.output.dense.bias", "encoder.encoder.layer.7.attention.output.LayerNorm.weight", "encoder.encoder.layer.7.attention.output.LayerNorm.bias", "encoder.encoder.layer.7.intermediate.dense.weight", "encoder.encoder.layer.7.intermediate.dense.bias", "encoder.encoder.layer.7.output.dense.weight", "encoder.encoder.layer.7.output.dense.bias", "encoder.encoder.layer.7.output.LayerNorm.weight", "encoder.encoder.layer.7.output.LayerNorm.bias", "encoder.encoder.layer.8.attention.self.query.weight", "encoder.encoder.layer.8.attention.self.query.bias", "encoder.encoder.layer.8.attention.self.key.weight", "encoder.encoder.layer.8.attention.self.key.bias", "encoder.encoder.layer.8.attention.self.value.weight", "encoder.encoder.layer.8.attention.self.value.bias", "encoder.encoder.layer.8.attention.output.dense.weight", "encoder.encoder.layer.8.attention.output.dense.bias", "encoder.encoder.layer.8.attention.output.LayerNorm.weight", "encoder.encoder.layer.8.attention.output.LayerNorm.bias", "encoder.encoder.layer.8.intermediate.dense.weight", "encoder.encoder.layer.8.intermediate.dense.bias", "encoder.encoder.layer.8.output.dense.weight", "encoder.encoder.layer.8.output.dense.bias", "encoder.encoder.layer.8.output.LayerNorm.weight", "encoder.encoder.layer.8.output.LayerNorm.bias", "encoder.encoder.layer.9.attention.self.query.weight", "encoder.encoder.layer.9.attention.self.query.bias", "encoder.encoder.layer.9.attention.self.key.weight", "encoder.encoder.layer.9.attention.self.key.bias", "encoder.encoder.layer.9.attention.self.value.weight", "encoder.encoder.layer.9.attention.self.value.bias", "encoder.encoder.layer.9.attention.output.dense.weight", "encoder.encoder.layer.9.attention.output.dense.bias", "encoder.encoder.layer.9.attention.output.LayerNorm.weight", "encoder.encoder.layer.9.attention.output.LayerNorm.bias", "encoder.encoder.layer.9.intermediate.dense.weight", "encoder.encoder.layer.9.intermediate.dense.bias", "encoder.encoder.layer.9.output.dense.weight", "encoder.encoder.layer.9.output.dense.bias", "encoder.encoder.layer.9.output.LayerNorm.weight", "encoder.encoder.layer.9.output.LayerNorm.bias", "encoder.encoder.layer.10.attention.self.query.weight", "encoder.encoder.layer.10.attention.self.query.bias", "encoder.encoder.layer.10.attention.self.key.weight", "encoder.encoder.layer.10.attention.self.key.bias", "encoder.encoder.layer.10.attention.self.value.weight", "encoder.encoder.layer.10.attention.self.value.bias", "encoder.encoder.layer.10.attention.output.dense.weight", "encoder.encoder.layer.10.attention.output.dense.bias", "encoder.encoder.layer.10.attention.output.LayerNorm.weight", "encoder.encoder.layer.10.attention.output.LayerNorm.bias", "encoder.encoder.layer.10.intermediate.dense.weight", "encoder.encoder.layer.10.intermediate.dense.bias", "encoder.encoder.layer.10.output.dense.weight", "encoder.encoder.layer.10.output.dense.bias", "encoder.encoder.layer.10.output.LayerNorm.weight", "encoder.encoder.layer.10.output.LayerNorm.bias", "encoder.encoder.layer.11.attention.self.query.weight", "encoder.encoder.layer.11.attention.self.query.bias", "encoder.encoder.layer.11.attention.self.key.weight", "encoder.encoder.layer.11.attention.self.key.bias", "encoder.encoder.layer.11.attention.self.value.weight", "encoder.encoder.layer.11.attention.self.value.bias", "encoder.encoder.layer.11.attention.output.dense.weight", "encoder.encoder.layer.11.attention.output.dense.bias", "encoder.encoder.layer.11.attention.output.LayerNorm.weight", "encoder.encoder.layer.11.attention.output.LayerNorm.bias", "encoder.encoder.layer.11.intermediate.dense.weight", "encoder.encoder.layer.11.intermediate.dense.bias", "encoder.encoder.layer.11.output.dense.weight", "encoder.encoder.layer.11.output.dense.bias", "encoder.encoder.layer.11.output.LayerNorm.weight", "encoder.encoder.layer.11.output.LayerNorm.bias", "encoder.pooler.dense.weight", "encoder.pooler.dense.bias", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "model_state_dict", "config", "tokenizer". 

2025-09-27 15:57:26,227 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "GET /api/models/dmis-lab/biobert-v1.1/commits/refs%2Fpr%2F9 HTTP/1.1" 200 3950
2025-09-27 15:57:26,387 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/refs%2Fpr%2F9/model.safetensors.index.json HTTP/1.1" 404 0
2025-09-27 15:57:26,541 - urllib3.connectionpool - DEBUG - https://huggingface.co:443 "HEAD /dmis-lab/biobert-v1.1/resolve/refs%2Fpr%2F9/model.safetensors HTTP/1.1" 302 0
