# Knowledge Graph Construction
> Relation and entity extraction from text

In [None]:
#|default_exp ml.kg.cons

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export
import random
from dataclasses import dataclass
from typing import TypeAlias, Iterable, Callable, Any, Generator
import numpy as np

from bellek.logging import get_logger

log = get_logger(__name__)

In [None]:
#|hide
import json
def pprint(obj):
    print(json.dumps(obj, indent=2, ensure_ascii=False))

In [None]:
#|export

Entity: TypeAlias = str|tuple[str, str]
Relation: TypeAlias = str
Triplet: TypeAlias = tuple[Entity, Relation, Entity]

In [None]:
#|export

def evaluate_joint_er_extraction(*, reference: Iterable[Triplet], prediction: Iterable[Triplet]):
    """
    Example: [(('John', 'PERSON'), 'works_at', ('Google', 'ORG'))]
    """

    reference_set = set(reference)
    prediction_set = set(prediction)
    assert len(reference) == len(reference_set), "Duplicates found in references"

    TP = len(reference_set & prediction_set)
    FP = len(prediction_set - reference_set)
    FN = len(reference_set - prediction_set)
    
    # Calculate metrics
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1_score
    }

def evaluate_joint_er_extractions(*, references: Iterable[Iterable[Triplet]], predictions: Iterable[Iterable[Triplet]]):
    score_dicts = [
        evaluate_joint_er_extraction(reference=reference, prediction=prediction) 
        for reference, prediction in zip(references, predictions)
    ]
    return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()}

In [None]:
reference = [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('Paris', 'LOC')), (('Dwight', 'PERSON'), 'sells', ('Paper', 'OBJ'))]
prediction = [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('New York', 'LOC'))]

scores = evaluate_joint_er_extraction(reference=reference, prediction=prediction)
test_eq(scores, {'precision': 0.5, 'recall': 0.3333333333333333, 'f1': 0.4})
print(scores)

{'precision': 0.5, 'recall': 0.3333333333333333, 'f1': 0.4}


In [None]:
references = [
    [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('Paris', 'LOC')), (('Dwight', 'PERSON'), 'sells', ('Paper', 'OBJ'))],
    [(('Henry', 'PERSON'), 'founded', ('Ford', 'ORG'))],
]
predictions = [
    [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('New York', 'LOC'))],
    [(('Henry', 'PERSON'), 'founded', ('Boston Dynamics', 'ORG'))],
]

scores = evaluate_joint_er_extractions(references=references, predictions=predictions)
test_eq(scores, {'mean_precision': 0.25, 'mean_recall': 0.16666666666666666, 'mean_f1': 0.2})
print(scores)

{'mean_precision': 0.25, 'mean_recall': 0.16666666666666666, 'mean_f1': 0.2}


In [None]:
#|export

def parse_triplet_strings(text: str, delimiter: str="|") -> list[str]:
    return [line for line in text.splitlines() if line and line.count(delimiter) == 2]

def parse_triplets(text: str, delimiter: str="|") -> list[Triplet]:
    return [tuple(triplet_string.split(delimiter)) for triplet_string in parse_triplet_strings(text, delimiter=delimiter)]

In [None]:
#|hide
text = """
  Sure! Here are the entity-relation-entity triplets for the given text:

Aleksandre_Guruli|club|US_Lesquin
Paris|capitalOf|France

Please provide the next text for extraction.
"""
assert sorted(parse_triplet_strings(text)) == ["Aleksandre_Guruli|club|US_Lesquin", "Paris|capitalOf|France"]
assert sorted(parse_triplets(text)) == [("Aleksandre_Guruli", "club", "US_Lesquin"), ('Paris', 'capitalOf', 'France')]

## Prompting for joint entity-relation extraction

In [None]:
#|export

def format_triplets(triplets: Iterable[str]) -> str:
    return '\n'.join(triplets)

def format_few_shot_example(example, text_prefix="# Text\n", triplets_prefix="# Triplets\n"):
    text = example['text']
    triplets = format_triplets(example['triplets'])
    return f"{text_prefix}{text}\n{triplets_prefix}{triplets}"

def format_few_shot_examples(examples):
    return "\n\n".join([format_few_shot_example(example) for example in examples])

In [None]:
#|export

DEFAULT_SYSTEM_PROMPT_TEMPLATE = """You are a helpful assistant that extracts entity-relation-entity triplets from given text.
{relation_set_prompt}
{few_shot_prompt}
""".strip()

DEFAULT_RELATION_SET_PROMPT_TEMPLATE = """Here are the list of relations that you can use:
{relation_set}
"""

DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE = """Use the same format for triplets as in the examples provided below.
{few_shot_examples}
"""

