In [1]:
import json, re, ast, copy, jsonschema
import pandas as pd
import numpy as np

from json import JSONDecodeError
from fix_busted_json import repair_json
from collections import Counter
from statistics import mean
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

In [2]:
from eval_script import multi_span_evaluate, count_overlap, compute_scores

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
from utils import clean_italian_span, tokenize_italian_text, stops, italian_punctuation, check_symbols

## 1. Preprocessing function

In [4]:
def is_nested_list(input_list):
    return all(isinstance(el, list) for el in input_list)

def is_str_list(input_list):
    return all(isinstance(el, str) for el in input_list)

def contains_no_list_elements(input_list):
    return any(not isinstance(el, list) for el in input_list)

def contains_list_elements(input_list):
    return any(isinstance(el, list) for el in input_list)

def contains_no_str_elements(input_list):
    return any(not isinstance(el, str) for el in input_list)
            
def contains_str_elements(input_list):
    return any(isinstance(el, str) for el in input_list)

In [5]:
def flatten_nested_list(input_list):
    
    final_list = copy.deepcopy(input_list)
    
    while 1:
        if not contains_list_elements(final_list):
            break
            
        new_list = []
        for i, el in enumerate(final_list):
            if isinstance(el, list):
                new_list += el
            else:
                new_list.append(el)
        
        final_list = new_list
        
    for i, span in enumerate(final_list):
        final_list[i] = str(span)
        
    return final_list

In [6]:
def flatten_nested_dictionary(input_dict):
    
    final_list = list(input_dict.values())
    
    while 1:
        is_nested = False
        for el in final_list:
            if isinstance(el, dict):
                is_nested = True

        if not is_nested:
            break

        new_list = []
        for i, el in enumerate(final_list):
            if isinstance(el, dict):
                new_list += list(el.values())
            else:
                new_list.append(el)
                
        final_list = new_list

        
    return final_list

In [7]:
def check_dict_values(datum):
    for tag in datum:
        if isinstance(datum[tag], dict):
            return False
        for i, group in enumerate(datum[tag]):
            if isinstance(group, list):
                for j, span in enumerate(group):
                    if isinstance(span, dict):
                        return False
            elif isinstance(group, dict):
                return False
                        
    return True 

In [8]:
def check_none_values(datum):
    for tag in datum:
        if isinstance(datum[tag], list):
            to_delete = []
            for i, el in enumerate(datum[tag]):
                
                if isinstance(datum[tag][i], dict):
                    for k, v in datum[tag][i].items():
                        if v == None:
                            return True
                            
                elif isinstance(datum[tag][i], list):
                    to_delete2 = []
                    
                    for j, el2 in enumerate(datum[tag][i]):
                        if isinstance(datum[tag][i][j], dict):
                            for k, v in datum[tag][i][j].items():
                                if v == None:
                                    return True

                        elif datum[tag][i][j] == None:
                            return True
                               
                elif datum[tag][i] == None:
                    return True
                               
        elif datum[tag] == None:
            return True

    return False

In [9]:
def remove_none_values(datum):
    for tag in datum:
        if isinstance(datum[tag], list):
            to_delete = []
            for i, el in enumerate(datum[tag]):
                if isinstance(datum[tag][i], dict):
                    
                    to_delete2 = []
                    for k, v in datum[tag][i].items():
                        if v == None:
                            to_delete2.append(k)

                    for k in to_delete2:
                        del datum[tag][i][k]
                            
                elif isinstance(datum[tag][i], list):
                    
                    to_delete2 = []
                    for j, el2 in enumerate(datum[tag][i]):
                        if isinstance(datum[tag][i][j], dict):
                            
                            to_delete3 = []
                            for k, v in datum[tag][i][j].items():
                                if v == None:
                                    to_delete3.append(k)

                            for k in to_delete3:
                                del datum[tag][i][j][k]

                        elif datum[tag][i][j] == None:
                            to_delete2.append(j)

                    to_delete2.reverse()
                    for j in to_delete2:
                        del datum[tag][i][j]
                            
                elif datum[tag][i] == None:
                    to_delete.append(i)

            to_delete.reverse()

            for i in to_delete:
                del datum[tag][i]
                    
                    
        elif isinstance(datum[tag], dict):
            to_delete = []
            
            for k, v in datum[tag].items():
                if v == None:
                    to_delete.append(k)

            for k in to_delete:
                del datum[tag][k]
            
        elif datum[tag] == None:
            datum[tag] = []

    return datum

