In [1]:
from datasets import load_dataset
import evaluate
import yaml
import re
import pandas as pd
from transformers import AutoTokenizer
from wasabi import msg
from pathlib import Path
from os.path import abspath

  from .autonotebook import tqdm as notebook_tqdm


## Setting home directory

In [2]:
home_dir = Path(abspath("")).parent

msg.info(f"Home directory: {home_dir}")

# Change when using a different config
config_path = home_dir.joinpath('config/config_testing.yaml')

[38;5;4mℹ Home directory: /home/lgrootde/Generative-re-tests[0m


## Load Config

In [3]:
# Load the config
with open(config_path) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

## Load dataset

In [4]:
dataset = load_dataset(
        config['dataset_vars']['type'], 
        data_dir=home_dir.joinpath(config['dataset_vars']['dir']),
        column_names=config['dataset_vars']['column_names']
        )

dataset_train = dataset['train'].select(range(1,501)) # remove first row that contains column names
dataset_eval = dataset['validation'].select(range(1,501)) # remove first row that contains column names

In [5]:
relations_in_structerd_text = dataset_train["relations"]
entity_types = {"chemical":"@CHEMICAL@", 
                "disease":"@DISEASE@"}
rel_types = {"chemical induced disease":"@CID@"}

In [6]:
# voorbeeld
relations_in_structerd_text[4]

'lithium ; li @CHEMICAL@ chronic renal failure @DISEASE@ @CID@ lithium ; li @CHEMICAL@ proteinuria @DISEASE@ @CID@ lithium ; li @CHEMICAL@ hypertension @DISEASE@ @CID@'

## Extracting triples from structerd text

In [7]:
def split_on_labels(input_text, labels):
    # Escape labels to ensure special characters are treated as literals in regex
    escaped_labels = [re.escape(label) for label in labels]
    # Join the labels into a regex pattern with alternation to match any of them
    pattern = '|'.join(escaped_labels)
    # Use re.split() with the compiled pattern, keeping the delimiters in the result
    relation_segments = re.split(f'({pattern})', input_text)
    # Filter out empty strings that might result from splitting
    relation_segments = [segment for segment in relation_segments if segment]
    return relation_segments

In [8]:
def handle_coreforents(ent, keep):
    coreferents = tuple([coref.strip() for coref in ent[0].split(';')])
    if keep and len(coreferents) > 1:
        return (coreferents, ent[1]) 
    else:
        return (coreferents[0], ent[1])

In [9]:
def extract_relation_triples(text: str, re_labels: list[str], keep_coreforents: bool = False) -> list[dict]:
    '''
    This function extracts the relationship triples out of structerd text. 
    This function assumes that the NER labels are in this structure: @label@
    
    input:
    text: The structerd text as a string.
    re_labels: The relationship labels.

    returns:
    A list of dictionaries
    '''
    # Split the input text into relation segments
    relation_segments = split_on_labels(text, re_labels)
    
    # Remove the last empty segment if it exists
    if not relation_segments[-1].strip():
        relation_segments = relation_segments[:-1]

    # Map relation label to entity text
    entity_texts = relation_segments[::2] # All uneven elements
    relation_labels = relation_segments[1::2] # All even elements
    if len(entity_texts) != len(relation_labels):
        msg.warn(f" len entity text {len(entity_texts)} | {entity_texts}")
        msg.warn(f" len relation labels {len(relation_labels)} | {relation_labels}")
        raise ValueError('Amount of relation labels in the text does not equal to amount of entities pairs')
        
    
    # Initialize a list to hold the relation triples
    relations = []
    
    for entity_text, re_label in zip(entity_texts, relation_labels):
        # Split head and tail entities and their labels
        head_ent, tail_ent = [handle_coreforents(ent, keep_coreforents) for ent in re.findall(r'(.+?)\s@(\w+)@', entity_text)]
        # print(f"head_ent {head_ent} | tail_ent {tail_ent}") #DEBUG
        
        re_label = re_label.split('@')[1]
        relations.append({
            're_label':re_label,
            'head_ent': {'label':head_ent[1], 'text':head_ent[0]},
            'tail_ent': {'label':tail_ent[1], 'text':tail_ent[0]}
        })
    
    return relations

In [10]:
# Example usage
input_text = 'lithium ; li @CHEMICAL@ chronic renal failure @DISEASE@ @CID@ lithium ; li @CHEMICAL@ proteinuria @DISEASE@ @CID@ lithium ; li @CHEMICAL@ hypertension @DISEASE@ @CID@'
relation_triples = extract_relation_triples(input_text, ['@CID@'], True)
relation_triples

[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'chronic renal failure'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'proteinuria'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'hypertension'}}]

