<img src="../assets/CoLLIE_blue.png" alt="GoLLIE" width="200"/>

# Relation Extraction with GoLLIE

This notebook is an example of how to run Relation Extraction with GoLLIE. This notebook covers:

- How to define the guidelines for a task
- How to load GoLLIE
- How to generate model inputs
- How to parse the output
- How to implement a scorer and evaluate the output

In this notebook, we will demonstrate how to perform basic relation extraction with two string arguments per Relation class. However, GoLLIE can handle more complex relations with multiple arguments. Please refer to the `Create Custom Task` notebook if you wish to undertake more advanced relation extraction tasks 🔥🔥🔥.

### Import requeriments

See the requeriments.txt file in the main directory to install the required dependencies

In [30]:
import sys

sys.path.append("../")  # Add the GoLLIE base directory to sys path

In [31]:
import rich
import logging
from src.model.load_model import load_model
import black
import inspect
from jinja2 import Template
import tempfile
from src.tasks.utils_typing import AnnotationList

logging.basicConfig(level=logging.INFO)
from typing import Dict, List, Type

## Load GoLLIE

We will load GOLLIE-7B from the huggingface-hub.
You can use the function AutoModelForCausalLM.from_pretrained if you prefer it. However, we provide a handy load_model function with many functionalities already implemented that will assist you in reproducing our results.

Please note that setting use_flash_attention=True is mandatory. Our flash attention implementation has small numerical differences compared to the attention implementation in Huggingface. Using use_flash_attention=False will result in the model producing inferior results. Flash attention requires an available CUDA GPU. Running GOLLIE pre-trained models on a CPU is not supported. We plan to address this in future releases.

- Set force_auto_device_map=True to automatically load the model on available GPUs.
- Set quantization=4 if the model doesn't fit in your GPU memory.

In [3]:
model, tokenizer = load_model(
    inference=True,
    model_weights_name_or_path="HiTZ/GoLLIE-7B",
    quantization=None,
    use_lora=False,
    force_auto_device_map=True,
    use_flash_attention=True,
    torch_dtype="bfloat16",
)

INFO:root:Loading model model from HiTZ/GoLLIE-7B
INFO:root:We will load the model using the following device map: auto and max_memory: None
Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`,  it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again. You will see the new `added_tokens_decoder` attribute that will store the relevant information.
INFO:root:Loading model with dtype: torch.bfloat16


>>>> Flash Attention installed
>>>> Flash RoPE installed


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

INFO:root:Model dtype: torch.bfloat16
INFO:root:Total model memory footprint: 13477.101762 MB
INFO:root:Quantization is enabled, we will not merge LoRA layers into the model. Inference will be slower.


## Define the guidelines

First, we will define the labels and guidelines for the task. We will represent them as Python classes.

 For this demonstration, we will define two Relations extracted from the ACE05 guidelines: https://www.ldc.upenn.edu/sites/www.ldc.upenn.edu/files/english-entities-guidelines-v6.6.pdf

💡 Be creative and try to define your own guidelines to test GoLLIE!

In [32]:
from typing import List

from src.tasks.utils_typing import Relation, dataclass

"""
Relation definitions
"""


@dataclass
class PhysicalRelation(Relation):
    """The Physical Relation captures the physical location relation of entities such as:
    a Person entity located in a Facility, Location or GPE; or two entities that are near,
    but neither entity is a part of the other or located in/at the other."""

    arg1: str
    arg2: str


@dataclass
class PersonalSocialRelation(Relation):
    """The Personal-Social Relation describe the relationship between people. Both arguments must be entities
    of type Person. Please note: The arguments of these Relations are not ordered. The Relations are
    symmetric."""

    arg1: str
    arg2: str


ENTITY_DEFINITIONS: List[Relation] = [
    PhysicalRelation,
    PersonalSocialRelation,
]

if __name__ == "__main__":
    cell_txt = In[-1]

### Print the guidelines to guidelines.py

Due to IPython limitations, we must write the content of the previous cell to a file and then import the content from that file.

In [33]:
with open("guidelines.py", "w", encoding="utf8") as python_guidelines:
    print(cell_txt, file=python_guidelines)

from guidelines import *

We use inspect.getsource to get the guidelines as a string

In [34]:
guidelines = [inspect.getsource(definition) for definition in ENTITY_DEFINITIONS]

## Define input sentence

Here we define the input sentence and the gold labels.

You can define and empy list as gold labels if you don't have gold annotations.

In [10]:
text = "Ana and Mary are sisters. Mary was at the supermarket while Ana was at home."
gold = [
    PersonalSocialRelation(arg1="Ana", arg2="Mary"),
    PhysicalRelation(arg1="Mary", arg2="supermarket"),
    PhysicalRelation(arg1="Ana", arg2="home"),
]

## Filling a template

For Relation Extraction we will use the following prompt template.
We use Jinja templates, which are easy to implement and exceptionally fast. For more information, visit: https://jinja.palletsprojects.com/en/3.1.x/api/#high-level-api.

```Python
# The following lines describe the task definition
{%- for definition in guidelines %}
{{ definition }}
{%- endfor %}

