# NER with GoLLIE-TF

### Import requeriments

See the requeriments.txt file in the main directory to install the required dependencies

In [2]:
import sys

sys.path.append("../")  # Add the GoLLIE base directory to sys path

In [None]:
import rich
import logging
from src.model.load_model import load_model
import black
import inspect
from jinja2 import Template
import tempfile
from src.tasks.utils_typing import AnnotationList

logging.basicConfig(level=logging.INFO)
from typing import Dict, List, Type

## Load GoLLIE

We will load GOLLIE-7B from the huggingface-hub.
You can use the function AutoModelForCausalLM.from_pretrained if you prefer it. However, we provide a handy load_model function with many functionalities already implemented that will assist you in reproducing our results.

Please note that setting use_flash_attention=True is mandatory. Our flash attention implementation has small numerical differences compared to the attention implementation in Huggingface. Using use_flash_attention=False will result in the model producing inferior results. Flash attention requires an available CUDA GPU. Running GOLLIE pre-trained models on a CPU is not supported. We plan to address this in future releases.

- Set force_auto_device_map=True to automatically load the model on available GPUs.
- Set quantization=4 if the model doesn't fit in your GPU memory.

In [1]:
import torch
print(torch.cuda.is_available())

False


In [5]:
model, tokenizer = load_model(
    inference=True,
    model_weights_name_or_path="ychenNLP/GoLLIE-7B-TF",
    quantization=None,
    use_lora=False,
    force_auto_device_map=True,
    use_flash_attention=True,
    torch_dtype="bfloat16",
)

INFO:root:Loading model model from HiTZ/GoLLIE-7B
INFO:root:We will load the model using the following device map: auto and max_memory: None


INFO:root:Loading model with dtype: torch.bfloat16


>>>> Flash Attention installed




>>>> Flash RoPE installed


Downloading shards: 100%|██████████| 2/2 [04:42<00:00, 141.33s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.23it/s]
INFO:root:Model dtype: torch.bfloat16
INFO:root:Total model memory footprint: 13477.101762 MB


In [None]:
test_input = "# The following lines describe the task definition\n@dataclass\nclass Location(Entity):\n    \"\"\"Roads (streets, motorways) trajectories regions (villages, towns, cities, provinces, countries, continents,\n    dioceses, parishes) structures (bridges, ports, dams) natural locations (mountains, mountain ranges, woods, rivers,\n    wells, fields, valleys, gardens, nature reserves, allotments, beaches, national parks) public places (squares, opera\n    houses, museums, schools, markets, airports, stations, swimming pools, hospitals, sports facilities, youth centers,\n    parks, town halls, theaters, cinemas, galleries, camping grounds, NASA launch pads, club houses, universities,\n    libraries, churches, medical centers, parking lots, playgrounds, cemeteries) commercial places (chemists, pubs,\n    restaurants, depots, hostels, hotels, industrial parks, nightclubs, music venues) assorted buildings (houses, monasteries,\n    creches, mills, army barracks, castles, retirement homes, towers, halls, rooms, vicarages, courtyards) abstract\n    ``places'' (e.g. {\\it the free world})\"\"\"\n\n    span: str  # Such as: \"U.S.\", \"Germany\", \"Britain\", \"Australia\", \"England\"\n\n\n@dataclass\nclass Person(Entity):\n    \"\"\"first, middle and last names of people, animals and fictional characters aliases.\"\"\"\n\n    span: str  # Such as: \"Clinton\", \"Dole\", \"Arafat\", \"Yeltsin\", \"Lebed\"\n\n\n@dataclass\nclass Organization(Entity):\n    \"\"\"Companies (press agencies, studios, banks, stock markets, manufacturers, cooperatives) subdivisions of\n    companies (newsrooms) brands political movements (political parties, terrorist organisations) government bodies\n    (ministries, councils, courts, political unions of countries (e.g. the {\\it U.N.})) publications (magazines, newspapers,\n    journals) musical companies (bands, choirs, opera companies, orchestras public organisations (schools, universities,\n    charities other collections of people (sports clubs, sports teams, associations, theaters companies, religious orders,\n    youth organisations.\"\"\"\n\n    span: str  # Such as: \"Reuters\", \"U.N.\", \"NEW YORK\", \"CHICAGO\", \"PUK\"\n\n\n# This is the text to analyze\ntext = \"Ko min tɔ bɛ kɔ - Dirisa Togola - Minisiriɲɛmɔgɔ , Sogɛli Kokala Mayiga n'a ka Kunnafonidi minisiri Mɛtiri Haruna Ture dalen a kan , taara nin ntɛnɛndon , zuwɛnkalo tile 28 , Kunnafonidalaw ka Soba la .\"\n\n# This is the English translation of the text\neng_text = \"The remaining party - Dirisa Togola - headed by Prime Minister, Sogeli Kokala Mayiga and his Minister of Information, M. Haruna Ture, attended this Saturday, June 28, at the House of Representatives.\"\n\n# Using translation and fusion\n# (1) generate annotation for eng_text\n# (2) generate annotation for text\n\n# The annotation instances that take place in the eng_text above are listed here\nresult ="