In [11]:
%%time
relation_triples = [extract_relation_triples(i, ['@CID@'], True) for i in relations_in_structerd_text]

CPU times: user 5.25 ms, sys: 0 ns, total: 5.25 ms
Wall time: 5.18 ms


In [12]:
dataframe=pd.DataFrame({'structerd_text': relations_in_structerd_text, 'Relation triples': relation_triples})
dataframe

Unnamed: 0,structerd_text,Relation triples
0,alpha-methyldopa @CHEMICAL@ hypotensive @DISEA...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
1,lidocaine @CHEMICAL@ cardiac asystole @DISEASE...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
2,suxamethonium ; suxamethonium chloride ; sch @...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
3,scopolamine ; hyoscine @CHEMICAL@ overdosage @...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
4,lithium ; li @CHEMICAL@ chronic renal failure ...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
...,...,...
495,zonisamide @CHEMICAL@ visual hallucinations @D...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
496,pan ; puromycin aminonucleoside @CHEMICAL@ nep...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
497,ticlopidine @CHEMICAL@ aplastic anemia @DISEAS...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."
498,scopolamine @CHEMICAL@ amnesia @DISEASE@ @CID@...,"[{'re_label': 'CID', 'head_ent': {'label': 'CH..."


## Dealing with unstructured text

### Load dataset of predicted output

In [13]:
test_eval_dataset = load_dataset(
        config['dataset_vars']['type'], 
        data_dir=home_dir.joinpath("data/testing_eval/cdr_seq2rel"),
        data_files={'test_trained': 'test_trained.csv', 
                    'test_untrained':'test_untrained.csv', 
                    'test_trained_gc1': 'test_trained_gc1.csv',
                    'test_re100':'test_re100.csv',
                    'test_re53_ner53':'test_re53_ner53.csv'},
        column_names=["input", "expected output", "predicted"]
        )

test_eval_dataset_trained = test_eval_dataset['test_trained'].select(range(1,501)) # remove first row that contains column names
test_eval_dataset_trained_gc1 = test_eval_dataset['test_trained'].select(range(1,501)) # remove first row that contains column names
test_eval_dataset_untrained = test_eval_dataset['test_untrained'].select(range(1,501)) # remove first row that contains column names
test_eval_dataset_re100 = test_eval_dataset['test_re100'].select(range(1,11)) # remove first row that contains column names
test_eval_dataset_re53_ner53 = test_eval_dataset['test_re53_ner53'].select(range(1,11)) # remove first row that contains column names

In [14]:
for row in test_eval_dataset_re53_ner53:
    msg.info(f"Artificialy alterd output:      |      exact same as input? {row['expected output']==row['predicted']}")
    print(row['predicted'], '\n')

