# BioNNE-R Relation Extraction Baseline

Interactive notebook version of the OpenNRE-based baseline for relation extraction using `bert-base-multilingual-cased` with entity markers.

Supports three modes:
1. **Labeled mode** — relation TSV to JSON lines (standard training/evaluation)
2. **Labeled + negative sampling** — adds `no_relation` negatives from entity inventory
3. **Blind mode** — entity TSV to all candidate pairs (for blind evaluation)

## 1. Setup

In [None]:
# Install dependencies if needed
# !pip install git+https://github.com/thunlp/OpenNRE.git
# !pip install torch transformers nltk pandas scikit-learn

In [None]:
import json
import random
import re
from collections import Counter
from pathlib import Path

import nltk
import opennre
import pandas as pd
import torch

try:
    nltk.data.find("tokenizers/punkt_tab")
except LookupError:
    nltk.download("punkt_tab", quiet=True)

RELATION_TYPES = [
    "ABBREVIATION", "ALTERNATIVE_NAME", "SUBCLASS_OF", "PART_OF",
    "TREATED_USING", "ORIGINS_FROM", "TO_DETECT_OR_STUDY", "AFFECTS",
    "HAS_CAUSE", "APPLIED_TO", "USED_IN", "ASSOCIATED_WITH",
    "PHYSIOLOGY_OF", "FINDING_OF",
    "no_relation",
]

print(f"Relation classes: {len(RELATION_TYPES)} (including no_relation)")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

In [None]:
# Configuration — adjust paths for your setup
MODEL_NAME = "bert-base-multilingual-cased"
MAX_LENGTH = 256
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
EPOCHS = 10
WARMUP_STEPS = 300
LANG = "english"
SEED = 42

# Paths
TRAIN_REL_TSV = Path("eng-train-rel.tsv")      # relation TSV for training
TRAIN_ENT_TSV = Path("eng-train-ent.tsv")       # entity TSV for negative sampling
DEV_REL_TSV = Path("eng-dev-rel.tsv")           # relation TSV for dev
TEXTS_DIR = Path("texts")                        # raw .txt article files
CONFIG_PATH = Path("../data/annotation_short-bio.conf")  # annotation config for type filtering
NEG_RATIO = 3                                    # negatives per positive (0 = no negatives)

OUTPUT_DIR = Path("data")
CKPT_PATH = Path("outputs/model.pth.tar")
REL2ID_PATH = OUTPUT_DIR / "rel2id.json"

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
Path("outputs").mkdir(parents=True, exist_ok=True)

print(f"Model: {MODEL_NAME}")
print(f"Neg ratio: {NEG_RATIO}")

## 2. Data Preparation Functions

All functions from `prepare_data.py` — config parsing, entity loading, pair generation, sentence segmentation, negative sampling, blind mode conversion.

In [None]:
def parse_config(config_path):
    """Parse annotation config and return set of valid (arg1_type, arg2_type) tuples."""
    valid_pairs = set()
    with open(config_path, "r", encoding="utf-8") as f:
        content = f.read()
    relations_match = re.search(r"\[relations\](.*?)(?=\[|\Z)", content, re.DOTALL)
    if not relations_match:
        return valid_pairs
    for line in relations_match.group(1).strip().split("\n"):
        line = line.strip()
        if not line or line.startswith("<"):
            continue
        parts = line.split(None, 1)
        if len(parts) < 2:
            continue
        args_part = parts[1]
        arg1_match = re.search(r"Arg1:([A-Z_|]+)", args_part)
        if not arg1_match:
            continue
        arg1_types = arg1_match.group(1).split("|")
        arg2_match = re.search(r"Arg2:([A-Z_|]+)", args_part)
        if arg2_match:
            arg2_types = arg2_match.group(1).split("|")
        else:
            comma_match = re.search(r",\s*([A-Z_|]+)\s*$", args_part)
            if comma_match:
                arg2_types = comma_match.group(1).split("|")
            else:
                continue
        for t1 in arg1_types:
            for t2 in arg2_types:
                valid_pairs.add((t1, t2))
    return valid_pairs