# This is the text to analyze
text = {{ text.__repr__() }}

# The annotation instances that take place in the text above are listed here
result = [
{%- for ann in annotations %}
    {{ ann }},
{%- endfor %}
]

```

This template is stored in `templates/prompt.txt`

In [35]:
# Read template
with open("../templates/prompt.txt", "rt") as f:
    template = Template(f.read())
# Fill the template
formated_text = template.render(guidelines=guidelines, text=text, annotations=gold, gold=gold)

### Black Code Formatter

We use the Black Code Formatter to automatically unify all the prompts to the same format. 

https://github.com/psf/black

In [36]:
black_mode = black.Mode()
formated_text = black.format_str(formated_text, mode=black_mode)

### Print the filled and formatted template

In [37]:
rich.print(formated_text)

## Prepare model inputs

We remove everything after `result =` to run inference with the model.

In [38]:
prompt, _ = formated_text.split("result =")
prompt = prompt + "result ="

Tokenize the input sentence

In [39]:
model_input = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")

Remove the `eos` token from the input

In [40]:
model_input["input_ids"] = model_input["input_ids"][:, :-1]
model_input["attention_mask"] = model_input["attention_mask"][:, :-1]

## Run GoLLIE

We generate the predictions using GoLLIE.

We use `num_beams=1` and `do_sample=False` in our exmperiments. But feel free to experiment with differen decoding strategies 😊

In [41]:
%%time

model_ouput = model.generate(
    **model_input.to(model.device),
    max_new_tokens=128,
    do_sample=False,
    min_new_tokens=0,
    num_beams=1,
    num_return_sequences=1,
)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


CPU times: user 2.07 s, sys: 5.39 ms, total: 2.08 s
Wall time: 2.09 s


### Print the results

In [42]:
for y, x in enumerate(model_ouput):
    print(f"Answer {y}")
    rich.print(tokenizer.decode(x, skip_special_tokens=True).split("result = ")[-1])

Answer 0


## Parse the output

The output is a Python list of instances, we can execute it  🤯

We define the AnnotationList class to parse the output with a single line of code. The `AnnotationList.from_output` function filters any label that we did not define (hallucinations) to prevent getting an `undefined class` error. 

In [43]:
result = AnnotationList.from_output(
    tokenizer.decode(model_ouput[0], skip_special_tokens=True).split("result = ")[-1], task_module="guidelines"
)
rich.print(result)

Labels are an instance of the defined classes:

In [44]:
type(result[0])

guidelines.PersonalSocialRelation

In [45]:
result[0].arg1

'Ana'

In [46]:
result[0].arg2

'Mary'

# Evaluate the result

Finally, we will evaluate the outputs from the model.

First, we define an Scorer, for Relation Extraction, we will use the `RelationScorer` class.

We need to define the `valid_types` for the scorer, which will be the labels that we have defined. 

In [47]:
from src.tasks.utils_scorer import RelationScorer


class MyScorer(RelationScorer):
    """Compute the F1 score for Relation Extraction"""

    valid_types: List[Type] = ENTITY_DEFINITIONS

### Instanciate the scorer

In [48]:
scorer = MyScorer()

### Compute F1 

In [49]:
scorer_results = scorer(reference=[gold], predictions=[result])
rich.print(scorer_results)

GoLLIE has successfully labeled the sententence with the defined relations 🎉🎉🎉

GoLLIE will perform well on labels with well-defined and clearly bounded guidelines. 

Please share your cool experiments with us; we'd love to see what everyone is doing with GoLLIE!
- [@iker_garciaf](https://twitter.com/iker_garciaf)
- [@osainz59](https://twitter.com/osainz59)