[38;5;4mℹ Artificialy alterd output:      |      exact same as input? True[0m
famotidine @CHEMICAL@ delirium @DISEASE@ @CID@ 

[38;5;4mℹ Artificialy alterd output:      |      exact same as input? True[0m
indomethacin ; idm @CHEMICAL@ hypotension @DISEASE@ @CID@ 

[38;5;4mℹ Artificialy alterd output:      |      exact same as input? True[0m
tacrolimus @CHEMICAL@ systemic sclerosis ; ssc @DISEASE@ @CID@ corticosteroid ; corticosteroids @CHEMICAL@ systemic sclerosis ; ssc @DISEASE@ @CID@ cyclosporine @CHEMICAL@ thrombotic microangiopathy @DISEASE@ @CID@ 

[38;5;4mℹ Artificialy alterd output:      |      exact same as input? True[0m
methamphetamine @CHEMICAL@ psychosis ; psychotic symptoms @DISEASE@ @CID@ methamphetamine @CHEMICAL@ bipolar disorder @DISEASE@ @CID@ methamphetamine @CHEMICAL@ antisocial personality disorder ; antisocial personality @DISEASE@ @CID@ 

[38;5;4mℹ Artificialy alterd output:      |      exact same as input? True[0m
levodopa @CHEMICAL@ dyskinetic ; dyski

### Extracting triples from actual model output

In [15]:
test_eval_dataset = test_eval_dataset_trained_gc1

re_labels = ['@CID@']
for i in test_eval_dataset:
    text = i['predicted']
    try:
        extract_relation_triples(text, re_labels)
    except ValueError:
        msg.info("expected output:")
        print(i["expected output"])
        msg.info("predicted output:")
        print(text)
        print('\n')

[38;5;4mℹ expected output:[0m
phosphorus @CHEMICAL@ cholestatic ; cholestasis @DISEASE@ @CID@ phosphorus @CHEMICAL@ acute hepatitis @DISEASE@ @CID@ phosphorus @CHEMICAL@ acute liver failure @DISEASE@ @CID@
[38;5;4mℹ predicted output:[0m
yellow phosphorus ; yellow phosphorus @CHEMICAL@ cholestasis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ hepatotoxicity ; hepatitis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ fireworks @CHEMICAL@ cholestasis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ cholestasis @DISEASE@ @CID@


[38;5;3m⚠  len entity text 5 | ['lomustine ; ccnu @CHEMICAL@ neutropenia
@DISEASE@ ', ' lomustine ; ccnu @CHEMICAL@ hemorrhagic @DISEASE@ ', ' lomustine
; ccnu @CHEMICAL@ cystitis @DISEASE@ ', ' cyclophosphamide ; ctx @CHEMICAL@
neutropenia @DISEASE@ ', ' cyclophosphamide ; cnu @CHEMICAL@ hemorhagicnu
@CHEMICAL@ cystitis @'][0m
[38;5;3m⚠  len relation labels 4 | ['@CID@', '@CID@', '@CID@', '@CID@'][0m
[38

### Changing `extract_relation_triples()` to deal with unstructured text

In [16]:
def extract_relation_triples(text: str, ner_labels: list[str], re_labels: list[str], keep_coreforents: bool = False) -> list[dict]:
    '''
    This function extracts the relationship triples out of structerd text. 
    This function assumes that the NER labels are in this structure: @label@
    
    input:
    text: The structerd text as a string.
    re_labels: The relationship labels.

    returns:
    A list of dictionaries
    '''
    ##### Check if text is structered #####
    split_on_space_text = text.split(" ")
    
    # check if text ends with a relation label
    if split_on_space_text[-1] not in re_labels:
        raise ValueError(f"Text is unstructured: '{text}'\nText should end with a relationship label found in re_labels: {re_labels}.\n")
        
    # Check if text has atleast two entity labels and one relation label
    count_ner_labels = sum([split_on_space_text.count(label) for label in ner_labels])
    count_re_labels = sum([split_on_space_text.count(label) for label in re_labels])
    if count_ner_labels < 2 or count_re_labels < 1:
        raise ValueError(f"Text is unstructured: '{text}'\nText should have atleast 2 ner_labels: {ner_labels} and 1 re_label: {re_labels} to make a relationship.\n")
        
    # Check if text has the right amount of entity and relation labels
    if count_re_labels*2 != count_ner_labels:
        raise ValueError(f"Text is unstructured: '{text}'\nText should have 2 times the ner_labels: {ner_labels} then there are re_label: {re_labels}. currently: ner labels: {count_ner_labels} | re labels: {count_re_labels}\n")
    
    ##### Extracting relation triples #####
    # Split the input text into relation segments
    relation_segments = split_on_labels(text, re_labels)
    
    # Remove the last empty segment if it exists
    if not relation_segments[-1].strip():
        relation_segments = relation_segments[:-1]

    # Map relation label to entity text
    entity_texts = relation_segments[::2] # All uneven elements
    relation_labels = relation_segments[1::2] # All even elements
        
    # Initialize a list to hold the relation triples
    relations = []
    
    for entity_text, re_label in zip(entity_texts, relation_labels):
        # Split head and tail entities and their labels
        head_ent, tail_ent = [handle_coreforents(ent, keep_coreforents) for ent in re.findall(r'(.+?)\s@(\w+)@', entity_text)]
        # print(f"head_ent {head_ent} | tail_ent {tail_ent}") #DEBUG
        
        re_label = re_label.split('@')[1]
        relations.append({
            're_label':re_label,
            'head_ent': {'label':head_ent[1], 'text':head_ent[0]},
            'tail_ent': {'label':tail_ent[1], 'text':tail_ent[0]}
        })
    
    return relations

In [17]:
re_labels = ['@CID@']
ner_labels = ['@CHEMICAL@', '@DISEASE@']
for i in test_eval_dataset:
    text = i['predicted']
    try:
        extract_relation_triples(text, ner_labels, re_labels)
    except ValueError as error:
        print(error)

Text is unstructured: 'yellow phosphorus ; yellow phosphorus @CHEMICAL@ cholestasis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ hepatotoxicity ; hepatitis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ fireworks @CHEMICAL@ cholestasis @DISEASE@ @CID@ yellow phosphorus ; yellow phosphorus @CHEMICAL@ cholestasis @DISEASE@ @CID@'
Text should have 2 times the ner_labels: ['@CHEMICAL@', '@DISEASE@'] then there are re_label: ['@CID@']. currently: ner labels: 9 | re labels: 4

Text is unstructured: 'lomustine ; ccnu @CHEMICAL@ neutropenia @DISEASE@ @CID@ lomustine ; ccnu @CHEMICAL@ hemorrhagic @DISEASE@ @CID@ lomustine ; ccnu @CHEMICAL@ cystitis @DISEASE@ @CID@ cyclophosphamide ; ctx @CHEMICAL@ neutropenia @DISEASE@ @CID@ cyclophosphamide ; cnu @CHEMICAL@ hemorhagicnu @CHEMICAL@ cystitis @'
Text should end with a relationship label found in re_labels: ['@CID@'].

Text is unstructured: 'fluocinolone acetonide intravitreal implant ; 0.59 mg @CHEMICAL@ ocul

## Using 🤗 Evaluate to calculate accuracy, f1, precision and recall

In [18]:
clf_metrics = evaluate.combine(["f1", "precision", "recall"])

In [19]:
clf_metrics.compute(
    predictions=[0,1,1],
    references=[1,0,1]
)

{'f1': 0.5, 'precision': 0.5, 'recall': 0.5}

All of these metrics take input as a list of ints.

documentation:
- [accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy)
- [precision](https://huggingface.co/spaces/evaluate-metric/precision)
- [recall](https://huggingface.co/spaces/evaluate-metric/recall)
- [f1](https://huggingface.co/spaces/evaluate-metric/f1)

so the extract relation triples needs to work on the input ids. Or the text needs to be converted to ints

## Converting text to ints method

Because we're doing a relation extraction from a document, evaluating the output will be a bit different from usual methods. This method will be based on if the set of entities or relation in the predictions are also found in the set of entities and relations in the reference. The sequence of entities and relations are irrelivant. 

### Named entity recoginition scoring

In [20]:
# starting point
index = 4
relation_triples[index]

[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'chronic renal failure'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'proteinuria'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': ('lithium', 'li')},
  'tail_ent': {'label': 'DISEASE', 'text': 'hypertension'}}]

In [21]:
# predictions = relation_triples[index]
# references = relation_triples[index]

# ### Define group of entities in the reference
# entities_text = list()
# for relation in references:
#     nested_list = list(relation.values())[0] # Elements can be list if there is a coreferent mention
#     # This next list comprehension is barely readable but if an element is a sublist it is flattend 
#     # if an element is a string it's kept as it is. 
#     entities_text.extend([item for sublist in nested_list.values() for item in (sublist if isinstance(sublist, list) else [sublist])])

# ### Check if a prediction is in the group of reference entities
# for pred in predictions:
#     for pred_ent in list(pred.values())[0].values():
#         if isinstance(pred_ent, list):
#         else:
#             if pred_ent in entities_text:
#                 pass
#     # Remove entity texts from the group if there is a match
# # Check if the labels match
# # 

### Relation extraxtion scoring

## Using inputs ids method

In [22]:
# Load tokenizer to access input ids
model_name = config['model_name']
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, legacy=False)

In [23]:
# voorbeeld
msg.info(f"Output in text:")
print(relations_in_structerd_text[0]+'\n')
input_ids=tokenizer(relations_in_structerd_text[0], return_tensors="pt").input_ids
msg.info(f"Output as input ids:")
print(input_ids)
print() 

msg.info("Input ids of entity labels")
for key, value in entity_types.items():
    input_ids = tokenizer(value, return_tensors="pt").input_ids
    tokens = tokenizer.tokenize(value)
    print(f"{key} : {value}\ntokens:{tokens}\ninput ids: {input_ids[0][:-1]}"+"\n")

msg.info("Input ids of relation labels")
for key, value in rel_types.items():
    input_ids = tokenizer(value, return_tensors="pt").input_ids
    tokens = tokenizer.tokenize(value)
    print(f"{key} : {value}\ntokens:{tokens}\ninput ids: {input_ids[0][:-1]}"+"\n")

[38;5;4mℹ Output in text:[0m
alpha-methyldopa @CHEMICAL@ hypotensive @DISEASE@ @CID@

[38;5;4mℹ Output as input ids:[0m
tensor([[  491,  6977,    18, 22758,    26,    32,   102,     9,  3320, 13717,
           329, 23936,  1741, 10950,   324,     7,   757,  3320,   308, 19056,
         17892,  1741,  3320,   254,  4309,  1741,     1]])

[38;5;4mℹ Input ids of entity labels[0m
chemical : @CHEMICAL@
tokens:['▁@', 'CHE', 'M', 'ICAL', '@']
input ids: tensor([ 3320, 13717,   329, 23936,  1741])

disease : @DISEASE@
tokens:['▁@', 'D', 'ISE', 'ASE', '@']
input ids: tensor([ 3320,   308, 19056, 17892,  1741])

[38;5;4mℹ Input ids of relation labels[0m
chemical induced disease : @CID@
tokens:['▁@', 'C', 'ID', '@']
input ids: tensor([3320,  254, 4309, 1741])



We can use the unique sequence of input ids of the entity and relation labels to extract the input ids of the entity text.

## Custom implementation of precision, recall and f1 for RE:

In [24]:
def expand_rel(relationship: dict) -> list[dict]:
    '''
    Expands a relationship with coferent mentions to a list of relationships. Each relationship is a unique combination of the coferent head and tail entities of the orginal relationship.
    The size of the returned list will be (amount of coferent head entities) * (amount of coferent tail entities)
    '''
    from copy import deepcopy as copy

    # Get all the head entity mentions
    if isinstance(relationship['head_ent']['text'], tuple):
        head_ent_mentions = relationship['head_ent']['text']
    else:
        head_ent_mentions = (relationship['head_ent']['text'],)

    # Get all the tail entity mentions
    if isinstance(relationship['tail_ent']['text'], tuple):
        tail_ent_mentions = relationship['tail_ent']['text']
    else:
        tail_ent_mentions = (relationship['tail_ent']['text'],)

    # print(f"head_ent_mentions: {head_ent_mentions}") # DEBUG
    # print(f"tail_ent_mentions: {tail_ent_mentions}") # DEBUG
    result = []
    # The product of all head mentions to all tail mentions as a generator
    for head_mention, tail_mention in ((head_mention,tail_mention) for head_mention in head_ent_mentions for tail_mention in tail_ent_mentions):
        # print(f"head mention: {head_mention} | tail mention: {tail_mention}") # DEBUG
        copy_rel = copy(relationship) # we'll use copy_rel as a skeleton
        copy_rel["head_ent"]["text"] = copy(head_mention)
        copy_rel["tail_ent"]["text"] = copy(tail_mention)
        # print(f"temp rel: {copy_rel}") # DEBUG
        result.append(copy_rel)

    return result
    

In [25]:
# Testing expand_rel()
print(relation_triples[2][0])

expand_rel(relation_triples[2][0])

{'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('suxamethonium', 'suxamethonium chloride', 'sch')}, 'tail_ent': {'label': 'DISEASE', 'text': ('fasciculations', 'fasciculation')}}


[{'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'suxamethonium'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculations'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'suxamethonium'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculation'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'suxamethonium chloride'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculations'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'suxamethonium chloride'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculation'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'sch'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculations'}},
 {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL', 'text': 'sch'},
  'tail_ent': {'label': 'DISEASE', 'text': 'fasciculation'}}]

In [26]:
def make_rel_hashable(rel):
    return (rel['re_label'], (rel['head_ent']['label'], rel['head_ent']['text']), (rel['tail_ent']['label'], rel['tail_ent']['text']))

In [27]:
# Testing make_rel_hashable:
print(relation_triples[2][0])

make_rel_hashable(relation_triples[2][0])

{'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('suxamethonium', 'suxamethonium chloride', 'sch')}, 'tail_ent': {'label': 'DISEASE', 'text': ('fasciculations', 'fasciculation')}}


('CID',
 ('CHEMICAL', ('suxamethonium', 'suxamethonium chloride', 'sch')),
 ('DISEASE', ('fasciculations', 'fasciculation')))

In [30]:
def map_expanded_to_rel(relationships: list[dict]) -> dict:
    result = {}
    for relation in relationships:
        for expanded_rel in expand_rel(relation):
            result[make_rel_hashable(expanded_rel)] = relation

    return result

In [31]:
# Testing map_expanded_to_rel
print(relation_triples[2])

map_expanded_to_rel(relation_triples[2])

[{'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('suxamethonium', 'suxamethonium chloride', 'sch')}, 'tail_ent': {'label': 'DISEASE', 'text': ('fasciculations', 'fasciculation')}}]


{('CID',
  ('CHEMICAL', 'suxamethonium'),
  ('DISEASE',
   'fasciculations')): {'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('suxamethonium',
    'suxamethonium chloride',
    'sch')}, 'tail_ent': {'label': 'DISEASE',
   'text': ('fasciculations', 'fasciculation')}},
 ('CID',
  ('CHEMICAL', 'suxamethonium'),
  ('DISEASE',
   'fasciculation')): {'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('suxamethonium',
    'suxamethonium chloride',
    'sch')}, 'tail_ent': {'label': 'DISEASE',
   'text': ('fasciculations', 'fasciculation')}},
 ('CID',
  ('CHEMICAL', 'suxamethonium chloride'),
  ('DISEASE', 'fasciculations')): {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL',
   'text': ('suxamethonium', 'suxamethonium chloride', 'sch')},
  'tail_ent': {'label': 'DISEASE',
   'text': ('fasciculations', 'fasciculation')}},
 ('CID',
  ('CHEMICAL', 'suxamethonium chloride'),
  ('DISEASE', 'fasciculation')): {'re_label': 'CID',
  'head_ent': {'label': 'CHEMICAL',
 

In [32]:
def re_metric(predictions: list[str], references: list[str], ner_labels: list[str], re_labels: list[str], coferent_matching: str ="relaxed"):
    
    tp = 0 # True positive count
    fp = 0 # False positive count
    fn = 0 # False negative count
    
    unstructured_text_count = 0
    
    # Define groups
    for pred_text, ref_text in zip(predictions, references):
        try:
            predicted_triples = extract_relation_triples(pred_text, ner_labels, re_labels, True)
        except ValueError: # Text is unstructured
            unstructured_text_count += 1
            continue
    
        references = extract_relation_triples(ref_text, ner_labels, re_labels, True)
        
        if coferent_matching == "relaxed":
            # Expand references
            expanded_references = []
            for ref in references:
                expanded_references.extend(expand_rel(ref))

            # Map expanded references to orginal references
            mapping = map_expanded_to_rel(references)
             
        for pred in predicted_triples:
            # print(f"Trying: {pred} in {references} ") # DEBUG
            
            if coferent_matching == "relaxed":
                for expanded_rel in expand_rel(pred):
                    if expanded_rel in expanded_references:
                        # print(f"True!") # DEBUG
                        tp=tp+1
                        references.remove(mapping[make_rel_hashable(expanded_rel)])
                        # create Expand references again 
                        expanded_references = []
                        for ref in references:
                            expanded_references.extend(expand_rel(ref))
                        continue
                    else:
                        # print(f"False! \n") #DEBUG
                        fp=fp+1
                            
            elif coferent_matching == "strict":
                if pred in references: # True positive
                    # print(f"True!") # DEBUG
                    tp=tp+1
                    references.remove(pred)
                    # print(f"Updated references: {references} \n") # DEBUG
                else: # False positive
                    # print(f"False! \n") #DEBUG
                    fp=fp+1
                
        # False negative
        print(f"Counting false negatives: {len(references)} from: {references} \n") #DEBUG
        fn+=len(references)

        print(f"Current counts: tp:{tp} | fp:{fp} | fn:{fn} \n\n") # DEBUG

    # Calculate metrics
    if (tp+fp) == 0: precision=0.0
    else: precision = tp/(tp+fp)

    if (tp+fn) == 0: recall=0.0
    else: recall = tp/(tp+fn)

    if (precision+recall) == 0: f1=0.0
    else: f1 = 2 * ((precision*recall)/(precision+recall))

    unstructured_text = unstructured_text_count/len(predictions)

    return {'re_precision':precision, 're_recall':recall, 're_f1':f1, 'unstructured':unstructured_text}

In [33]:
# Testing with dataset with know scores
test_eval_dataset = test_eval_dataset_re53_ner53
re_labels = ['@CID@']
ner_labels = ['@CHEMICAL@', '@DISEASE@']
predictions = [i['predicted'] for i in test_eval_dataset]
references = [i['expected output'] for i in test_eval_dataset]

result = re_metric(predictions, references, ner_labels, re_labels, coferent_matching="relaxed")
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Expected score should be 52.9412% for re and ner: {result}")

Counting false negatives: 0 from: [] 

Current counts: tp:1 | fp:0 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:2 | fp:1 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:5 | fp:5 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:8 | fp:7 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:9 | fp:10 | fn:0 


Counting false negatives: 3 from: [{'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('cyclophosphamide', 'cyp')}, 'tail_ent': {'label': 'DISEASE', 'text': 'cystitis'}}, {'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('cyclophosphamide', 'cyp')}, 'tail_ent': {'label': 'DISEASE', 'text': 'pain'}}, {'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': ('cyclophosphamide', 'cyp')}, 'tail_ent': {'label': 'DISEASE', 'text': 'edema'}}] 

Current counts: tp:9 | fp:16 | fn:3 


Counting false negatives: 1 from: [{'re_label': 'CID', 'head_ent': {'label': 'CHEMICAL', 'text': 'clopid

In [34]:
# Testing
test_eval_dataset = test_eval_dataset_trained_gc1
re_labels = ['@CID@']
ner_labels = ['@CHEMICAL@', '@DISEASE@']
predictions = [i['predicted'] for i in test_eval_dataset]
references = [i['expected output'] for i in test_eval_dataset]

result = re_metric(references, references, ner_labels, re_labels)
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Result with references as input: {result}")

result = re_metric(predictions, references, ner_labels, re_labels)
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Actual results: {result}")

Counting false negatives: 0 from: [] 

Current counts: tp:1 | fp:0 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:2 | fp:1 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:5 | fp:5 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:8 | fp:7 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:9 | fp:10 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:12 | fp:13 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:13 | fp:15 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:15 | fp:17 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:16 | fp:18 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:17 | fp:18 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:23 | fp:21 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:25 | fp:23 | fn:0 


Counting false negatives: 0 from: [] 

Current counts: tp:26 

## Custom implementation of precision, recall and f1 for NER:
This function should take a list of predicted outputs and a list of references. The lists are a list of strings, the strings being the decoded output of the model (using `tokenizer.batch_decode()`)

The output should be a dictionary with `key : value` pairs of `name metric : value of metric`

In [None]:
def get_group(relation_triples):
    group = []
    for rel in relation_triples:
        group.append(rel['head_ent'])
        group.append(rel['tail_ent'])
    return group

In [None]:
# testing get_group
print(f"Relation triples: {relation_triples[2]}")
get_group(relation_triples[2])

In [None]:
def split_coferents(ent):
    '''
    This function splits a entitiy with a coferent mention into two entities for each entity form.
    '''
    if isinstance(ent['text'], tuple): # Check if the entity has coferent mentions
        return tuple([{"label":ent["label"], "text":ent['text'][i]} for i in range(len(ent['text']))])
    else:
        return (ent,)

In [None]:
# Testing split_coferents
print(f"A single ent: {relation_triples[1][0]['head_ent']}")
split_coferents(relation_triples[1][0]['head_ent'])

In [None]:
def map_coferents(group):
    '''
    This function maps all forms of a coferent mentions to all it's other forms for all coferent mentions in a group of relationships.
    '''
    result = {}
    group = [split_coferents(ent) for ent in group] # Split coferents into two entities
    for ent in group:
        if isinstance(ent, tuple):
            for i in range(len(ent)):
                result[frozenset(ent[i].items())] = ent
        else:
            result[frozenset(ent.items())] = (ent,)

    return result

In [None]:
# testing map_coferents
print(f"Relation triples: {relation_triples[3]}")
group = get_group(relation_triples[3])
map_coferents(group)

In [None]:
def ner_metric(predictions: list[str], references: list[str], ner_labels: list[str], re_labels: list[str], coferent_matching: str ="relaxed") -> dict:
    '''
    Calculates the precision, recall and f1-score for document named entity recognition. 
    input:
        predictions: 
            List of decoded outputs of the model
        references: 
            List of decoded gold data
        coferent_matching: 
            Wheter to use the coferent mentions to match named entities. can be either "relaxed", "strict" or "no".
            "relaxed" Meaning that all coferent mentions might be used to match a predicted named entity to a reference entity
            "strict"  Meaning the model needs to have all coferent mentions correct to count as a match. (including the sequence)
            "no"      Meaning that coferent mentions are ignored, and only the first mentions are used.
        re_labels
            Which relation extraxtion labels are used.
    output
        a dictionary with the key value pairs of metric_name : metric value
    '''
    if coferent_matching not in ["relaxed", "strict", "no"]: 
        raise ValueError(f"'{coferent_matching}' is not a valid value for coferent_matching, Please choose one of {["relaxed", "strict", "no"]}.")

    if coferent_matching == "no":
        keep_coferents = False
    else:
        keep_coferents = True
    
    tp = 0 # True positive count
    fp = 0 # False positive count
    fn = 0 # False negative count
    unstructured_text_count = 0
    
    for pred_text, ref_text in zip(predictions, references):
        # Define groups
        try:
            pred_group = get_group(extract_relation_triples(pred_text, ner_labels, re_labels, keep_coferents))
        except ValueError:
            # Should be a logging statement here
            continue # Skip this row entirely
        
        ref_group = get_group(extract_relation_triples(ref_text, ner_labels, re_labels, keep_coferents))

        if coferent_matching == "relaxed":
            # Create mapping from a coferent mentions to all coferent mentions
            mapping_coferent = map_coferents(ref_group)
            
            # Split entities in the reference group
            ref_group = [split_coferents(ent) for ent in ref_group] # Split coferents into multiple entities 
            ref_group = [item for sublist in ref_group for item in (sublist if isinstance(sublist, tuple) else [sublist])] # Flatten list
        
            # print(f"pred_group: {pred_group}\n\nref_group: {ref_group}\n") #DEBUG
            # print(f"mapping: {mapping_coferent}\n\n\n") # DEBUG
        
        checked_coferent_pred = []
        for ent in pred_group:
            # print(f"entity: {ent}") # DEBUG
            # print(f"ref_group: {ref_group} ") # DEBUG

            if coferent_matching == "relaxed":
                # Split coferent entity
                ent_forms = split_coferents(ent)
            else:
                ent_forms = [ent]

            # print(f"\nStarting entity checking, ent_forms: {ent_forms}\n") # DEBUG
            for ent_form in ent_forms:
                # print(f"Checking if {ent_form} in {ref_group}") # DEBUG
                if ent_form in ref_group: # True positive
                    tp=tp+1
                    # print(f"True! \n") # DEBUG
                    
                    # Remove all instances of the coferent mentions
                    if coferent_matching == "relaxed": 
                        [ref_group.remove(i) for i in mapping_coferent[frozenset(ent_form.items())]] # Remove all coferent mentions from the reference group
                        # print(f"Removing coferent mentions from reference group: {[i for i in mapping_coferent[frozenset(ent_form.items())]]}\n") #DEBUG
                        checked_coferent_pred.extend([i for i in mapping_coferent[frozenset(ent_form.items())]]) # Remember which coferent mentions have been checked
                    else: 
                        ref_group.remove(ent) 
                    break # A match was found so we move on to the next entity
                    
                elif ent_form not in ref_group and ent_form not in checked_coferent_pred: # False positive
                    fp=fp+1
                    # print(f"False! \n") #DEBUG
                    break # A mismatch was found so we move on to the next entity

        # print(f"Counting false negatives. Current ref group: length:{len(ref_group)} | {ref_group}\n") # DEBUG
        # [ref_group.remove(i) for i in checked_coferent_pred if i in ref_group]
        # print(f"ref group after removeal of checked coferents: length:{len(ref_group)} | {ref_group}\n") # DEBUG

        # if coferent_matching == "relaxed": # WORK NEEDED. RELAXED MATCHING BASED ON COFERENT MENTIONS DOES NOT WORK YET!!!
            # print(f"checked_coferent_pred: {checked_coferent_pred}") #DEBUG
            # Remove all checked entities before counting false negatives
            
        fn=fn+len(ref_group) # False negative 
        # print(f"TP: {tp}, FP: {fp}, FN: {fn} \n\n\n") #DEBUG
    
    # Calculate metrics
    if (tp+fp) == 0: precision=0.0
    else: precision = tp/(tp+fp)
    
    if (tp+fn) == 0: recall=0.0
    else: recall = tp/(tp+fn)
    
    if (precision+recall) == 0: f1=0.0
    else: f1 = 2 * ((precision*recall)/(precision+recall))
    
    return {'ner_precision':precision, 'ner_recall':recall, 'ner_f1':f1}

In [None]:
# Testing with dataset with know scores
test_eval_dataset = test_eval_dataset_re53_ner53
re_labels = ['@CID@']
ner_labels = ['@CHEMICAL@', '@DISEASE@']
predictions = [i['predicted'] for i in test_eval_dataset]
references = [i['expected output'] for i in test_eval_dataset]

result = ner_metric(predictions, references, ner_labels, re_labels, "strict")
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Expected score should be 52.9412% for re and ner: {result}")

In [None]:
test_eval_dataset = test_eval_dataset_trained
re_labels = ['@CID@']
ner_labels = ['@CHEMICAL@', '@DISEASE@']
predictions = [i['predicted'] for i in test_eval_dataset]
references = [i['expected output'] for i in test_eval_dataset]

result = ner_metric(references, references, ner_labels, re_labels)
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Result with references as input: {result}")

result = ner_metric(predictions, references, ner_labels, re_labels, "strict")
result = {k: round(v * 100, 4) for k, v in result.items()}
print(f"Actual results: {result}")

## Loading the model to evaluate 

In [None]:
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    device_map=device_map