# REBEL MODEL

Paper: https://aclanthology.org/2021.findings-emnlp.204.pdf

Git: https://github.com/Babelscape/rebel

## Rebel Models Triplet extraction

In [None]:
import pandas as pd
from tqdm import tqdm
from transformers import pipeline
import torch
import re

# --- Parsing functions ---

import re

def extract_triplets_mrebel(text):
    # Remove language tokens
    text = re.sub(r">>.*?<<", "", text)
    triplets = []
    # Split by <triplet>
    for triplet_str in text.split("<triplet>"):
        triplet_str = triplet_str.strip()
        if not triplet_str:
            continue
        # Find all type tokens and their positions
        matches = list(re.finditer(r"<[^>]+>", triplet_str))
        if len(matches) < 2:
            continue
        # Subject is before first type token
        subject = triplet_str[:matches[0].start()].strip()
        subject_type = matches[0].group(0).replace("<", "").replace(">", "").capitalize()
        # Object is after second type token
        object_ = triplet_str[matches[1].end():].strip()
        object_type = matches[1].group(0).replace("<", "").replace(">", "").capitalize()
        # TAIL is between the two type tokens
        TAIL = triplet_str[matches[0].end():matches[1].start()].strip()
        # Remove any type tokens from subject/object
        subject = re.sub(r"<[^>]+>", "", subject).strip()
        object_ = re.sub(r"<[^>]+>", "", object_).strip()
        # Split if there are multiple relations/entities (by double/triple spaces or '  ')
        for rel, obj in zip(TAIL.split("  "), object_.split("  ")):
            rel = rel.strip()
            obj = obj.strip()
            if subject and rel and obj:
                triplets.append({
                    'head': subject,
                    'head_type': subject_type,
                    'type': rel,
                    'RELATION': obj,
                    'tail_type': object_type
                })
    return triplets

def extract_triplets_rebel(text):
    # Remove language tokens
    text = re.sub(r">>.*?<<", "", text)
    triplets = []
    # Split by <triplet>
    for triplet_str in text.split("<triplet>"):
        triplet_str = triplet_str.strip()
        if not triplet_str:
            continue
        # Format: subject <subj> TAIL <obj> object
        subj_match = re.search(r"(.*?)<subj>", triplet_str)
        rel_match = re.search(r"<subj>(.*?)<obj>", triplet_str)
        obj_match = re.search(r"<obj>(.*)", triplet_str)
        if subj_match and rel_match and obj_match:
            subject = subj_match.group(1).strip()
            TAIL = rel_match.group(1).strip()
            object_ = obj_match.group(1).strip()
            # Remove any tags accidentally left
            subject = re.sub(r"<[^>]+>", "", subject).strip()
            TAIL = re.sub(r"<[^>]+>", "", TAIL).strip()
            object_ = re.sub(r"<[^>]+>", "", object_).strip()
            if subject and TAIL and object_:
                triplets.append({
                    'head': subject,
                    'head_type': '',
                    'type': TAIL,
                    'RELATION': object_,
                    'tail_type': ''
                })
    return triplets

# --- Device selection ---
device_to_use = 0 if torch.cuda.is_available() else -1
print(f"Using device: {'cuda:' + str(device_to_use) if device_to_use != -1 else 'cpu'}")

# --- Pipelines ---
triplet_extractor_mrebel = pipeline(
    'translation',
    model='Babelscape/mrebel-large',
    tokenizer='Babelscape/mrebel-large',
    device=device_to_use
)
triplet_extractor_rebel = pipeline(
    'text2text-generation',
    model='Babelscape/rebel-large',
    tokenizer='Babelscape/rebel-large',
    device=device_to_use
)

df = pd.read_csv('JuanRana_split.csv')
all_triplets = []

for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing documents"):
    document = row['ID']
    text = row['Text']

    # --- mREBEL ---
    try:
        output = triplet_extractor_mrebel(
            text,
            src_lang="es_XX",
            tgt_lang="es_XX",
            max_length=400,
            return_tensors=True,
            return_text=False
        )[0]["translation_token_ids"]
        decoded = triplet_extractor_mrebel.tokenizer.batch_decode([output], skip_special_tokens=False)[0]
        triplets_mrebel = extract_triplets_mrebel(decoded)
    except Exception as e:
        print(f"Error processing row {index} with mREBEL: {e}")
        triplets_mrebel = []

    # --- REBEL ---
    try:
        output = triplet_extractor_rebel(
            text,
            max_length=400,
            return_tensors=True,
            return_text=False
        )[0]["generated_token_ids"]
        decoded = triplet_extractor_rebel.tokenizer.batch_decode([output], skip_special_tokens=False)[0]
        triplets_rebel = extract_triplets_rebel(decoded)
    except Exception as e:
        print(f"Error processing row {index} with REBEL: {e}")
        triplets_rebel = []

    # --- Append ---
    for i, triplet in enumerate(triplets_mrebel, 1):
        all_triplets.append({
            "DOCUMENT": document,
            "SUBLABEL": i,
            "MODEL": "mrebel",
            "HEAD": triplet.get('head', ''),
            "TAIL": triplet.get('type', ''),
            "RELATION": triplet.get('RELATION', ''),
            "HEAD_TYPE": triplet.get('head_type', ''),
            "TAIL_TYPE": triplet.get('tail_type', '')
        })
    for i, triplet in enumerate(triplets_rebel, 1):
        all_triplets.append({
            "DOCUMENT": document,
            "SUBLABEL": i,
            "MODEL": "rebel",
            "HEAD": triplet.get('head', ''),
            "TAIL": triplet.get('type', ''),
            "RELATION": triplet.get('RELATION', ''),
            "HEAD_TYPE": triplet.get('head_type', ''),
            "TAIL_TYPE": triplet.get('tail_type', '')
        })

result_df = pd.DataFrame(all_triplets)
result_df.to_csv('triplets_output.csv', index=False)
print("Processing complete. Triplets saved to triplets_output.csv")