# Extracting triples
The model outputs relationship triples in the form of structured text. This notebook shows how the relationship triples are extracted out of that text.

In [15]:
from datasets import load_dataset
import evaluate
import yaml
import re
import pandas as pd
from transformers import AutoTokenizer
from wasabi import msg

## Load Config

In [2]:
# Load the config
with open('config/config_testing.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

## Load dataset

In [3]:
dataset = load_dataset(
        config['dataset_vars']['type'], 
        data_dir=config['dataset_vars']['dir'],
        column_names=config['dataset_vars']['column_names']
        )

dataset_train = dataset['train'].select(range(1,501)) # remove first row that contains column names
dataset_eval = dataset['validation'].select(range(1,501)) # remove first row that contains column names

In [4]:
relations_in_structerd_text = dataset_train["relations"]
entity_types = {"chemical":"@CHEMICAL@", 
                "disease":"@DISEASE@"}
rel_types = {"chemical induced disease":"@CID@"}

In [5]:
# voorbeeld
relations_in_structerd_text[0]

'alpha-methyldopa @CHEMICAL@ hypotensive @DISEASE@ @CID@'

In [6]:
def split_on_labels(input_text, labels):
    # Escape labels to ensure special characters are treated as literals in regex
    escaped_labels = [re.escape(label) for label in labels]
    # Join the labels into a regex pattern with alternation to match any of them
    pattern = '|'.join(escaped_labels)
    # Use re.split() with the compiled pattern, keeping the delimiters in the result
    relation_segments = re.split(f'({pattern})', input_text)
    # Filter out empty strings that might result from splitting
    relation_segments = [segment for segment in relation_segments if segment]
    return relation_segments

In [125]:
def handle_coreforents(ent, keep):
    coreferents = tuple([coref.strip() for coref in ent[0].split(';')])
    if keep and len(coreferents) > 1:
        return (coreferents, ent[1]) 
    else:
        return (coreferents[0], ent[1])

In [131]:
def extract_relation_triples(text: str, re_labels: list[str], keep_coreforents: bool = False) -> list[dict]:
    '''
    This function extracts the relationship triples out of structerd text. 
    This function assumes that the NER labels are in this structure: @label@
    
    input:
    text: The structerd text as a string.
    re_labels: The relationship labels.

    returns:
    A list of dictionaries
    '''
    # Split the input text into relation segments
    relation_segments = split_on_labels(text, re_labels)
    
    # Remove the last empty segment if it exists
    if not relation_segments[-1].strip():
        relation_segments = relation_segments[:-1]

    # Map relation label to entity text
    entity_texts = relation_segments[::2] # All uneven elements
    relation_labels = relation_segments[1::2] # All even elements
    if len(entity_texts) != len(relation_labels): raise ValueError('Amount of relation labels in the text does not equal to amount of entities pairs')
    
    # Initialize a list to hold the relation triples
    relations = []
    
    for entity_text, re_label in zip(entity_texts, relation_labels):
        # Split head and tail entities and their labels
        head_ent, tail_ent = [handle_coreforents(ent, keep_coreforents) for ent in re.findall(r'(.+?)\s@(\w+)@', entity_text)]
        # print(f"head_ent {head_ent} | tail_ent {tail_ent}") #DEBUG
        
        re_label = re_label.split('@')[1]
        relations.append({
            're_label':re_label,
            'head_ent': {'label':head_ent[1], 'text':head_ent[0]},
            'tail_ent': {'label':tail_ent[1], 'text':tail_ent[0]}
        })
    
    return relations

# Example usage
input_text = 'lithium ; li @CHEMICAL@ chronic renal failure @DISEASE@ @CID@ lithium ; li @CHEMICAL@ proteinuria @DISEASE@ @CID@ lithium ; li @CHEMICAL@ hypertension @DISEASE@ @CID@'
relation_triples = extract_relation_triples(input_text, ['@CID@'], True)
relation_triples


[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'chronic renal failure'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'proteinuria'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'hypertension'}}]

