# Knowledge Graph Construction
> Relation and entity extraction from text

In [None]:
#| default_exp ml.kg.cons

In [None]:
#| hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export
from dataclasses import dataclass
from typing import TypeAlias, Iterable, List, Set, Tuple, Callable, Any, Dict
import numpy as np
from bellek.ml.llm.utils import LLAMA2_CHAT_PROMPT_TEMPLATE

In [None]:
#|export

Entity: TypeAlias = str | tuple[str, str]
Relation: TypeAlias = str
Triplet: TypeAlias = tuple[Entity, Relation, Entity]

In [None]:
#|export

def evaluate_joint_er_extraction(*, reference: Iterable[Triplet], prediction: Iterable[Triplet]):
    """
    Example: [(('John', 'PERSON'), 'works_at', ('Google', 'ORG'))]
    """

    reference_set = set(reference)
    prediction_set = set(prediction)
    assert len(reference) == len(reference_set), "Duplicates found in references"

    TP = len(reference_set & prediction_set)
    FP = len(prediction_set - reference_set)
    FN = len(reference_set - prediction_set)
    
    # Calculate metrics
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1_score
    }

def evaluate_joint_er_extractions(*, references: Iterable[Iterable[Triplet]], predictions: Iterable[Iterable[Triplet]]):
    score_dicts = [
        evaluate_joint_er_extraction(reference=reference, prediction=prediction) 
        for reference, prediction in zip(references, predictions)
    ]
    return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()}

In [None]:
reference = [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('Paris', 'LOC')), (('Dwight', 'PERSON'), 'sells', ('Paper', 'OBJ'))]
prediction = [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('New York', 'LOC'))]

scores = evaluate_joint_er_extraction(reference=reference, prediction=prediction)
test_eq(scores, {'precision': 0.5, 'recall': 0.3333333333333333, 'f1': 0.4})
print(scores)

{'precision': 0.5, 'recall': 0.3333333333333333, 'f1': 0.4}


In [None]:
references = [
    [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('Paris', 'LOC')), (('Dwight', 'PERSON'), 'sells', ('Paper', 'OBJ'))],
    [(('Henry', 'PERSON'), 'founded', ('Ford', 'ORG'))],
]
predictions = [
    [(('John', 'PERSON'), 'works_at', ('Google', 'ORG')), (('Mike', 'PERSON'), 'lives_in', ('New York', 'LOC'))],
    [(('Henry', 'PERSON'), 'founded', ('Boston Dynamics', 'ORG'))],
]

scores = evaluate_joint_er_extractions(references=references, predictions=predictions)
test_eq(scores, {'mean_precision': 0.25, 'mean_recall': 0.16666666666666666, 'mean_f1': 0.2})
print(scores)

{'mean_precision': 0.25, 'mean_recall': 0.16666666666666666, 'mean_f1': 0.2}


In [None]:
#|export

def parse_triplet_strings(text: str, delimiter: str=" | ") -> List[str]:
    return [line for line in text.splitlines() if line and line.count(delimiter) == 2]

def parse_triplets(text: str, delimiter: str=" | ") -> List[Triplet]:
    return [tuple(triplet_string.split(delimiter)) for triplet_string in parse_triplet_strings(text, delimiter=delimiter)]

In [None]:
#|hide
text = """
  Sure! Here are the entity-relation-entity triplets for the given text:

Aleksandre_Guruli | club | US_Lesquin
Paris | capitalOf | France

Please provide the next text for extraction.
"""
assert sorted(parse_triplet_strings(text)) == ["Aleksandre_Guruli | club | US_Lesquin", "Paris | capitalOf | France"]
assert sorted(parse_triplets(text)) == [("Aleksandre_Guruli", "club", "US_Lesquin"), ('Paris', 'capitalOf', 'France')]

In [None]:
#|export

def format_few_shot_example(example, text_prefix="# Text\n", triplets_prefix="# Triplets\n"):
    text = example['text']
    triplets = '\n'.join(example['triplets'])
    return f"{text_prefix}{text}\n{triplets_prefix}{triplets}"

def format_few_shot_examples(examples):
    return "\n\n".join([format_few_shot_example(example) for example in examples])

In [None]:
#|export

DEFAULT_RELATION_SET_PROMPT_TEMPLATE = """Here are the list of relations that you can use:
{relation_set}
"""

DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE = """Use the same format for triplets as in the examples provided below.
{few_shot_examples}
"""

