In [1]:
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [45]:
# Load ChemDisGene dataset
ds = load_dataset("bigbio/chem_dis_gene", "chem_dis_gene_bigbio_kb", trust_remote_code = True)

model_name = "distilbert-base-uncased"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [46]:
# find all the unique relation labels
all_relation_types = set()

for doc in ds["train"]:
        for rel in doc["relations"]:
            all_relation_types.add(rel["type"])

# Step 3: Add 'no_relation' for negative examples
all_relation_types.add("no_relation")

# Step 4: Print them
print("Unique relation types:")
for rel_type in sorted(all_relation_types):
    print(f"- {rel_type}")


Unique relation types:
- chem_disease:marker/mechanism
- chem_disease:therapeutic
- chem_gene:affects^activity
- chem_gene:affects^binding
- chem_gene:affects^expression
- chem_gene:affects^localization
- chem_gene:affects^metabolic_processing
- chem_gene:affects^transport
- chem_gene:decreases^activity
- chem_gene:decreases^expression
- chem_gene:decreases^metabolic_processing
- chem_gene:decreases^transport
- chem_gene:increases^activity
- chem_gene:increases^expression
- chem_gene:increases^metabolic_processing
- chem_gene:increases^transport
- gene_disease:marker/mechanism
- gene_disease:therapeutic
- no_relation


In [47]:
# create label-to-id mapping
label2id = {label: idx for idx, label in enumerate(sorted(all_relation_types))}
id2label = {v: k for k, v in label2id.items()}

In [48]:
# split dataset into train and test
if "test" not in ds:
    ds_split = ds["train"].train_test_split(test_size=0.1, seed=42)

In [50]:
ds_split["train"] # 9:1 split

Dataset({
    features: ['id', 'document_id', 'passages', 'entities', 'events', 'coreferences', 'relations'],
    num_rows: 470
})

In [52]:

# Build examples from entity pairs
def prepare_examples(split):
    examples = []
    for doc in ds_split[split]:
        text = " ".join(p["text"][0] for p in doc["passages"])
        entity_map = {e["id"]: e for e in doc["entities"]}
        existing_pairs = {(r["arg1_id"], r["arg2_id"]): r["type"] for r in doc["relations"]}

        # Create all possible entity pairs (for binary classification setup)
        for e1 in doc["entities"]:
            for e2 in doc["entities"]:
                if e1["id"] == e2["id"]:
                    continue
                label = existing_pairs.get((e1["id"], e2["id"]), "no_relation")
                input_text = f"{text} [SEP] {e1['text'][0]} [SEP] {e2['text'][0]}"
                examples.append({
                    "text": input_text,
                    "label": label2id[label]
                })
    return Dataset.from_list(examples)

train_dataset = prepare_examples("train")
test_dataset = prepare_examples("test")


In [53]:
train_dataset[0]
len(label2id)

19

In [None]:

# Tokenize
def tokenize(example):
    encoding = tokenizer(example["text"], padding="max_length", truncation=True, max_length=256)
    encoding["labels"] = example["label"]
    return encoding

tokenized_train = train_dataset.map(tokenize, batched=False)
tokenized_test = test_dataset.map(tokenize, batched=False)

# Load model
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id
)

# Training setup
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer
)

# Train!
trainer.train()


In [None]:
# prediction function

def predict_relation(text: str, entity1: str, entity2: str, model, tokenizer, id2label):
    """
    Predict the relation between two entities in a given text.

    Args:
        text (str): The input sentence or passage.
        entity1 (str): The first entity.
        entity2 (str): The second entity.
        model: The fine-tuned BERT-based model.
        tokenizer: The tokenizer used during training.
        id2label (dict): A dictionary mapping label IDs to relation strings.

    Returns:
        predicted_label (str): The predicted relation type.
        confidence (float): The confidence score of the prediction.
    """
    # Format the input the same way you did during training
    input_text = f"{text} [SEP] {entity1} [SEP] {entity2}"

    # Tokenize and encode
    encoding = tokenizer(
        input_text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=256
    )

    # Move to the same device as the model
    encoding = {k: v.to(model.device) for k, v in encoding.items()}

    # Set model to eval mode and predict
    model.eval()
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits
        predicted_id = logits.argmax(dim=-1).item()
        confidence = torch.softmax(logits, dim=-1)[0][predicted_id].item()

    # Convert to label
    predicted_label = id2label[predicted_id]
    return predicted_label, confidence


In [None]:
# example application

text = "Aspirin is known to reduce the risk of heart attacks."
entity1 = "Aspirin"
entity2 = "heart attacks"

label, confidence = predict_relation(text, entity1, entity2, model, tokenizer, id2label)
print(f"Predicted relation: {label} (Confidence: {confidence:.2f})")