In [127]:
%%time
relation_triples = [extract_relation_triples(i, ['@CID@']) for i in relations_in_structerd_text]

CPU times: user 4.8 ms, sys: 240 µs, total: 5.04 ms
Wall time: 5.02 ms


In [128]:
pd.DataFrame({'structerd_text': relations_in_structerd_text, 'Relation triples': relation_triples})

Unnamed: 0,structerd_text,Relation triples
0,alpha-methyldopa @CHEMICAL@ hypotensive @DISEA...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
1,lidocaine @CHEMICAL@ cardiac asystole @DISEASE...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
2,suxamethonium ; suxamethonium chloride ; sch @...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
3,scopolamine ; hyoscine @CHEMICAL@ overdosage @...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
4,lithium ; li @CHEMICAL@ chronic renal failure ...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
...,...,...
495,zonisamide @CHEMICAL@ visual hallucinations @D...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
496,pan ; puromycin aminonucleoside @CHEMICAL@ nep...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
497,ticlopidine @CHEMICAL@ aplastic anemia @DISEAS...,"[{'re_label': 'CID', 'head_entity': {'label': ..."
498,scopolamine @CHEMICAL@ amnesia @DISEASE@ @CID@...,"[{'re_label': 'CID', 'head_entity': {'label': ..."


# Using 🤗 Evaluate to calculate accuracy, f1, precision and recall

In [49]:
clf_metrics = evaluate.combine(["f1", "precision", "recall"])

In [55]:
clf_metrics.compute(
    predictions=[0,1,1],
    references=[1,0,1]
)

{'f1': 0.5, 'precision': 0.5, 'recall': 0.5}

All of these metrics take input as a list of ints.