def load_texts(texts_dir):
    """Load all .txt files from directory into {doc_id: text} dict."""
    texts = {}
    for txt_file in Path(texts_dir).glob("*.txt"):
        texts[txt_file.stem] = txt_file.read_text(encoding="utf-8")
    return texts


def load_entities(entity_tsv):
    """Read entity TSV -> {doc_id: [(type, text, span), ...]}."""
    df = pd.read_csv(entity_tsv, sep="\t")
    entities_by_doc = {}
    for _, row in df.iterrows():
        doc_id = str(row["document_id"])
        entry = (row["entity_type"], str(row["entity_text"]), str(row["entity_span"]))
        entities_by_doc.setdefault(doc_id, []).append(entry)
    return entities_by_doc


def parse_span(span_str):
    """Parse 'start-end' to (start, end)."""
    start, end = span_str.split("-")
    return int(start), int(end)


def generate_pairs(entities, valid_type_pairs=None):
    """Generate ordered entity pairs, optionally filtered by valid type combinations."""
    pairs = []
    for i, e1 in enumerate(entities):
        for j, e2 in enumerate(entities):
            if i == j:
                continue
            if valid_type_pairs is not None and (e1[0], e2[0]) not in valid_type_pairs:
                continue
            pairs.append((e1, e2))
    return pairs


def find_sentence_segment(text, head_start, head_end, tail_start, tail_end, lang="english"):
    """Find minimal sentence segment containing both entities. Returns (segment_text, offset)."""
    try:
        sent_tokenizer = nltk.data.load(f"tokenizers/punkt_tab/{lang}.pickle")
    except LookupError:
        sent_tokenizer = nltk.data.load("tokenizers/punkt_tab/english.pickle")
    sentences = list(sent_tokenizer.span_tokenize(text))
    if not sentences:
        return text, 0
    entity_min = min(head_start, tail_start)
    entity_max = max(head_end, tail_end)
    first_idx = 0
    last_idx = len(sentences) - 1
    for i, (s_start, s_end) in enumerate(sentences):
        if s_start <= entity_min < s_end:
            first_idx = i
        if s_start < entity_max <= s_end:
            last_idx = i
    if first_idx > last_idx:
        first_idx, last_idx = last_idx, first_idx
    first_idx = max(0, first_idx - 1)
    last_idx = min(len(sentences) - 1, last_idx + 1)
    seg_start = sentences[first_idx][0]
    seg_end = sentences[last_idx][1]
    return text[seg_start:seg_end], seg_start


def _make_instance(text, head_type, head_text, head_span, tail_type, tail_text, tail_span, relation, doc_id, lang):
    """Build one OpenNRE JSON instance from entity pair info."""
    head_start, head_end = parse_span(head_span)
    tail_start, tail_end = parse_span(tail_span)
    segment, offset = find_sentence_segment(text, head_start, head_end, tail_start, tail_end, lang)
    h_start = head_start - offset
    h_end = head_end - offset
    t_start = tail_start - offset
    t_end = tail_end - offset
    if h_start < 0 or h_end > len(segment) or t_start < 0 or t_end > len(segment):
        segment = text
        h_start, h_end = head_start, head_end
        t_start, t_end = tail_start, tail_end
    return {
        "text": segment,
        "h": {"name": head_text, "pos": [h_start, h_end]},
        "t": {"name": tail_text, "pos": [t_start, t_end]},
        "relation": relation,
        "doc_id": doc_id,
        "head_span": head_span,
        "tail_span": tail_span,
        "head_type": head_type,
        "tail_type": tail_type,
    }


def add_negatives(positives_by_doc, entities_by_doc, neg_ratio, valid_type_pairs, seed):
    """Sample negative pairs per document."""
    rng = random.Random(seed)
    negatives_by_doc = {}
    for doc_id, entities in entities_by_doc.items():
        pos_set = positives_by_doc.get(doc_id, set())
        all_pairs = generate_pairs(entities, valid_type_pairs)
        neg_candidates = [(e1, e2) for e1, e2 in all_pairs if (e1[2], e2[2]) not in pos_set]
        n_positives = len(pos_set)
        n_sample = min(neg_ratio * n_positives, len(neg_candidates))
        if n_sample > 0:
            negatives_by_doc[doc_id] = rng.sample(neg_candidates, n_sample)
    return negatives_by_doc


