# Mistake analysis
This notebook gives a quick look into the predications the model makes that are wrong according to the reference evaluation data. 

In [1]:
from transformers import (
    T5Tokenizer,
    AutoTokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import torch
from datasets import load_dataset, concatenate_datasets
from wasabi import msg
import yaml
import re
import pandas as pd
from pathlib import Path
from os.path import abspath
from tqdm import tqdm
from typing import Set, Dict, Tuple, Union
from helper_functions import *

  from .autonotebook import tqdm as notebook_tqdm


## Setting home directory

In [2]:
home_dir = Path(abspath("")).parent
config_path = home_dir.joinpath("config/config_T5-11b_cdr_ds.yaml")
trained_model_path = home_dir.parent.joinpath("data/generative_re_model_storage_azure/ds_runs/T5-11b_cdr_17_05_24/checkpoint-100")
predictions_dataset_dir = home_dir.joinpath("data/evaluate_trained_model")

msg.info(f"Home directory: {home_dir}")
msg.info(f"Selected config: {config_path}")
msg.info(f"Selected model: {trained_model_path}")

[38;5;4mℹ Home directory: /home/lgrootde/Generative-re-tests[0m
[38;5;4mℹ Selected config:
/home/lgrootde/Generative-re-tests/config/config_T5-11b_cdr_ds.yaml[0m
[38;5;4mℹ Selected model:
/home/lgrootde/data/generative_re_model_storage_azure/ds_runs/T5-11b_cdr_17_05_24/checkpoint-100[0m


## Load Config & Dataset

In [3]:
# Load the config
with open(config_path) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
dataset = load_dataset(
    'csv',
    data_files={"validation":home_dir.joinpath(config['validation_file']).__str__()}
) 
dataset_eval = dataset["validation"]

## Load trained model and tokenizer

In [5]:
use_model = True

In [6]:
if use_model:
    model_name = config['model_name_or_path']
    device_map = {"": 0}
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, legacy=False)
    
    # Load model after training
    model = T5ForConditionalGeneration.from_pretrained(
        trained_model_path,
        device_map=device_map
    )

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on google-t5/t5-11b automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


## Run model over evaluation dataset 
And note where the model makes mistakes and where the model prediction match the reference.

In [7]:
if use_model:
    # run model over dataset
    data = {
        "input":[],
        "expected output":[],
        "predicted output":[],
        "match":[]
    }
    
    for row in tqdm(dataset_eval):
        # Generate prediction
        input_ids = tokenizer(row['input'], return_tensors="pt").input_ids.to('cuda') 
        output = model.generate(input_ids, max_new_tokens=config["generation_max_length"])
        predicted = tokenizer.decode(output[0], skip_special_tokens=True)
    
        # check for match
        if row["relations"].strip() == predicted.strip():
            match = True
        else:
            match = False
    
        # Save all information
        data["input"].append(row["input"])
        data["expected output"].append(row["relations"])
        data["predicted output"].append(predicted)
        data["match"].append(match)
        
    dataframe = pd.DataFrame.from_dict(data)

  0%|          | 0/1 [00:00<?, ?it/s]


TypeError: string indices must be integers

In [20]:
# Save dataframe   
dataset_name = (trained_model_path.parts[-3]+
                "_"+
                trained_model_path.parts[-2]+
                "_"+
                trained_model_path.parts[-1]+
                ".csv")
    
msg.info(f"dataset name: {dataset_name}")

[38;5;4mℹ dataset name: ds_runs_T5-11b_cdr_17_05_24_checkpoint-100.csv[0m


In [21]:
if use_model:
    # Save dataframe
    dataframe.to_csv(predictions_dataset_dir.joinpath(dataset_name))

## Show data for manual inspection

In [22]:
# Load dataset:
dataset_path = predictions_dataset_dir.joinpath(dataset_name)
dataframe = pd.read_csv(dataset_path)

In [23]:
def highlight_entities(input_text: str, rels_expected: Set[Dict], rels_predicted: Set[Dict]) -> str:
    # Create a set to store all entity texts for predicted relationships
    entity_texts_predicted = set()
    
    # Extract all entity texts from the predicted relationships data
    for rel in rels_predicted:
        head_ent_text = rel['head_ent']['text']
        if isinstance(head_ent_text, tuple):
            for text in head_ent_text:
                entity_texts_predicted.add(text)
        else:
            entity_texts_predicted.add(head_ent_text)
        
        tail_ent_text = rel['tail_ent']['text']
        if isinstance(tail_ent_text, tuple):
            for text in tail_ent_text:
                entity_texts_predicted.add(text)
        else:
            entity_texts_predicted.add(tail_ent_text)

    # Highlight the entity texts for predicted relationships in the input text with pastel blue color
    highlighted_text = input_text
    for entity_text in entity_texts_predicted:
        highlighted_text = re.sub(r'\b{}\b'.format(re.escape(entity_text)),
                                  '\033[48;2;173;216;230;38;2;0;0;0m{}\033[00m'.format(entity_text),
                                  highlighted_text, flags=re.IGNORECASE)
    
    # Create a set to store all entity texts for expected relationships
    entity_texts_expected = set()
    
    # Extract all entity texts from the expected relationships data
    for rel in rels_expected:
        head_ent_text = rel['head_ent']['text']
        if isinstance(head_ent_text, tuple):
            for text in head_ent_text:
                entity_texts_expected.add(text)
        else:
            entity_texts_expected.add(head_ent_text)
        
        tail_ent_text = rel['tail_ent']['text']
        if isinstance(tail_ent_text, tuple):
            for text in tail_ent_text:
                entity_texts_expected.add(text)
        else:
            entity_texts_expected.add(tail_ent_text)

    

    # Highlight the entity texts for expected relationships in the input text with pastel green color
    for entity_text in entity_texts_expected:
        highlighted_text = re.sub(r'\b{}\b'.format(re.escape(entity_text)),
                                  '\033[48;2;144;238;144;38;2;0;0;0m{}\033[00m'.format(entity_text),
                                  highlighted_text, flags=re.IGNORECASE)

    return highlighted_text

In [25]:
print("Entities of expected output colored \033[48;2;144;238;144;38;2;0;0;0mgreen\033[00m, of predicted \033[48;2;173;216;230;38;2;0;0;0mblue\033[00m\n")
for row in dataframe[dataframe["match"]==False].iterrows():
    try:
        rels_expected=extract_relation_triples(row[1]["expected output"], config["ner_labels"], config["re_labels"], True)
        rels_predicted=extract_relation_triples(row[1]["predicted output"], config["ner_labels"], config["re_labels"], True)
    except ValueError:
        continue
    
    msg.info("input text")
    print(highlight_entities(row[1]["input"], rels_expected, rels_predicted))
    msg.good("Expected output:")
    print(row[1]["expected output"])
    msg.info("Actual output:")
    print(row[1]["predicted output"])
    print("\n\n")

Entities of expected output colored [48;2;144;238;144;38;2;0;0;0mgreen[00m, of predicted [48;2;173;216;230;38;2;0;0;0mblue[00m

[38;5;4mℹ input text[0m
[48;2;173;216;230;38;2;0;0;0mtricuspid valve regurgitation[00m and [48;2;173;216;230;38;2;0;0;0mlithium carbonate[00m toxicity in a newborn infant. A newborn with massive tricuspid regurgitation, [48;2;173;216;230;38;2;0;0;0matrial flutter[00m, [48;2;173;216;230;38;2;0;0;0mcongestive heart failure[00m, and a high serum lithium level is described. This is the first patient to initially manifest tricuspid regurgitation and [48;2;173;216;230;38;2;0;0;0matrial flutter[00m, and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy. Sixty-three percent of these infants had tricuspid valve involvement. [48;2;173;216;230;38;2;0;0;0mlithium carbonate[00m may be a factor in the increasing incidence of congenital heart disease when taken during early pregnan