In [10]:
def correct_dict_values(datum):
    for tag in datum:
        if isinstance(datum[tag], dict):
            datum[tag] = flatten_nested_dictionary(datum[tag])
            
        for i, group in enumerate(datum[tag]):
            if isinstance(group, list):
                for j, span in enumerate(group):
                    if isinstance(span, dict):
                        datum[tag][i][j] = flatten_nested_dictionary(span)
            else:
                if isinstance(group, dict):
                    datum[tag][i] = flatten_nested_dictionary(group)
    return datum

In [11]:
def check_tag_values(datum):

    for tag in available_tags:
    
        if tag in relation_tags:
            if contains_no_list_elements(datum[tag]):
                return False
            else:
                for group in datum[tag]:
                    if contains_no_str_elements(group):
                        return False
        else:
            if contains_no_str_elements(datum[tag]) or isinstance(datum[tag], str):
                return False
    return True

In [12]:
def correct_tag_values(datum):
    
    for tag in available_tags:
        
        if tag in relation_tags:

            if isinstance(datum[tag], str):
                datum[tag] = [[datum[tag]]]
            else:
            
                #relation tags
                for i, group in enumerate(datum[tag]):
                    if not isinstance(group, list):
                        datum[tag][i] = [str(group)]
                    else:
                        datum[tag][i] = flatten_nested_list(datum[tag][i])
        
        else:
            if isinstance(datum[tag], str):
                datum[tag] = [datum[tag]]
            else:
                datum[tag] = flatten_nested_list(datum[tag])

    return datum

In [13]:
def check_schema(datum):
    try:
        jsonschema.validate(datum, schema)
            
    except jsonschema.exceptions.ValidationError as e:
        #print("Validation")
        return False
    except jsonschema.exceptions.SchemaError as ex:
        #print("Schema")
        return False
    return True

In [14]:
def correct_schema(datum):
    
    not_included_tags = [tag for tag in available_tags if tag not in datum]
    
    keys = list(datum.keys())
    oos_tags = [key for key in keys if key not in available_tags]
    
    for tag in oos_tags:
        del datum[tag]
    
    for tag in not_included_tags:
        datum[tag] = []
    
    return datum

In [15]:
def replace_all_occurrences(text, to_replace, replace):
    
    occurrences = [m.start() for m in re.finditer(to_replace, text)]
    occurrences.reverse()
        
    for i in occurrences:
        text = text[:i] + replace + text[i+1:]
        
    return text

In [16]:
def get_json_from_string(text):
    
    text = text.replace("\n", " ")
    text = text.replace("″", "\"")
    
    start = text.find('{')
    end = text.rfind('}')

    if start == -1 and end != -1:

        start = 10000
        for tag in available_tags:
            index = text.find(f'"{tag}')
            if start > index and index != -1:
                start = index
                
        corrected_text = "{" + text[start:end+1]
        
    elif start == -1 and end == -1:

        if "•" in text:
            result = {} 
            for tag in available_tags:
                tag_index = text.find(tag)

                if tag_index != -1:
                    text_value = text[tag_index + len(tag) + 1:]

                    endline_index = text_value.find("•")
                    if endline_index == -1:
                        endline_index = text_value.find("\n")
                    
                    text_value = text_value[:endline_index]
                    final_index = text_value.rfind("]")

                    if final_index != -1:
                        text_value = text_value[:final_index+1]
                        try:
                            result[tag] = ast.literal_eval(text_value)
                        except Exception as e:
                            pass

                    else:
                        corrected_text = ""

            if len(result) > 0:
                return result, True
            else:
                corrected_text = ""

        else:

            start = 10000
            for tag in available_tags:
                index = text.find(f'"{tag}')
                if start > index and index != -1:
                    start = index
            
            #start = text.find('"AUT')
            end = text.rfind(']')
    
            if start < end:
                corrected_text = "{" + text[start:end+1] + "}"

            else:
                corrected_text = ""
            
                
    elif start != -1 and end == -1:
        end = text.rfind(']')
        corrected_text = text[start:end+1] + "}"

    elif start != -1 and end != -1:
        if start > end:
            start = text.find('{')
            end = text.rfind('}')
            
            if start > end:
                return empty_value, False

        
        corrected_text = text[start:end+1]
        #print(corrected_text)
        corrected_text = replace_all_occurrences(corrected_text, ",]", "")
        corrected_text = replace_all_occurrences(corrected_text, ", ]", "")
        corrected_text = replace_all_occurrences(corrected_text, "] \"", "],")
        corrected_text = corrected_text.replace("\[", "[")
        corrected_text = corrected_text.replace("]\"]", "]]")
    
    corrected_text = corrected_text.replace("\\_", "_")
    #print(corrected_text)

    try:
        result = json.loads(repair_json(corrected_text))
        return result, True

    except JSONDecodeError  as e:
        #print("Error JSON", e)
        return empty_value, False

    except Exception as e:
        #print("Exception", e)
        return empty_value, False

  corrected_text = corrected_text.replace("\[", "[")