documentation:
- [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy)
- [precision](https://huggingface.co/spaces/evaluate-metric/precision)
- [recall](https://huggingface.co/spaces/evaluate-metric/recall)
- [f1](https://huggingface.co/spaces/evaluate-metric/f1)

so the extract relation triples needs to work on the input ids. Or the text needs to be converted to ints

## Converting text to ints method

Because we're doing a relation extraction from a document, evaluating the output will be a bit different from usual methods. This method will be based on if the set of entities or relation in the predictions are also found in the set of entities and relations in the reference. The sequence of entities and relations are irrelivant. 

### Named entity recoginition scoring

In [63]:
# starting point
index = 4
relation_triples[index]

[{'CID': {'CHEMICAL': ['lithium', 'li'], 'DISEASE': 'chronic renal failure'}},
 {'CID': {'CHEMICAL': ['lithium', 'li'], 'DISEASE': 'proteinuria'}},
 {'CID': {'CHEMICAL': ['lithium', 'li'], 'DISEASE': 'hypertension'}}]

In [99]:
predictions = relation_triples[index]
references = relation_triples[index]

### Define group of entities in the reference
entities_text = list()
for relation in references:
    nested_list = list(relation.values())[0] # Elements can be list if there is a coreferent mention
    # This next list comprehension is barely readable but if an element is a sublist it is flattend 
    # if an element is a string it's kept as it is. 
    entities_text.extend([item for sublist in nested_list.values() for item in (sublist if isinstance(sublist, list) else [sublist])])

### Check if a prediction is in the group of reference entities
for pred in predictions:
    for pred_ent in list(pred.values())[0].values():
        if isinstance(pred_ent, list):
        else:
            if pred_ent in entities_text:
                
    # Remove entity texts from the group if there is a match
# Check if the labels match
# 

### Relation extraxtion scoring

## Using inputs ids method

In [11]:
# Load tokenizer to access input ids
model_name = config['model_name']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, legacy=False)

In [34]:
# voorbeeld
msg.info(f"Output in text:")
print(relations_in_structerd_text[0]+'\n')
input_ids=tokenizer(relations_in_structerd_text[0], return_tensors="pt").input_ids
msg.info(f"Output as input ids:")
print(input_ids)
print() 

msg.info("Input ids of entity labels")
for key, value in entity_types.items():
    input_ids = tokenizer(value, return_tensors="pt").input_ids
    tokens = tokenizer.tokenize(value)
    print(f"{key} : {value}\ntokens:{tokens}\ninput ids: {input_ids[0][:-1]}"+"\n")

msg.info("Input ids of relation labels")
for key, value in rel_types.items():
    input_ids = tokenizer(value, return_tensors="pt").input_ids
    tokens = tokenizer.tokenize(value)
    print(f"{key} : {value}\ntokens:{tokens}\ninput ids: {input_ids[0][:-1]}"+"\n")

[38;5;4mℹ Output in text:[0m
alpha-methyldopa @CHEMICAL@ hypotensive @DISEASE@ @CID@

[38;5;4mℹ Output as input ids:[0m
tensor([[  491,  6977,    18, 22758,    26,    32,   102,     9,  3320, 13717,
           329, 23936,  1741, 10950,   324,     7,   757,  3320,   308, 19056,
         17892,  1741,  3320,   254,  4309,  1741,     1]])

[38;5;4mℹ Input ids of entity labels[0m
chemical : @CHEMICAL@
tokens:['▁@', 'CHE', 'M', 'ICAL', '@']
input ids: tensor([ 3320, 13717,   329, 23936,  1741])

disease : @DISEASE@
tokens:['▁@', 'D', 'ISE', 'ASE', '@']
input ids: tensor([ 3320,   308, 19056, 17892,  1741])

[38;5;4mℹ Input ids of relation labels[0m
chemical induced disease : @CID@
tokens:['▁@', 'C', 'ID', '@']
input ids: tensor([3320,  254, 4309, 1741])



We can use the unique sequence of input ids of the entity and relation labels to extract the input ids of the entity text.

### Custom implementation of precision, recall and f1 for RE:

In [247]:
def re_metric(predictions: list[str], references: list[str]):
    
    tp = 0 # True positive count
    fp = 0 # False positive count
    fn = 0 # False negative count
    
    # Define groups
    for pred_text, ref_text in zip(predictions, references):
    
        predictions = extract_relation_triples(pred_text, ['@CID@'], True)
        references = extract_relation_triples(ref_text, ['@CID@'], True)
    
        for pred in predictions:
            if pred in references: # True positive
                tp=+1
                references.remove(pred)
            else: # False positive
                fp+=1
                
        # False negative
        fn+=len(references)

    # Calculate metrics
    if (tp+fp) == 0: precision=0.0
    else: precision = tp/(tp+fp)

    if (tp+fn) == 0: recall=0.0
    else: recall = tp/(tp+fn)

    if (precision+recall) == 0: f1=0.0
    else: f1 = 2 * ((precision*recall)/(precision+recall))

    return {'re_precision':precision, 're_recall':recall, 're_f1':f1}

In [248]:
example_batch = relations_in_structerd_text[:10]
predictions = example_batch
references = example_batch
re_metric(predictions, references)

{'re_precision': 1.0, 're_recall': 1.0, 're_f1': 1.0}

### Custom implementation of precision, recall and f1 for NER:
This function should take a list of predicted outputs and a list of references. The lists are a list of strings, the strings being the decoded output of the model (using `tokenizer.batch_decode()`)

The output should be a dictionary with `key : value` pairs of `name metric : value of metric`

In [149]:
def get_group(relation_triples):
    group = []
    for rel in relation_triples:
        group.append(rel['head_ent'])
        group.append(rel['tail_ent'])
    return group

In [150]:
relation_triples

[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'chronic renal failure'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'proteinuria'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'hypertension'}}]

In [151]:
# testing get_group
get_group(relation_triples)

[{'label': 'CHEMICAL', 'text': ('lithium', 'li')},
 {'label': 'DISEASE', 'text': 'chronic renal failure'},
 {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
 {'label': 'DISEASE', 'text': 'proteinuria'},
 {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
 {'label': 'DISEASE', 'text': 'hypertension'}]

In [221]:
# Testing split_coferents
print(f"A single ent: {relation_triples[0]['head_ent']}")
split_coferents(relation_triples[0]['head_ent'])

A single ent: {'label': 'CHEMICAL', 'text': ('lithium', 'li')}


({'label': 'CHEMICAL', 'text': 'lithium'}, {'label': 'CHEMICAL', 'text': 'li'})

In [222]:
def split_coferents(ent):
    if isinstance(ent['text'], tuple):
        return tuple([{"label":ent["label"], "text":ent['text'][i]} for i in range(len(ent['text']))])
    else:
        return ent

In [223]:
def map_coferents(group):
    result = {}
    group = [split_coferents(ent) for ent in group] # Split coferents into two entities
    for ent in group:
        if isinstance(ent, tuple):
            for i in range(len(ent)):
                result[frozenset(ent[i].items())] = ent
        else:
            result[frozenset(ent.items())] = (ent,)

    return result

In [218]:
relation_triples

[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'chronic renal failure'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'proteinuria'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'hypertension'}}]

In [224]:
# testing map_coferents
group = get_group(relation_triples)
map_coferents(group)

{frozenset({('label', 'CHEMICAL'),
            ('text', 'lithium')}): ({'label': 'CHEMICAL',
   'text': 'lithium'}, {'label': 'CHEMICAL', 'text': 'li'}),
 frozenset({('label', 'CHEMICAL'),
            ('text', 'li')}): ({'label': 'CHEMICAL',
   'text': 'lithium'}, {'label': 'CHEMICAL', 'text': 'li'}),
 frozenset({('label', 'DISEASE'),
            ('text', 'chronic renal failure')}): ({'label': 'DISEASE',
   'text': 'chronic renal failure'},),
 frozenset({('label', 'DISEASE'),
            ('text', 'proteinuria')}): ({'label': 'DISEASE',
   'text': 'proteinuria'},),
 frozenset({('label', 'DISEASE'),
            ('text', 'hypertension')}): ({'label': 'DISEASE',
   'text': 'hypertension'},)}

In [239]:
example_batch = relations_in_structerd_text[:10]
predictions = example_batch
references = example_batch

coferent_matching = "relaxed"
predictions = example_batch
references = example_batch

tp = 0 # True positive count
fp = 0 # False positive count
fn = 0 # False negative count
for pred_text, ref_text in zip(predictions, references):
    # Define groups
    pred_group = get_group(extract_relation_triples(pred_text, ['@CID@'], True))
    ref_group = get_group(extract_relation_triples(ref_text, ['@CID@'], True))

    if coferent_matching == "relaxed":
        # Create mapping from a coferent mentions to all coferent mentions
        mapping_coferent = map_coferents(ref_group)

        pred_group = [split_coferents(ent) for ent in pred_group] # Split coferents into multiple entities 
        # print(f"pred_group before flattening: {pred_group}") # DEBUG
        pred_group = [item for sublist in pred_group for item in (sublist if isinstance(sublist, tuple) else [sublist])] # Flatten list
        # print(f"pred_group after flattening: {pred_group}"+'\n') # DEBUG

        ref_group = [split_coferents(ent) for ent in ref_group] # Split coferents into multiple entities 
        ref_group = [item for sublist in ref_group for item in (sublist if isinstance(sublist, tuple) else [sublist])] # Flatten list

    # print(f"pred_group: {pred_group}\nref_group: {ref_group}") #DEBUG
    # print(f"mapping: {mapping_coferent}")
    checked_coferent_pred = []
    for ent in pred_group:
        # print(ent) # DEBUG
        # print(ref_group) # DEBUG
        # print() #DEBUG
        if ent in ref_group: # True positive
            tp += 1
            # Remove all instances of the coferent mentions
            if coferent_matching == "relaxed": 
                [ref_group.remove(i) for i in mapping_coferent[frozenset(ent.items())]]
                checked_coferent_pred.extend([i for i in mapping_coferent[frozenset(ent.items())]])
            else: ref_group.remove(ent) 
            continue
        elif ent not in ref_group and ent not in checked_coferent_pred: # False positive
            fp += 1
    fn += len(ref_group) # False negative 
    # print(f"TP: {tp}, FP: {fp}, FN: {fn}") #DEBUG

# Calculate metrics
if (tp+fp) == 0: precision=0.0
else: precision = tp/(tp+fp)

if (tp+fn) == 0: recall=0.0
else: recall = tp/(tp+fn)

if (precision+recall) == 0: f1=0.0
else: f1 = 2 * ((precision*recall)/(precision+recall))

print({'precision':precision, 'recall':recall, 'f1':f1})

{'precision': 1.0, 'recall': 1.0, 'f1': 1.0}


In [None]:
def ner_metric(predictions: list[str], references: list[str], coferent_matching: str ="relaxed", re_labels: list[str]) -> dict:
    '''
    Calculates the precision, recall and f1-score for document named entity recognition. 
    input:
        predictions: 
            List of decoded outputs of the model
        references: 
            List of decoded gold data
        coferent_matching: 
            Wheter to use the coferent mentions to match named entities. can be either "relaxed" or "strict".
            relaxed meaning that all coferent mentions might be used to match a predicted named entity to a reference entity
            strict meaning the model needs to have all coferent mentions correct to count as a match. (including the sequence)
        re_labels
            Which relation extraxtion labels are used.
    output
        a dictionary with the key value pairs of metric_name : metric value
    '''
    tp = 0 # True positive count
    fp = 0 # False positive count
    fn = 0 # False negative count
    for pred_text, ref_text in zip(predictions, references):
        # Define groups
        pred_group = get_group(extract_relation_triples(pred_text, ['@CID@'], True))
        ref_group = get_group(extract_relation_triples(ref_text, ['@CID@'], True))
    
        if coferent_matching == "relaxed":
            # Create mapping from a coferent mentions to all coferent mentions
            mapping_coferent = map_coferents(ref_group)
    
            pred_group = [split_coferents(ent) for ent in pred_group] # Split coferents into multiple entities 
            # print(f"pred_group before flattening: {pred_group}") # DEBUG
            pred_group = [item for sublist in pred_group for item in (sublist if isinstance(sublist, tuple) else [sublist])] # Flatten list
            # print(f"pred_group after flattening: {pred_group}"+'\n') # DEBUG
    
            ref_group = [split_coferents(ent) for ent in ref_group] # Split coferents into multiple entities 
            ref_group = [item for sublist in ref_group for item in (sublist if isinstance(sublist, tuple) else [sublist])] # Flatten list
    
        # print(f"pred_group: {pred_group}\nref_group: {ref_group}") #DEBUG
        # print(f"mapping: {mapping_coferent}")
        checked_coferent_pred = []
        for ent in pred_group:
            # print(ent) # DEBUG
            # print(ref_group) # DEBUG
            # print() #DEBUG
            if ent in ref_group: # True positive
                tp += 1
                # Remove all instances of the coferent mentions
                if coferent_matching == "relaxed": 
                    [ref_group.remove(i) for i in mapping_coferent[frozenset(ent.items())]]
                    checked_coferent_pred.extend([i for i in mapping_coferent[frozenset(ent.items())]])
                else: ref_group.remove(ent) 
                continue
            elif ent not in ref_group and ent not in checked_coferent_pred: # False positive
                fp += 1
        fn += len(ref_group) # False negative 
        # print(f"TP: {tp}, FP: {fp}, FN: {fn}") #DEBUG
    
    # Calculate metrics
    if (tp+fp) == 0: precision=0.0
    else: precision = tp/(tp+fp)
    
    if (tp+fn) == 0: recall=0.0
    else: recall = tp/(tp+fn)
    
    if (precision+recall) == 0: f1=0.0
    else: f1 = 2 * ((precision*recall)/(precision+recall))
    
    return {'ner_precision':precision, 'ner_recall':recall, 'ner_f1':f1}

## Loading the model to evaluate 

In [None]:
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    device_map=device_map