def convert_split(rel_tsv, texts_dir, output_path, lang="english",
                  entities_tsv=None, neg_ratio=0, valid_type_pairs=None, seed=42):
    """Convert labeled relation TSV + texts to OpenNRE JSON lines. Returns instance count."""
    df = pd.read_csv(rel_tsv, sep="\t")
    texts = load_texts(texts_dir)
    entities_by_doc = None
    positives_by_doc = {}
    if entities_tsv and neg_ratio > 0:
        entities_by_doc = load_entities(entities_tsv)
        for _, row in df.iterrows():
            doc_id = str(row["document_id"])
            positives_by_doc.setdefault(doc_id, set()).add(
                (str(row["head_span"]), str(row["tail_span"]))
            )
    count = 0
    neg_count = 0
    skipped = 0
    with open(output_path, "w", encoding="utf-8") as f:
        for _, row in df.iterrows():
            doc_id = str(row["document_id"])
            relation = row["relation"]
            if doc_id not in texts:
                skipped += 1
                continue
            if relation not in RELATION_TYPES:
                skipped += 1
                continue
            instance = _make_instance(
                texts[doc_id], row["head_type"], row["head_text"], str(row["head_span"]),
                row["tail_type"], row["tail_text"], str(row["tail_span"]),
                relation, doc_id, lang,
            )
            if instance:
                f.write(json.dumps(instance, ensure_ascii=False) + "\n")
                count += 1
        if entities_by_doc and neg_ratio > 0:
            neg_pairs_by_doc = add_negatives(positives_by_doc, entities_by_doc, neg_ratio, valid_type_pairs, seed)
            for doc_id, neg_pairs in neg_pairs_by_doc.items():
                if doc_id not in texts:
                    continue
                text = texts[doc_id]
                for head_ent, tail_ent in neg_pairs:
                    instance = _make_instance(
                        text, head_ent[0], head_ent[1], head_ent[2],
                        tail_ent[0], tail_ent[1], tail_ent[2],
                        "no_relation", doc_id, lang,
                    )
                    if instance:
                        f.write(json.dumps(instance, ensure_ascii=False) + "\n")
                        neg_count += 1
    if skipped:
        print(f"  Skipped {skipped} instances (missing text or unknown relation)")
    if neg_count:
        print(f"  Added {neg_count} no_relation negatives (ratio {neg_ratio}:1)")
    return count + neg_count


def convert_blind(entity_tsv, texts_dir, output_path, valid_type_pairs=None, lang="english"):
    """Blind mode: generate all candidate pairs from entity TSV. Returns instance count."""
    entities_by_doc = load_entities(entity_tsv)
    texts = load_texts(texts_dir)
    count = 0
    skipped_docs = 0
    with open(output_path, "w", encoding="utf-8") as f:
        for doc_id, entities in sorted(entities_by_doc.items()):
            if doc_id not in texts:
                skipped_docs += 1
                continue
            text = texts[doc_id]
            pairs = generate_pairs(entities, valid_type_pairs)
            for head_ent, tail_ent in pairs:
                instance = _make_instance(
                    text, head_ent[0], head_ent[1], head_ent[2],
                    tail_ent[0], tail_ent[1], tail_ent[2],
                    "no_relation", doc_id, lang,
                )
                if instance:
                    f.write(json.dumps(instance, ensure_ascii=False) + "\n")
                    count += 1
    if skipped_docs:
        print(f"  Skipped {skipped_docs} documents (no matching text file)")
    return count