## 2.Schema definition

In [17]:
available_tags = ('LOC', 'AUT', 'PAR', 'OBJ', 'VIC', 'AUTG', 'VICG') 
relation_tags = ('AUT', 'VIC', 'OBJ')

In [18]:
schema = {
    "type": "object",
    "properties": {
                    tag: {"type": ["array", "string", "object"]} for tag in available_tags
    },
    "required": list(available_tags),
    "additionalProperties" : False
}

In [19]:
empty_value = {
                tag: [] for tag in available_tags
}

## 3. Load data to evaluate

In [97]:
model = "mixtral-8x7B-Instruct-v0.1"
dataset = "validation_set"
combination = (3, 0, 5, 2)

In [98]:
if len(combination) == 0:
    filename = f"llms_predictions\\{dataset}\\{model}\\predicted_0.json"
else:
    char_str = [str(index) for index in combination]
    filename = f"llms_predictions\\{dataset}\\{model}\\predicted_{len(combination)}_{''.join(char_str)}.json"

In [99]:
filename

'llms_predictions\\validation_set\\mixtral-8x7B-Instruct-v0.1\\predicted_4_3052.json'

In [100]:
with open(filename, "r") as f:
    predicted_data = json.load(f)

In [101]:
with open(f"data//{dataset}.json", "r") as f:
    annotated_data = json.load(f)

In [102]:
len(predicted_data)

190

## 4. Prediction Parse Function

In [86]:
def parse_predictions(predicted_data, annotated_data, start_index=0):

    assert len(predicted_data) == len(annotated_data)
    assert start_index < len(predicted_data)

    not_flattened_predictions = {tag: [] for tag in available_tags}
    not_flattened_groundtruth = {tag: [] for tag in available_tags}
    
    ### Mistral models
    total_predictions = {}
    total_groundtruth = {}
    
    groundtruth = {
                        tag: {} for tag in available_tags
    }
    
    predictions = {
                        tag: {} for tag in available_tags
    }
    incorrect_outputs = []
    not_formatable_jsons = []
    invalid_jsons = []
    not_correctables = []
    
    for i, el in enumerate(predicted_data[start_index:]):
        index = start_index + i

        last_occurrence = 0
        first_index = 0
        if model.startswith("mi"):
            last_occurrence = el['completion'].rfind('[/INST]')
            first_index = last_occurrence + 7
        
        if last_occurrence == -1:
            incorrect_outputs.append(index)
            predicted_data[index]['predictions'] = empty_value
            
        else:
            completion = el['completion'][first_index:]
            value, is_json = get_json_from_string(completion)
    
            if value == empty_value:
                if not is_json:
                    not_formatable_jsons.append(index)
                predicted_data[index]['predictions'] = empty_value
            else:
                predicted_data[index]['predictions'] = value
    
                try:
    
                    if check_none_values(predicted_data[index]['predictions']):
                        predicted_data[index]['predictions'] = remove_none_values(predicted_data[index]['predictions'])
    
                        if check_none_values(predicted_data[index]['predictions']):
                            predicted_data[index]['predictions'] = empty_value
                    
                    if not check_schema(predicted_data[index]['predictions']):
                        invalid_jsons.append(index)
                        predicted_data[index]['predictions'] = correct_schema(predicted_data[index]['predictions'])
    
                        if not check_schema(predicted_data[index]['predictions']):
                            not_correctables.append(index) 
                            predicted_data[index]['predictions'] = empty_value
    
                    if not check_dict_values(predicted_data[index]['predictions']):
                        corrected_predictions = correct_dict_values(predicted_data[index]['predictions'])
                        predicted_data[index]['predictions'] = corrected_predictions
    
                        if not check_dict_values(predicted_data[index]['predictions']):
                            not_correctables.append(index) 
                            predicted_data[index]['predictions'] = empty_value
    
                    if not check_tag_values(predicted_data[index]['predictions']):            
                        corrected_predictions = correct_tag_values(predicted_data[index]['predictions'])
                        predicted_data[index]['predictions'] = corrected_predictions
    
                        if not check_tag_values(corrected_predictions):
                            not_correctables.append(index) 
                            predicted_data[index]['predictions'] = empty_value
                            
                
                except jsonschema.exceptions.ValidationError as e:
                    not_correctables.append(index)
                    predicted_data[index]['predictions'] = empty_value
    

        for tag, elements in predicted_data[index]['predictions'].items():
            if tag in available_tags:
                not_flattened_predictions[tag].append(elements)
                predictions[tag][index] = flatten_nested_list(elements)
                total_predictions[f"{tag}_{index}"] = flatten_nested_list(elements)
    
        for tag, elements in annotated_data[index]['annotation'].items():
            if tag in available_tags:
                not_flattened_groundtruth[tag].append(elements)
                groundtruth[tag][index] = flatten_nested_list(elements)
                total_groundtruth[f"{tag}_{index}"] = flatten_nested_list(elements)


    invalid_outputs = len(incorrect_outputs) + len(not_formatable_jsons) + len(not_correctables)

    return total_predictions, total_groundtruth, not_flattened_predictions, not_flattened_groundtruth, invalid_outputs