@dataclass
class ERX2AlpacaFormatter:
    system_prompt_template: str = DEFAULT_SYSTEM_PROMPT_TEMPLATE
    relation_set_prompt_template: str = DEFAULT_RELATION_SET_PROMPT_TEMPLATE
    relation_set: set|None = None
    few_shot_examples_prompt_template: str = DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE
    few_shot_examples: list[dict]|None = None
    n_few_shot_examples: int = 3

    def __post_init__(self):
        if self.relation_set:
            self.relation_set = sorted(self.relation_set)
        if self.few_shot_examples is not None:
            self.few_shot_examples = list(self.few_shot_examples)

    def format(self, example: dict):
        instruction = self.make_system_prompt()
        input = example['text']
        output = '\n'.join(example['triplets'])
        return {
            'instruction': instruction.strip(),
            'input': input.strip(),
            'output': output.strip(),
        }

    def make_system_prompt(self) -> str:
        rsp = self.relation_set_prompt_template.format(relation_set=','.join(self.relation_set)) if self.relation_set else ""
        fsp = self.few_shot_examples_prompt_template.format(few_shot_examples=format_few_shot_examples(self._choose_few_shot_examples())) if self.few_shot_examples else ""
        return self.system_prompt_template.format(relation_set_prompt=rsp, few_shot_prompt=fsp)

    def _choose_few_shot_examples(self) -> list[dict]:
        if len(self.few_shot_examples) <= self.n_few_shot_examples:
            return self.few_shot_examples
        else:
            return random.sample(self.few_shot_examples, k=self.n_few_shot_examples)

In [None]:
example = {
    "text": "Dead Man's Plack is found in England and is made from rock. The capital of England is London and its religion is Church of England. The British Arabs are an English ethnic group.",
    "triplets": [
        "Dead_Man's_Plack|location|England",
        "England|ethnicGroup|British_Arabs",
    ],
}
erx2alpaca_formatter = ERX2AlpacaFormatter()
alpaca_example = erx2alpaca_formatter.format(example)
assert "instruction" in alpaca_example and "input" in alpaca_example and "output" in alpaca_example
alpaca_example

{'instruction': 'You are a helpful assistant that extracts entity-relation-entity triplets from given text.',
 'input': "Dead Man's Plack is found in England and is made from rock. The capital of England is London and its religion is Church of England. The British Arabs are an English ethnic group.",
 'output': "Dead_Man's_Plack|location|England\nEngland|ethnicGroup|British_Arabs"}

In [None]:
few_shot_examples = [
    {
        "text": "Ankara is capital city of Turkey",
        "triplets": [
            "Ankara|capital of|Turkey",
        ],
    },
    {
        "text": "Paris is capital city of France",
        "triplets": [
            "Paris|capital of|France",
        ],
    },
    {
        "text": "London is capital city of UK",
        "triplets": [
            "London|capital of|UK",
        ],
    }
]
example = {
    "text": "Moscow is capital city of Russia",
    "triplets": [
        "Moscow|capital of|Russia",
    ],
}
erx2alpaca_formatter = ERX2AlpacaFormatter(few_shot_examples=few_shot_examples, n_few_shot_examples=1)
for i in range(3):
    print("="*80)
    print(erx2alpaca_formatter.format(example))