def detect_input_type(tsv_path):
    """Auto-detect TSV type by column names. Returns 'relation' or 'entity'."""
    df = pd.read_csv(tsv_path, sep="\t", nrows=0)
    columns = set(df.columns)
    if "relation" in columns:
        return "relation"
    if "entity_type" in columns:
        return "entity"
    raise ValueError(f"Cannot auto-detect TSV type from columns: {sorted(columns)}")


print("Data preparation functions loaded.")

In [None]:
# Generate rel2id.json (always includes no_relation as class 14)
rel2id = {rel: i for i, rel in enumerate(RELATION_TYPES)}
with open(REL2ID_PATH, "w") as f:
    json.dump(rel2id, f, indent=2)
print(f"Wrote {REL2ID_PATH} ({len(rel2id)} classes)")

# Load config for type-based pair filtering (optional)
valid_type_pairs = None
if CONFIG_PATH.exists():
    valid_type_pairs = parse_config(str(CONFIG_PATH))
    print(f"Loaded config: {len(valid_type_pairs)} valid type pairs")
else:
    print(f"Config not found at {CONFIG_PATH}, skipping type filtering")

In [None]:
# Prepare training data (labeled + negative sampling)
train_output = OUTPUT_DIR / "train.txt"

print(f"Converting {TRAIN_REL_TSV} + {TEXTS_DIR} -> {train_output}")
n_train = convert_split(
    TRAIN_REL_TSV, TEXTS_DIR, train_output,
    lang=LANG,
    entities_tsv=TRAIN_ENT_TSV if TRAIN_ENT_TSV.exists() else None,
    neg_ratio=NEG_RATIO,
    valid_type_pairs=valid_type_pairs,
    seed=SEED,
)
print(f"  {n_train} total instances")

In [None]:
# Prepare dev data (labeled, no negatives)
dev_output = OUTPUT_DIR / "dev.txt"

print(f"Converting {DEV_REL_TSV} + {TEXTS_DIR} -> {dev_output}")
n_dev = convert_split(DEV_REL_TSV, TEXTS_DIR, dev_output, lang=LANG)
print(f"  {n_dev} instances")

## 3. Data Exploration

