# 🚞 Zero-shot RE Training

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# # if you're running this in a colab notebook, you can run this cell to install the necessary dependencies
# pip install glirel
# !python -m spacy download en_core_web_sm

In [3]:
from glirel import GLiREL

save_path = 'logs/redocred/redocred-2024-09-17__17-53-16/model_42900'
model = GLiREL.from_pretrained(save_path, map_location='cpu')
# model = GLiREL.from_pretrained('jackboyla/glirel_beta')

config.json not found in /home/jackboylan/tmp/GLiREL/logs/redocred/redocred-2024-09-17__17-53-16/model_42900
  state_dict = torch.load(model_file, map_location=torch.device(map_location))


# Inference

To infer, the model needs `tokens`, `NER`, and `zero shot labels`.

### Eval data

In [4]:
import json
with open('./data/few_rel_all.jsonl', 'r') as f:
    data = [json.loads(line) for line in f]

i = 0

tokens = data[i]['tokenized_text']
ner = data[i]['ner']
labels = list(set([r['relation_text'] for r in data[i]['relations']]))
print(tokens)
print()
print(ner)
print(labels)

['Derren', 'Nesbitt', 'had', 'a', 'history', 'of', 'being', 'cast', 'in', '"', 'Doctor', 'Who', '"', ',', 'having', 'played', 'villainous', 'warlord', 'Tegana', 'in', 'the', '1964', 'First', 'Doctor', 'serial', '"', 'Marco', 'Polo', '"', '.']

[[26, 27, 'Q2989881', 'Marco Polo'], [22, 23, 'Q2989412', 'First Doctor']]
['characters']


In [5]:
labels = ['country of origin', 'licensed to broadcast to', 'father', 'followed by'] + labels
print(labels)

['country of origin', 'licensed to broadcast to', 'father', 'followed by', 'characters']


In [6]:
relations = model.predict_relations(tokens, labels, threshold=0.0, ner=ner)

print('Number of relations:', len(relations))  # num entity pairs (both directions) * num classes.... provided they're over the threshold

sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
print("\nDescending Order by Score:")
for item in sorted_data_desc:
    print(item)

Number of relations: 10

