# Dataset utils for joint entity relation extraction

In [None]:
#|default_exp jerx.dataset.webnlg

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export

from bellem.utils import split_camel_case

In [None]:
# |export


def _transform_relation(relation: str):
    return " ".join([word.lower() for word in split_camel_case(relation)]).strip()


def _transform_entity(entity: str):
    return entity.replace("_", " ").strip()


def _transform_triplet(triplet_string: str):
    delimiter = " | "
    triplet_string = triplet_string.replace('"', "").replace("''", "")
    entity1, relation, entity2 = triplet_string.split(delimiter)
    relation = _transform_relation(relation)
    entity1 = _transform_entity(entity1)
    entity2 = _transform_entity(entity2)
    return delimiter.join([entity1, relation, entity2])


def _batch_transform_webnlg(examples):
    for lex, mts in zip(examples["lex"], examples["modified_triple_sets"]):
        for text in lex["text"]:
            triplets = [_transform_triplet(triplet_string) for triplet_string in mts["mtriple_set"][0]]
            yield dict(text=text, triplets=triplets)


def batch_transform_webnlg(examples):
    records = list(_batch_transform_webnlg(examples))
    return {
        "text": [record["text"] for record in records],
        "triplets": [record["triplets"] for record in records],
    }

In [None]:
#|hide
from datasets import load_dataset

ds = load_dataset("web_nlg", "release_v3.0_en", split="train[:10]")
jerx_ds = ds.map(batch_transform_webnlg, batched=True, remove_columns=ds.column_names)

assert 'text' in jerx_ds.features
assert 'triplets' in jerx_ds.features
assert isinstance(jerx_ds[0]['triplets'], list)
assert isinstance(jerx_ds[0]['triplets'][0], str)
print(jerx_ds[0])

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

{'text': 'The Aarhus is the airport of Aarhus, Denmark.', 'triplets': ['Aarhus Airport|city served|Aarhus, Denmark']}


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()