DEFAULT_SYSTEM_PROMPT_TEMPLATE = """You are helpful assistant that extracts entity-relation-entity triplets from given text.
{relation_set_prompt}
{few_shot_prompt}
""".strip()


@dataclass
class ERXFormatter:
    chat_prompt_template: str = LLAMA2_CHAT_PROMPT_TEMPLATE
    system_prompt_template: str = DEFAULT_SYSTEM_PROMPT_TEMPLATE
    few_shot_examples_prompt_template: str = DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE
    few_shot_examples: List[Dict] | None = None
    n_few_shot_examples: int = 3
    relation_set_prompt_template: str = DEFAULT_RELATION_SET_PROMPT_TEMPLATE
    relation_set: set | None = None

    def __post_init__(self):
        if self.relation_set:
            self.relation_set = sorted(self.relation_set)

    def format_for_inference(self, example: Dict):
        example = {**example}
        user_message = example['text']
        example['text'] = self.chat_prompt_template.format(system_prompt=self.make_system_prompt(), user_message=user_message)
        return example

    def format_for_train(self, example: Dict):
        example = {**example}
        example['text'] = self.format_for_inference(example)['text'] + " " + '\n'.join(example['triplets'])
        return example

    def make_system_prompt(self) -> str:
        rsp = self.relation_set_prompt_template.format(relation_set=','.join(self.relation_set)) if self.relation_set else ""
        fsp = self.few_shot_examples_prompt_template.format(few_shot_examples=format_few_shot_examples(self._choose_few_shot_examples())) if self.few_shot_examples else ""
        return self.system_prompt_template.format(relation_set_prompt=rsp, few_shot_prompt=fsp)

    def _choose_few_shot_examples(self) -> list[dict]:
        if len(self.few_shot_examples) <= self.n_few_shot_examples:
            return self.few_shot_examples
        else:
            return np.random.choice(self.few_shot_examples, self.n_few_shot_examples, replace=False).tolist()

In [None]:
example = {
    "text": "Dead Man's Plack is found in England and is made from rock. The capital of England is London and its religion is Church of England. The British Arabs are an English ethnic group.",
    "triplets": [
        "Dead_Man's_Plack | location | England",
        "England | ethnicGroup | British_Arabs",
    ],
}

In [None]:
erx_formatter = ERXFormatter()
print(erx_formatter.format_for_train(example)['text'])

<s>[INST] <<SYS>>
You are helpful assistant that extracts entity-relation-entity triplets from given text.


<</SYS>>

Dead Man's Plack is found in England and is made from rock. The capital of England is London and its religion is Church of England. The British Arabs are an English ethnic group. [/INST] Dead_Man's_Plack | location | England
England | ethnicGroup | British_Arabs


In [None]:
few_shot_examples = [
    {
        "text": "Ankara is capital city of Turkey",
        "triplets": [
            "Ankara | capital of | Turkey",
        ],
    },
    {
        "text": "Paris is capital city of France",
        "triplets": [
            "Paris | capital of | France",
        ],
    },
    {
        "text": "London is capital city of UK",
        "triplets": [
            "London | capital of | UK",
        ],
    }
]
example = {
    "text": "Moscow is capital city of Russia",
    "triplets": [
        "Moscow | capital of | Russia",
    ],
}
erx_formatter = ERXFormatter(few_shot_examples=few_shot_examples, n_few_shot_examples=1)
for i in range(3):
    print("="*80)
    print(erx_formatter.format_for_train(example)['text'])

<s>[INST] <<SYS>>
You are helpful assistant that extracts entity-relation-entity triplets from given text.

Use the same format for triplets as in the examples provided below.
# Text
London is capital city of UK
# Triplets
London | capital of | UK

<</SYS>>

Moscow is capital city of Russia [/INST] Moscow | capital of | Russia
<s>[INST] <<SYS>>
You are helpful assistant that extracts entity-relation-entity triplets from given text.

Use the same format for triplets as in the examples provided below.
# Text
Paris is capital city of France
# Triplets
Paris | capital of | France

<</SYS>>

Moscow is capital city of Russia [/INST] Moscow | capital of | Russia
<s>[INST] <<SYS>>
You are helpful assistant that extracts entity-relation-entity triplets from given text.

Use the same format for triplets as in the examples provided below.
# Text
London is capital city of UK
# Triplets
London | capital of | UK

<</SYS>>

Moscow is capital city of Russia [/INST] Moscow | capital of | Russia


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()