print(test_input)


# # The following lines describe the task definition
# @dataclass
# class Location(Entity):
#     """Roads (streets, motorways) trajectories regions (villages, towns, cities, provinces, countries, continents,
#     dioceses, parishes) structures (bridges, ports, dams) natural locations (mountains, mountain ranges, woods, rivers,
#     wells, fields, valleys, gardens, nature reserves, allotments, beaches, national parks) public places (squares, opera
#     houses, museums, schools, markets, airports, stations, swimming pools, hospitals, sports facilities, youth centers,
#     parks, town halls, theaters, cinemas, galleries, camping grounds, NASA launch pads, club houses, universities,
#     libraries, churches, medical centers, parking lots, playgrounds, cemeteries) commercial places (chemists, pubs,
#     restaurants, depots, hostels, hotels, industrial parks, nightclubs, music venues) assorted buildings (houses, monasteries,
#     creches, mills, army barracks, castles, retirement homes, towers, halls, rooms, vicarages, courtyards) abstract
#     ``places'' (e.g. {\it the free world})"""

#     span: str  # Such as: "U.S.", "Germany", "Britain", "Australia", "England"


# @dataclass
# class Person(Entity):
#     """first, middle and last names of people, animals and fictional characters aliases."""

#     span: str  # Such as: "Clinton", "Dole", "Arafat", "Yeltsin", "Lebed"


# @dataclass
# class Organization(Entity):
#     """Companies (press agencies, studios, banks, stock markets, manufacturers, cooperatives) subdivisions of
#     companies (newsrooms) brands political movements (political parties, terrorist organisations) government bodies
#     (ministries, councils, courts, political unions of countries (e.g. the {\it U.N.})) publications (magazines, newspapers,
#     journals) musical companies (bands, choirs, opera companies, orchestras public organisations (schools, universities,
#     charities other collections of people (sports clubs, sports teams, associations, theaters companies, religious orders,
#     youth organisations."""

#     span: str  # Such as: "Reuters", "U.N.", "NEW YORK", "CHICAGO", "PUK"


# # This is the text to analyze
# text = "Ko min tɔ bɛ kɔ - Dirisa Togola - Minisiriɲɛmɔgɔ , Sogɛli Kokala Mayiga n'a ka Kunnafonidi minisiri Mɛtiri Haruna Ture dalen a kan , taara nin ntɛnɛndon , zuwɛnkalo tile 28 , Kunnafonidalaw ka Soba la ."

# # This is the English translation of the text
# eng_text = "The remaining party - Dirisa Togola - headed by Prime Minister, Sogeli Kokala Mayiga and his Minister of Information, M. Haruna Ture, attended this Saturday, June 28, at the House of Representatives."

# # Using translation and fusion
# # (1) generate annotation for eng_text
# # (2) generate annotation for text

# # The annotation instances that take place in the eng_text above are listed here
# result =

In [None]:
model_input = tokenizer(test_input, return_tensors="pt")

print(model_input["input_ids"])

model_input["input_ids"] = model_input["input_ids"][:, :-1]
model_input["attention_mask"] = model_input["attention_mask"][:, :-1]

model_ouput = model.generate(
    **model_input.to(model.device),
    max_new_tokens=128,
    do_sample=False,
    min_new_tokens=0,
    num_beams=1,
    num_return_sequences=1,
)
print(tokenizer.batch_decode(model_ouput))