{'instruction': 'You are a helpful assistant that extracts entity-relation-entity triplets from given text.\n\nUse the same format for triplets as in the examples provided below.\n# Text\nLondon is capital city of UK\n# Triplets\nLondon|capital of|UK', 'input': 'Moscow is capital city of Russia', 'output': 'Moscow|capital of|Russia'}
{'instruction': 'You are a helpful assistant that extracts entity-relation-entity triplets from given text.\n\nUse the same format for triplets as in the examples provided below.\n# Text\nParis is capital city of France\n# Triplets\nParis|capital of|France', 'input': 'Moscow is capital city of Russia', 'output': 'Moscow|capital of|Russia'}
{'instruction': 'You are a helpful assistant that extracts entity-relation-entity triplets from given text.\n\nUse the same format for triplets as in the examples provided below.\n# Text\nLondon is capital city of UK\n# Triplets\nLondon|capital of|UK', 'input': 'Moscow is capital city of Russia', 'output': 'Moscow|capita

In [None]:
#|export

DEFAULT_SYSTEM_PROMPT_TEMPLATE2 = """You are a helpful assistant that extracts entity-relation-entity triplets from given text. Use '|' as delimiter and provide one triplet per line.
{relation_set_prompt}
""".strip()

DEFAULT_RELATION_SET_PROMPT_TEMPLATE2 = """Here are the list of relations that you can use:
{relation_set}
""".strip()

@dataclass
class ERX2ChatFormatter:
    system_prompt_template: str = DEFAULT_SYSTEM_PROMPT_TEMPLATE2
    relation_set_prompt_template: str = DEFAULT_RELATION_SET_PROMPT_TEMPLATE2
    relation_set: set|None = None
    few_shot_examples: list[dict]|None = None
    n_few_shot_examples: int = 3

    def __post_init__(self):
        if self.relation_set:
            self.relation_set = sorted(self.relation_set)
        if self.few_shot_examples is not None:
            self.few_shot_examples = list(self.few_shot_examples)

    def format(self, example: dict):
        messages = [
            {"role": "system", "content": self.make_system_message()},
            *list(self.make_messages(*self._choose_few_shot_examples(), example)),
        ]
        return {'conversations': messages}

    def make_system_message(self) -> str:
        rsp = self.relation_set_prompt_template.format(relation_set=','.join(self.relation_set)) if self.relation_set else ""
        return self.system_prompt_template.format(relation_set_prompt=rsp)

    def make_messages(self, *examples) -> Generator[dict, None, None]:
        for example in examples:
            yield {"role": "user", "content": example["text"]}
            yield {"role": "assistant", "content": format_triplets(example["triplets"])}

    def _choose_few_shot_examples(self) -> list[dict]:
        if len(self.few_shot_examples) <= self.n_few_shot_examples:
            return self.few_shot_examples
        else:
            return random.sample(self.few_shot_examples, k=self.n_few_shot_examples)

In [None]:
few_shot_examples = [
    {
        "text": "Ankara is capital city of Turkey",
        "triplets": [
            "Ankara|capital of|Turkey",
        ],
    },
    {
        "text": "Paris is capital city of France",
        "triplets": [
            "Paris|capital of|France",
        ],
    },
    {
        "text": "London is capital city of UK",
        "triplets": [
            "London|capital of|UK",
        ],
    }
]
example = {
    "text": "Moscow is capital city of Russia",
    "triplets": [
        "Moscow|capital of|Russia",
    ],
}
erx2chat_formatter = ERX2ChatFormatter(few_shot_examples=few_shot_examples, n_few_shot_examples=1)
for i in range(3):
    print("="*80)
    pprint(erx2chat_formatter.format(example))

{
  "conversations": [
    {
      "role": "system",
      "content": "You are a helpful assistant that extracts entity-relation-entity triplets from given text. Use '|' as delimiter and provide one triplet per line.\n"
    },
    {
      "role": "user",
      "content": "London is capital city of UK"
    },
    {
      "role": "assistant",
      "content": "London|capital of|UK"
    },
    {
      "role": "user",
      "content": "Moscow is capital city of Russia"
    },
    {
      "role": "assistant",
      "content": "Moscow|capital of|Russia"
    }
  ]
}
{
  "conversations": [
    {
      "role": "system",
      "content": "You are a helpful assistant that extracts entity-relation-entity triplets from given text. Use '|' as delimiter and provide one triplet per line.\n"
    },
    {
      "role": "user",
      "content": "Paris is capital city of France"
    },
    {
      "role": "assistant",
      "content": "Paris|capital of|France"
    },
    {
      "role": "user",
      "con

## Evaluation of joint entity-relation extraction

In [None]:
#|export


def evaluate_pipe_jer(dataset, pipe):
    import evaluate

    log.info(f"Evaluating model for JER on dataset with {len(dataset)} samples.")

    results = pipe(dataset["input"])
    generations = [result[0]["generated_text"] for result in results]
    predictions = [parse_triplet_strings(text.strip()) for text in generations]
    references = [parse_triplet_strings(text.strip()) for text in dataset["output"]]

    dataf = dataset.to_pandas()
    dataf["generation"] = generations
    dataf["prediction"] = predictions
    dataf["reference"] = references

    metric = evaluate.load("bdsaglam/jer")
    scores = metric.compute(predictions=predictions, references=references)

    return scores, dataf


def evaluate_model_jer(
    dataset,
    *,
    response_template: str,
    tokenizer,
    model,
    max_new_tokens=256,
    batch_size=4,
    **kwargs,
):
    assert len(dataset) > 0, "Dataset is empty!"

    def extract_input_output(example):
        input, output = example["text"].rsplit(response_template, 1)
        input += response_template
        output = output.replace(tokenizer.special_tokens_map["eos_token"], "")
        return {"input": input, "output": output}

    dataset = dataset.map(extract_input_output)

    # setup generation pipeline
    from transformers import pipeline

    pipe = pipeline(
        task="text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=max_new_tokens,
        batch_size=batch_size,
        return_full_text=False,
        **kwargs,
    )

    return evaluate_pipe_jer(dataset, pipe)

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()