In [None]:
# Load and inspect prepared training data
train_instances = []
with open(train_output, encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            train_instances.append(json.loads(line))

# Relation distribution
rel_counts = Counter(inst["relation"] for inst in train_instances)
print(f"Training instances: {len(train_instances)}")
print(f"\nRelation distribution:")
for rel, count in rel_counts.most_common():
    print(f"  {rel:<25} {count:>6}")

In [None]:
# Sample instance
sample = train_instances[0]
print("Sample instance:")
for key, value in sample.items():
    if key == "text":
        print(f"  {key}: {value[:200]}...")
    else:
        print(f"  {key}: {value}")

In [None]:
# Load dev instances for later evaluation
dev_instances = []
with open(dev_output, encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if line:
            dev_instances.append(json.loads(line))

dev_rel_counts = Counter(inst["relation"] for inst in dev_instances)
print(f"Dev instances: {len(dev_instances)}")
print(f"\nDev relation distribution:")
for rel, count in dev_rel_counts.most_common():
    print(f"  {rel:<25} {count:>6}")

## 4. Training

### Setup and train using OpenNRE SentenceRE framework

In [None]:
print(f"Model: {MODEL_NAME}")
print(f"Train: {train_output}")
print(f"Dev: {dev_output}")
print(f"Classes: {len(rel2id)}")
print(f"Checkpoint: {CKPT_PATH}")

encoder = opennre.encoder.BERTEntityEncoder(
    max_length=MAX_LENGTH,
    pretrain_path=MODEL_NAME,
)

model = opennre.model.SoftmaxNN(
    sentence_encoder=encoder,
    num_class=len(rel2id),
    rel2id=rel2id,
)

framework = opennre.framework.SentenceRE(
    model=model,
    train_path=str(train_output),
    val_path=str(dev_output),
    test_path=str(dev_output),
    ckpt=str(CKPT_PATH),
    batch_size=BATCH_SIZE,
    max_epoch=EPOCHS,
    lr=LEARNING_RATE,
    opt="adamw",
    warmup_step=WARMUP_STEPS,
)

print("Framework ready.")

In [None]:
# Train the model
framework.train_model(metric="micro_f1")
print(f"\nBest checkpoint saved to: {CKPT_PATH}")

## 5. Prediction on Dev Set

In [None]:
# Load best checkpoint for prediction
encoder_pred = opennre.encoder.BERTEntityEncoder(
    max_length=MAX_LENGTH,
    pretrain_path=MODEL_NAME,
)

model_pred = opennre.model.SoftmaxNN(
    sentence_encoder=encoder_pred,
    num_class=len(rel2id),
    rel2id=rel2id,
)

ckpt = torch.load(str(CKPT_PATH), map_location="cpu")
model_pred.load_state_dict(ckpt["state_dict"])

if torch.cuda.is_available():
    model_pred = model_pred.cuda()
model_pred.eval()
print("Model loaded for prediction.")

In [None]:
# Predict on dev set
print(f"Predicting {len(dev_instances)} instances...")

rows = []
for inst in dev_instances:
    pred_rel, score = model_pred.infer({
        "text": inst["text"],
        "h": {"pos": inst["h"]["pos"]},
        "t": {"pos": inst["t"]["pos"]},
    })
    rows.append({
        "document_id": inst["doc_id"],
        "relation": pred_rel,
        "head_text": inst["h"]["name"],
        "head_span": inst["head_span"],
        "head_type": inst["head_type"],
        "tail_text": inst["t"]["name"],
        "tail_span": inst["tail_span"],
        "tail_type": inst["tail_type"],
    })

# Filter out no_relation predictions
pred_df = pd.DataFrame(rows)
total_pred = len(pred_df)
pred_df = pred_df[pred_df["relation"] != "no_relation"]
filtered = total_pred - len(pred_df)
print(f"Filtered {filtered} no_relation predictions, {len(pred_df)} relations remaining")

# Save predictions
pred_path = Path("outputs/pred.tsv")
pred_df.to_csv(pred_path, sep="\t", index=False)
print(f"Predictions saved to: {pred_path}")

## 6. Evaluation

In [None]:
# Per-relation precision / recall / F1 (against gold labels in dev data)
all_relations = sorted(rel2id.keys())

correct = sum(1 for inst, row in zip(dev_instances, rows) if inst["relation"] == row["relation"])
total = len(dev_instances)
print(f"Accuracy: {correct}/{total} = {correct/total:.4f}")

gold_counts = Counter(inst["relation"] for inst in dev_instances)
pred_counts = Counter(row["relation"] for row in rows)
tp_counts = Counter()
for inst, row in zip(dev_instances, rows):
    if inst["relation"] == row["relation"]:
        tp_counts[inst["relation"]] += 1

print(f"\n{'Relation':<25} {'P':>8} {'R':>8} {'F1':>8} {'Support':>8}")
print("-" * 60)
f1_scores = []
for rel in all_relations:
    tp = tp_counts.get(rel, 0)
    pred_total = pred_counts.get(rel, 0)
    gold_total = gold_counts.get(rel, 0)
    p = tp / pred_total if pred_total else 0
    r = tp / gold_total if gold_total else 0
    f1 = 2 * p * r / (p + r) if (p + r) else 0
    if gold_total > 0:
        f1_scores.append(f1)
    print(f"{rel:<25} {p:>8.4f} {r:>8.4f} {f1:>8.4f} {gold_total:>8}")

macro_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
print(f"\nMacro F1: {macro_f1:.4f}")

## 7. Blind Prediction

Generate candidate pairs from entity TSV (no gold relations), predict, and filter `no_relation`.

In [None]:
# Prepare blind test data from entity TSV
# Change this path to your blind test entity TSV
BLIND_ENT_TSV = Path("eng-test-ent.tsv")
BLIND_TEXTS_DIR = TEXTS_DIR  # same text directory, or change for test set
blind_output = OUTPUT_DIR / "test.txt"

if BLIND_ENT_TSV.exists():
    print(f"Blind mode: {BLIND_ENT_TSV} + {BLIND_TEXTS_DIR} -> {blind_output}")
    n_blind = convert_blind(
        BLIND_ENT_TSV, BLIND_TEXTS_DIR, blind_output,
        valid_type_pairs=valid_type_pairs,
        lang=LANG,
    )
    print(f"  {n_blind} candidate pairs")
else:
    print(f"Blind entity TSV not found at {BLIND_ENT_TSV}, skipping.")

In [None]:
# Run blind prediction (if test data was prepared)
blind_pred_path = Path("outputs/blind_pred.tsv")

if blind_output.exists() and blind_output.stat().st_size > 0:
    blind_instances = []
    with open(blind_output, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                blind_instances.append(json.loads(line))

    print(f"Predicting {len(blind_instances)} candidate pairs...")

    blind_rows = []
    for inst in blind_instances:
        pred_rel, score = model_pred.infer({
            "text": inst["text"],
            "h": {"pos": inst["h"]["pos"]},
            "t": {"pos": inst["t"]["pos"]},
        })
        blind_rows.append({
            "document_id": inst["doc_id"],
            "relation": pred_rel,
            "head_text": inst["h"]["name"],
            "head_span": inst["head_span"],
            "head_type": inst["head_type"],
            "tail_text": inst["t"]["name"],
            "tail_span": inst["tail_span"],
            "tail_type": inst["tail_type"],
        })

    # Filter out no_relation
    blind_pred_df = pd.DataFrame(blind_rows)
    total_blind = len(blind_pred_df)
    blind_pred_df = blind_pred_df[blind_pred_df["relation"] != "no_relation"]
    blind_filtered = total_blind - len(blind_pred_df)
    print(f"Filtered {blind_filtered} no_relation predictions, {len(blind_pred_df)} relations remaining")

    blind_pred_df.to_csv(blind_pred_path, sep="\t", index=False)
    print(f"Blind predictions saved to: {blind_pred_path}")
else:
    print("No blind test data to predict on.")

## 8. Error Analysis

In [None]:
# Find misclassified examples on dev set
errors = []
for inst, row in zip(dev_instances, rows):
    if inst["relation"] != row["relation"]:
        errors.append({
            "doc_id": inst["doc_id"],
            "head_text": inst["h"]["name"],
            "tail_text": inst["t"]["name"],
            "gold": inst["relation"],
            "predicted": row["relation"],
            "text": inst["text"][:200],
        })

print(f"Total errors: {len(errors)} / {len(rows)} ({100*len(errors)/len(rows):.1f}%)")
print(f"\nFirst 5 errors:")
for i, err in enumerate(errors[:5]):
    print(f"\n--- Error {i+1} ---")
    print(f"Doc: {err['doc_id']}")
    print(f"Head: {err['head_text']}")
    print(f"Tail: {err['tail_text']}")
    print(f"Gold: {err['gold']}")
    print(f"Predicted: {err['predicted']}")

In [None]:
# Confusion matrix for most common relations
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

# Map relations to indices for confusion matrix
gold_indices = [rel2id.get(inst["relation"], -1) for inst in dev_instances]
pred_indices = [rel2id.get(row["relation"], -1) for row in rows]

cm = confusion_matrix(gold_indices, pred_indices, labels=list(range(len(RELATION_TYPES))))

# Get top 8 most common relations by gold count
label_counts = np.array([cm[i].sum() for i in range(len(RELATION_TYPES))])
top_labels = np.argsort(label_counts)[::-1][:8]

cm_subset = cm[np.ix_(top_labels, top_labels)]
labels_subset = [RELATION_TYPES[i] for i in top_labels]

plt.figure(figsize=(10, 8))
plt.imshow(cm_subset, interpolation="nearest", cmap=plt.cm.Blues)
plt.title("Confusion Matrix (Top 8 Relations)")
plt.colorbar()
plt.xticks(range(len(labels_subset)), labels_subset, rotation=45, ha="right")
plt.yticks(range(len(labels_subset)), labels_subset)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()