In [103]:
total_predictions, total_groundtruth, not_flattened_predictions, not_flattened_groundtruth, invalid = parse_predictions(predicted_data, annotated_data)

In [104]:
evaluation, _, _ = multi_span_evaluate(total_predictions, total_groundtruth)

In [105]:
evaluation

{'em_precision': 54.920212765957444,
 'em_recall': 61.427863163113535,
 'em_f1': 57.99204306108121,
 'overlap_precision': 65.17546934869122,
 'overlap_recall': 75.16472274003263,
 'overlap_f1': 69.81458426319898}

## 5. Linkage sets metrics

In [30]:
def split_apostrophe(span):

    tokenized = []
    span_copy = copy.deepcopy(span)

    if span == "" or span == "'":
        return [""]

    if span_copy[0] == "'":
        tokenized = ["'", span_copy[1:]]
    else:
        tokenized = [span_copy]

    if tokenized[-1][-1] == "'":
        tokenized[-1] = tokenized[-1][:-1]
        tokenized.append("'")

    return tokenized

In [31]:
def clear_annotation(to_clear):
    to_return = []
    
    for to_tokenize in to_clear:
        tokenized = [clean_italian_span(el).lower() for el in tokenize_italian_text(to_tokenize)]

        not_apostrofe = []
        for token in tokenized:
            not_apostrofe += split_apostrophe(token)
        to_return += [el for el in not_apostrofe if el != "" and el not in stops and el not in italian_punctuation]

    return to_return

In [32]:
def jaccard_similarity(set1, set2):
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union != 0 else 0.0

In [33]:
from itertools import product

def overall_sublist_distances(groundtruth, prediction, threshold=0):
    total_similarity = 0.0

    possible_pairs = []

    groundtruth_copy = copy.deepcopy(groundtruth)
    prediction_copy = copy.deepcopy(prediction)

    for couple in product(groundtruth_copy, prediction_copy):
        element1, element2 = couple
        similarity = jaccard_similarity(set(clear_annotation(element1)), set(clear_annotation(element2)))
        possible_pairs.append((element1, element2, similarity))
    
    possible_pairs.sort(key=lambda x: x[2])
    possible_pairs.reverse()

    final_pairs = []
    for pair in possible_pairs:
        if pair[0] in groundtruth_copy and pair[1] in prediction_copy:
            if pair[2] > threshold:
                groundtruth_copy.remove(pair[0])
                prediction_copy.remove(pair[1])
                final_pairs.append(pair)
    
    return final_pairs, groundtruth_copy, prediction_copy

In [118]:
threshold = 0.3

In [119]:
total_pairs = []
total_false_negatives = []
total_false_positives = []

for tag in relation_tags:
    for el1, el2 in zip(not_flattened_groundtruth[tag], not_flattened_predictions[tag]):
        pairs, false_negatives, false_positives = overall_sublist_distances(el1, el2, threshold=threshold)
        total_pairs += pairs
        total_false_negatives += false_negatives
        total_false_positives += false_positives

values = [el[2] for el in total_pairs]

In [120]:
precision = len(total_pairs) / (len(total_pairs) + len(total_false_positives))
recall = len(total_pairs) / (len(total_pairs) + len(total_false_negatives))

f1_score = (2*precision*recall) / (precision + recall) 

In [121]:
print(f"Precision: {precision:.2f}\nRecall: {recall:.2f}\nF1-score: {f1_score:.2f}")

Precision: 0.51
Recall: 0.75
F1-score: 0.60