Descending Order by Score:
{'head_pos': [26, 28], 'tail_pos': [22, 24], 'head_text': ['Marco', 'Polo'], 'tail_text': ['First', 'Doctor'], 'label': 'followed by', 'score': 0.011230885982513428}
{'head_pos': [22, 24], 'tail_pos': [26, 28], 'head_text': ['First', 'Doctor'], 'tail_text': ['Marco', 'Polo'], 'label': 'characters', 'score': 0.010938132181763649}
{'head_pos': [22, 24], 'tail_pos': [26, 28], 'head_text': ['First', 'Doctor'], 'tail_text': ['Marco', 'Polo'], 'label': 'followed by', 'score': 0.010783118195831776}
{'head_pos': [26, 28], 'tail_pos': [22, 24], 'head_text': ['Marco', 'Polo'], 'tail_text': ['First', 'Doctor'], 'label': 'characters', 'score': 0.01027392502874136}
{'head_pos': [26, 28], 'tail_pos': [22, 24], 'head_text': ['Marco', 'Polo'], 'tail_text': ['First', 'Doctor'], 'label': 'father', 'score': 0.0004802523762919009}
{'head_pos': [22, 24], 'tail_pos': [26, 28], 'head_text': ['First', 'Doctor'], 'tail_text': ['Marco', 'Polo'], 'label': 'fath

### Real-world example

Constrain the entity types that can associated with a relationship.
e.g:

`co-founder` can only have a head `PERSON` entity and a tail `ORG` entity.

In [7]:
# Real-world example
import spacy
from glirel.modules.utils import constrain_relations_by_entity_type

nlp = spacy.load('en_core_web_sm')


text = "Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company is headquartered in Cupertino, California."

# text = "Jack Dorsey's father, Tim Dorsey, is a licensed pilot. Jack met his wife Sarah Paulson in New York in 2003. They have one son, Edward."

labels = {"glirel_labels": {
    'co-founder': {"allowed_head": ["PERSON"], "allowed_tail": ["ORG"]}, 
    'country of origin': {"allowed_head": ["PERSON", "ORG"], "allowed_tail": ["LOC", "GPE"]}, 
    'licensed to broadcast to': {"allowed_head": ["ORG"]},  
    'no relation': {},  
    'parent': {"allowed_head": ["PERSON"], "allowed_tail": ["PERSON"]}, 
    'followed by': {"allowed_head": ["PERSON", "ORG"], "allowed_tail": ["PERSON", "ORG"]},  
    'located in or next to body of water': {"allowed_head": ["LOC", "GPE", "FAC"], "allowed_tail": ["LOC", "GPE"]},  
    'spouse': {"allowed_head": ["PERSON"], "allowed_tail": ["PERSON"]},  
    'child': {"allowed_head": ["PERSON"], "allowed_tail": ["PERSON"]},  
    'founder': {"allowed_head": ["PERSON"], "allowed_tail": ["ORG"]},  
    'founded on date': {"allowed_head": ["ORG"], "allowed_tail": ["DATE"]},
    'headquartered in': {"allowed_head": ["ORG"], "allowed_tail": ["LOC", "GPE", "FAC"]},  
    'acquired by': {"allowed_head": ["ORG"], "allowed_tail": ["ORG", "PERSON"]},  
    'subsidiary of': {"allowed_head": ["ORG"], "allowed_tail": ["ORG", "PERSON"]}, 
    }
}


def predict_and_show(text, labels):
    doc = nlp(text)
    print(f"Text: {text}")

    tokens = [token.text for token in doc]

    # NOTE: the end index should be inclusive
    ner = [[ent.start, (ent.end - 1), ent.label_, ent.text] for ent in doc.ents]
    print(f"Entities detected: {ner}")

    labels_and_constraints = None
    if isinstance(labels, dict):
        labels = labels["glirel_labels"]
        labels_and_constraints = labels
        labels = list(labels.keys())

    relations = model.predict_relations(tokens, labels, threshold=0.0, ner=ner, top_k=1)

    if isinstance(labels_and_constraints, dict):
        print('Constraining relations by entity type')
        relations = constrain_relations_by_entity_type(doc.ents, labels_and_constraints, relations)

    print('Number of relations:', len(relations))

    sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
    print("\nDescending Order by Score:")
    for item in sorted_data_desc:
        print(f"{item['head_text']} --> {item['label']} --> {item['tail_text']} | score: {item['score']}")

predict_and_show(text, labels)

Text: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in April 1976. The company is headquartered in Cupertino, California.
Entities detected: [[0, 1, 'ORG', 'Apple Inc.'], [5, 6, 'PERSON', 'Steve Jobs'], [8, 9, 'PERSON', 'Steve Wozniak'], [12, 13, 'PERSON', 'Ronald Wayne'], [15, 16, 'DATE', 'April 1976'], [23, 23, 'GPE', 'Cupertino'], [25, 25, 'GPE', 'California']]
Constraining relations by entity type
Number of relations: 6

Descending Order by Score:
['Apple', 'Inc.'] --> headquartered in --> ['Cupertino'] | score: 0.9072741866111755
['Apple', 'Inc.'] --> headquartered in --> ['California'] | score: 0.8888104557991028
['Apple', 'Inc.'] --> founded on date --> ['April', '1976'] | score: 0.8402661681175232
['Steve', 'Jobs'] --> founder --> ['Apple', 'Inc.'] | score: 0.8150324821472168
['Steve', 'Wozniak'] --> founder --> ['Apple', 'Inc.'] | score: 0.8128281831741333
['Ronald', 'Wayne'] --> founder --> ['Apple', 'Inc.'] | score: 0.7810325026512146


A simple list of relation types can also be passed, although this generally results in noisier results.

In [8]:
text = "Jack knows Gill. They live in the same house in London. They are not related."
labels = ['family relation', 'knows', 'lives with', 'loves', 'licensed to broadcast to', 'father', 'followed by', 'no relation', 'lives in',]
predict_and_show(text, labels)

Text: Jack knows Gill. They live in the same house in London. They are not related.
Entities detected: [[0, 0, 'PERSON', 'Jack'], [2, 2, 'PERSON', 'Gill'], [11, 11, 'GPE', 'London']]
Number of relations: 6

Descending Order by Score:
['Jack'] --> lives in --> ['London'] | score: 0.9570847153663635
['Gill'] --> lives in --> ['London'] | score: 0.9562698006629944
['Jack'] --> knows --> ['Gill'] | score: 0.8528702259063721
['Gill'] --> knows --> ['Jack'] | score: 0.8421204090118408
['London'] --> lives in --> ['Gill'] | score: 0.6627970337867737
['London'] --> lives in --> ['Jack'] | score: 0.6488385796546936


In [10]:
# import huggingface_hub
# import os

# huggingface_hub.login(os.environ['HF_TOKEN'])

# model.save_pretrained(
#     './release_model/glirel_beta', 
#     push_to_hub=True, 
#     repo_id='jackboyla/glirel_beta'
# )

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/jackboylan/.cache/huggingface/token
Login successful


pytorch_model.bin:   0%|          | 0.00/1.87G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jackboyla/glirel_beta/commit/9d864f1374760c9d5d9321a25d93bdf8895d0964', commit_message='Push model using huggingface_hub.', commit_description='', oid='9d864f1374760c9d5d9321a25d93bdf8895d0964', pr_url=None, pr_revision=None, pr_num=None)

### Co-reference Resolution

In [27]:
import json
with open('data/redocred_test.jsonl', 'r') as f:
    data = [json.loads(line) for line in f] 

examples = data[:5]
batch_tokenized_text = [example['tokenized_text'] for example in examples]
batch_ner = [example['ner'] for example in examples]
print(batch_tokenized_text[0])
print()
print(batch_ner[0])
print()

batch_labels = [list(set([r['relation_text'] for r in example['relations']])) for example in examples]
batch_labels[0]

['The', 'Loud', 'Tour', 'was', 'the', 'fourth', 'overall', 'and', 'third', 'world', 'concert', 'tour', 'by', 'Barbadian', 'recording', 'artist', 'Rihanna', '.', 'Performing', 'in', 'over', 'twenty', 'countries', 'in', 'the', 'Americas', 'and', 'Europe', ',', 'the', 'tour', 'was', 'launched', 'in', 'support', 'of', 'Rihanna', "'s", 'fifth', 'studio', 'album', 'Loud', '(', '2010', ')', '.', 'Critics', 'acclaimed', 'the', 'show', 'for', 'its', 'liveliness', 'and', 'higher', 'caliber', 'of', 'quality', 'when', 'compared', 'to', 'Rihanna', "'s", 'previous', 'tours', '.', 'The', 'Loud', 'Tour', 'was', 'a', 'large', 'commercial', 'success', ',', 'experiencing', 'demand', 'for', 'an', 'extension', 'of', 'shows', 'in', 'the', 'United', 'Kingdom', 'due', 'to', 'popularity', '.', 'In', 'London', ',', 'Rihanna', 'played', 'a', 'record', 'breaking', '10', 'dates', 'at', 'The', 'O2', 'Arena', '.', 'The', 'tour', 'ultimately', 'grossed', 'an', 'estimated', 'value', 'of', 'US$', '90', 'million', 'from

['continent',
 'country',
 'located in the administrative territorial entity',
 'performer',
 'publication date',
 'notable work',
 'country of citizenship',
 'SELF']

In [37]:
model.base_config.fixed_relation_types = False
model = model.to('cpu')
relations = model.batch_predict_relations(batch_tokenized_text, batch_labels, threshold=0.01, ner=batch_ner, top_k=1)
relations[0]

[{'head_pos': [1, 3],
  'tail_pos': [16, 17],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Rihanna'],
  'label': 'performer',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [36, 37],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Rihanna'],
  'label': 'performer',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [43, 44],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['2010'],
  'label': 'publication date',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [61, 62],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Rihanna'],
  'label': 'performer',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [67, 69],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Loud', 'Tour'],
  'label': 'SELF',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [93, 94],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Rihanna'],
  'label': 'performer',
  'score': 1.0},
 {'head_pos': [1, 3],
  'tail_pos': [128, 130],
  'head_text': ['Loud', 'Tour'],
  'tail_text': ['Loud', 'To

In [34]:
from glirel.modules.utils import get_coreference_clusters, aggregate_cluster_relations

clusters_batch, entity_to_cluster_idx_batch = get_coreference_clusters(relations)
print("Clusters:", clusters_batch[0])
print()
cluster_relations_batch = aggregate_cluster_relations(entity_to_cluster_idx_batch, relations)
cluster_relations_batch[0]

Clusters: [[(128, 130), (67, 69), (1, 3)], [(13, 14)], [(61, 62), (93, 94), (36, 37), (16, 17)], [(41, 42)], [(43, 44)], [(84, 86)], [(91, 92)], [(101, 104)]]



[{'h_idx': 0, 't_idx': 2, 'r': 'performer'},
 {'h_idx': 0, 't_idx': 3, 'r': 'follows'},
 {'h_idx': 0, 't_idx': 4, 'r': 'publication date'},
 {'h_idx': 2, 't_idx': 0, 'r': 'notable work'},
 {'h_idx': 2, 't_idx': 1, 'r': 'country of citizenship'},
 {'h_idx': 2, 't_idx': 3, 'r': 'notable work'},
 {'h_idx': 3, 't_idx': 0, 'r': 'followed by'},
 {'h_idx': 3, 't_idx': 2, 'r': 'performer'},
 {'h_idx': 3, 't_idx': 4, 'r': 'publication date'},
 {'h_idx': 6,
  't_idx': 5,
  'r': 'located in the administrative territorial entity'},
 {'h_idx': 7,
  't_idx': 5,
  'r': 'located in the administrative territorial entity'},
 {'h_idx': 7,
  't_idx': 6,
  'r': 'located in the administrative territorial entity'}]

In [36]:
for cluster, cluster_relations_list, tokenized_text in zip(clusters_batch, cluster_relations_batch, batch_tokenized_text):
    for cluster_relations in cluster_relations_list:
        print()
        cluster_h_idx = cluster_relations['h_idx']
        cluster_t_idx = cluster_relations['t_idx']
        head_cluster = [tokenized_text[s:e] for s, e in cluster[cluster_h_idx]]
        print(f"Head Cluster: {head_cluster}")
        print(f"Relation: {cluster_relations['r']}")
        tail_cluster = [tokenized_text[s:e] for s, e in cluster[cluster_t_idx]]
        print(f"Tail Cluster: {tail_cluster}")


Head Cluster: [['Loud', 'Tour'], ['Loud', 'Tour'], ['Loud', 'Tour']]
Relation: performer
Tail Cluster: [['Rihanna'], ['Rihanna'], ['Rihanna'], ['Rihanna']]

Head Cluster: [['Loud', 'Tour'], ['Loud', 'Tour'], ['Loud', 'Tour']]
Relation: follows
Tail Cluster: [['Loud']]

Head Cluster: [['Loud', 'Tour'], ['Loud', 'Tour'], ['Loud', 'Tour']]
Relation: publication date
Tail Cluster: [['2010']]

Head Cluster: [['Rihanna'], ['Rihanna'], ['Rihanna'], ['Rihanna']]
Relation: notable work
Tail Cluster: [['Loud', 'Tour'], ['Loud', 'Tour'], ['Loud', 'Tour']]

Head Cluster: [['Rihanna'], ['Rihanna'], ['Rihanna'], ['Rihanna']]
Relation: country of citizenship
Tail Cluster: [['Barbadian']]

Head Cluster: [['Rihanna'], ['Rihanna'], ['Rihanna'], ['Rihanna']]
Relation: notable work
Tail Cluster: [['Loud']]

Head Cluster: [['Loud']]
Relation: followed by
Tail Cluster: [['Loud', 'Tour'], ['Loud', 'Tour'], ['Loud', 'Tour']]

Head Cluster: [['Loud']]
Relation: performer
Tail Cluster: [['Rihanna'], ['Rihanna'

### Coreference in the Wild
Give a real news story

In [28]:
# Real-world example
import spacy
from glirel.modules.utils import constrain_relations_by_entity_type
from glirel.modules.utils import get_coreference_clusters, aggregate_cluster_relations

from fastcoref import FCoref
fcoref_model = FCoref(device='cpu', enable_progress_bar=False)

nlp = spacy.load('en_core_web_sm')


text = """Polish city urged to evacuate as floods batter central Europe
A flooded area in Nysa, Poland
Image source,Reuters
Image caption,
A flooded area in Nysa, Poland

Laura Gozzi, Nick Thorpe, Adam Easton and Rob Cameron
BBC News
Reporting from
London, Budapest, Warsaw and Prague
Published
16 September 2024
The mayor of a Polish city has asked all 44,000 residents to evacuate, as widespread flooding continues to batter central Europe.

Nysa mayor Kordian Kolbiarz asked people to head for higher ground, citing the risk of an embankment breaching and releasing a cascade of water into the town from a nearby lake.

The death toll from the floods that hit over the weekend rose to at least 16 on Monday, with seven confirmed fatalities in Romania. Casualties were also recorded in Austria, the Czech Republic and Poland.

Budapest said it would close roads near the river Danube which runs through the Hungarian capital, citing the risk of flooding later this week.

“Please evacuate your belongings, yourselves, your loved ones. It is worth getting to the top floor of the building immediately, because the wave may be several metres high. This means that the whole town will be flooded,” Nysa Mayor Kolbiarz wrote.

Polish Prime Minister Donald Tusk said one billion zloty (£197m) would be allocated for flood victims in the country, adding that Poland would also apply for EU relief funds. His government has also declared a state of natural disaster.

Although conditions have stabilised in some places, others are bracing themselves for more disruption and danger from the floods, sparked by Storm Boris.

In Slovakia, the overflowing of the Danube River caused flooding in the Old Town area of the capital, Bratislava, with local media reporting that water levels exceeded 9m (30ft) and were expected to rise further.

A dog being lifted by Polish rescuers
Image source,Getty Images
Image caption,
Polish rescuers and soldiers evacuated local residents and their pets in the village of Rudawa, southern Poland

Watch on BBC iPlayer (UK Only)
From Above - Storm Boris
Attribution
iPlayer
Hungary is bracing itself for floods in the coming days. Warnings are in force along 500km (310 miles) of the Danube.

The river is rising by about a metre every 24 hours, with Budapest's mayor offering residents a million sandbags to protect against floodwaters.

Some tram lines will not operate, while roads along the river will be closed in the Hungarian capital from Monday evening. Trains between Budapest and Vienna have also been cancelled.

Prime Minister Viktor Orban said on X that he had postponed all his international obligations "due to the extreme weather conditions and the ongoing floods in Hungary".

The highest rainfall totals have been in the Czech Republic. In the north-eastern town of Jesenik, 473mm (19in) of rain has fallen since Thursday morning - five times the average monthly rainfall.

The Czech fire service delivered bottles of drinking water to stranded villages, where people were told not to drink water from their taps or their wells as it is likely heavily contaminated.

In the Austrian town of St Polten, more rain has fallen in four days than in the whole of the wettest autumn on record, in 1950.

Chancellor Karl Nehammer said the armed forces had been deployed to offer assistance to storm-hit regions. Austria's Climate Ministry said €300m (£253m) in recovery funds would be made available.

Most parties paused campaigning for the federal elections due in less than two weeks, on 29 September.

Villages and town were submerged in eastern Romania. Emil Dragomir, mayor of Slobozia Conachi, told media that the flooding had had a devastating impact.

"If you were here, you would cry instantly, because people are desperate, their whole lives' work is gone, there were people who were left with just the clothes they had on," he said.

Thousands of people have been evacuted in Poland, including the personnel and patients of a hospital in the town of Nysa. Roads have been badly disrupted and train traffic was suspended in many parts of the country.

On Monday morning, the mayor of Paczków in south-west Poland appealed to residents to evacuate after water began overflowing in a nearby reservoir, endangering the town.


Media caption,
Airborne rescues as Europe hit by floods

In other parts of Poland, however, water levels are now falling, according to local officials.

The mayor of Klodzko city, Michal Piszko, told Polish media the water had receded and the indications were the worst was now over.

Video footage from Monday morning showed that city centre streets which were inundated on Sunday were now water-free, although the footage also revealed the extent of damage done to the buildings.

Where will Storm Boris go next?
More rain is expected throughout Monday and Tuesday in Austria, the Czech Republic and south-east Germany, where another 100mm could fall.

While it may still take days for the flood waters to subside, the weather will improve in central Europe from mid-week with much drier conditions.

Storm Boris will, however, now move further south into Italy, where it will reintensify and bring heavy rain. The Emilia-Romagna region is set to be worst hit, with 100-150mm of rain falling.

The record rainfall seen in central Europe has been caused by a number of factors, including climate change.

Different weather elements came together to create a “perfect storm” in which very cold air from the Arctic met warm air from the Mediterranean.

A pattern of atmospheric pressure also meant that Storm Boris was stuck in one place for a long time.

Scientists say that a warmer atmosphere holds more moisture, leading to more intense rainfall. Warmer oceans also lead to more evaporation, feeding storm systems.

For every 1C rise in the global average temperature, the atmosphere is able to hold about 7% more moisture, external."""


labels = [
    "SELF",
    # Positive relation labels
    "mayor responsible for evacuation",
    "flood causes evacuation",
    "country affected by floods",
    "storm responsible for floods",
    "government allocates funds",
    "river causes flood",
    "mayor issues warning",
    "rescue teams assist residents",
    "prime minister declares disaster",
    "city experiences flood damage",
    "rescue of people and animals",
    "flooding disrupts transport",
    "climate change causes extreme weather",
    "scientists explain weather patterns",
    
    # Negative relation labels
    "mayor responsible for storm",
    "train causes flood",
    "flooding creates more jobs",
    "rescue teams cause damage",
    "election delayed by storm"
]




def predict_and_show(text, labels, coreference=False):
    doc = nlp(text)
    # print(f"Text: {text}")

    tokens = [token.text for token in doc]

    # NOTE: the end index should be inclusive
    ner = [[ent.start, (ent.end - 1), ent.label_, ent.text] for ent in doc.ents]
    # print(f"Entities detected: {ner}")

    labels_and_constraints = None
    if isinstance(labels, dict):
        labels = labels["glirel_labels"]
        labels_and_constraints = labels
        labels = list(labels.keys())

    relations = model.batch_predict_relations([tokens], [labels], threshold=0.01, ner=[ner], top_k=1)

    if isinstance(labels_and_constraints, dict):
        print('Constraining relations by entity type')
        relations = constrain_relations_by_entity_type(doc.ents, labels_and_constraints, relations)



    mention_char_spans = [(m.start_char, m.end_char) for m in doc.ents]
    if len(mention_char_spans) == 0:
        raise ValueError("No entities detected in the text. Please provide entities for coreference resolution.")
    else:
        output = fcoref_model.predict(
            texts=doc.text, custom_mentions=mention_char_spans
        )
    char_idx2tok_idx = {(ent.start_char, ent.end_char): (ent.start, ent.end) for ent in doc.ents}

    entity_to_cluster_idx_batch = {}
    fcoref_output = output.get_clusters(as_strings=False)
    # map positions to cluster id
    for cluster_id, cluster in enumerate(fcoref_output):
        for mention in cluster:
            token_pos = char_idx2tok_idx[mention]
            entity_to_cluster_idx_batch[(token_pos[0], token_pos[1])] = cluster_id
    all_indices = [idx for cluster in fcoref_output for idx in cluster]
    # add singletons
    cluster_id = len(fcoref_output)
    for k, token_pos in char_idx2tok_idx.items():
        if k not in all_indices:
            entity_to_cluster_idx_batch[(token_pos[0], token_pos[1])] = cluster_id
            fcoref_output.append([k])
            cluster_id += 1
    # sort by start idx of keys
    entity_to_cluster_idx_batch = dict(sorted(entity_to_cluster_idx_batch.items(), key=lambda x: x[0][0]))
    
    clusters_batch = [
        [[(char_idx2tok_idx[(s, e)][0], char_idx2tok_idx[(s, e)][1]) for s, e in cluster] for cluster in fcoref_output]
    ]


    # clusters_batch, entity_to_cluster_idx_batch = get_coreference_clusters(relations)
    print()
    cluster_relations_batch = aggregate_cluster_relations(entity_to_cluster_idx_batch, relations)
    cluster_relations_batch[0]

    batch_tokenized_text = [tokens]
    
    for cluster, cluster_relations_list, tokenized_text in zip(clusters_batch, cluster_relations_batch, batch_tokenized_text):
        print(f"len of (cluster, cluster_relations_list, tokenized_text): {len(cluster), len(cluster_relations_list), len(tokenized_text)}")
        for cluster_relations in cluster_relations_list:
            print()
            cluster_h_idx = cluster_relations['h_idx']
            cluster_t_idx = cluster_relations['t_idx']
            head_cluster = [tokenized_text[s:e] for s, e in cluster[cluster_h_idx]]
            print(f"Head Cluster: {head_cluster}")
            print(f"Relation: {cluster_relations['r']}")
            print(f"on cluster idx: {cluster_h_idx}")             
            tail_cluster = [tokenized_text[s:e] for s, e in cluster[cluster_t_idx]]
            print(f"Tail Cluster: {tail_cluster}")



predict_and_show(text, labels, coreference=True)

09/18/2024 12:39:32 - INFO - 	 missing_keys: []
09/18/2024 12:39:32 - INFO - 	 unexpected_keys: []
09/18/2024 12:39:32 - INFO - 	 mismatched_keys: []
09/18/2024 12:39:32 - INFO - 	 error_msgs: []
09/18/2024 12:39:32 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M
09/18/2024 12:39:32 - INFO - 	 Number of eval relation types per instance: [20]


Map:   0%|          | 0/1 [00:00<?, ? examples/s]


len of (cluster, cluster_relations_list, tokenized_text): (79, 59, 1178)

Head Cluster: [['Nysa'], ['Nysa'], ['Nysa']]
Relation: country affected by floods
on cluster idx: 0
Tail Cluster: [['BBC', 'News'], ['BBC', 'iPlayer'], ['iPlayer']]

Head Cluster: [['Budapest'], ['Warsaw'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Vienna']]
Relation: country affected by floods
on cluster idx: 2
Tail Cluster: [['Hungary'], ['Hungary']]

Head Cluster: [['Budapest'], ['Warsaw'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Vienna']]
on cluster idx: 2
Tail Cluster: [['Viktor', 'Orban']]

Head Cluster: [['Budapest'], ['Warsaw'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Hungarian'], ['Budapest'], ['Vienna']]
Relation: prime minister declares disaster
on cluster idx: 2
Tail Cluster: [['Viktor', 'Orban']]

Head Cluster: [['Kordian', 'Kolbiarz'], ['Kolbiarz']]
on cluster idx: 6
Tail Cluster: [['Nysa'], ['Nysa'], ['Nysa']]

Head Cluster: [[