In [32]:
from typing import List, Set, Dict
import os
import json
import pandas as pd
import numpy as np
import re

In [33]:
# Параметры скрипта

# DOMAIN = 'movie'
# DOMAIN = 'computer'
DOMAIN = 'nature'

In [34]:
class Jaccard_metric:

    def clean_list(self, A):
        AA = []
        if A is not None:
            for a in A:
                if isinstance(a, dict) and 'sub' in a.keys() and 'rel' in a.keys() and 'obj' in a.keys():
                    AA.append(a)

        return AA


    def count_matching(self, A, B):
        """
        Функция подсчитывает количество одинаковых триплетов
        """

        count = 0
        A = self.clean_list(A)
        B = self.clean_list(B)

        for i in range(len(A)):
            a = A[i]
            for j in range(len(B)):
                b = B[j]
                if a['sub'] == b['sub'] and a['rel'] == b['rel'] and a['obj'] == b['obj']:
                    count += 1
                    continue

        return count


    def count_unique(self, A, B):
        """
        Функция, которая подсчитывает количество уникальных триплетов
        """

        A = self.clean_list(A)
        B = self.clean_list(B)

        return len(A) + len(B) - self.count_matching(A, B)


    def get_metric(self, expected: List[Set], predicted: List[Set]) -> float:

        try:
            X = expected
            Y = predicted
            return self.count_matching(X, Y) / self.count_unique(X, Y)
        except:
            return 0
        
    
    def read_triples(self, path: str) -> List[List[Set]]:
        with open(path, 'r') as f:
            data = f.readlines()

        result = []
        for i_text in data:
            i_text = json.loads(i_text)

            triples = i_text['triples']
            result.append(triples)

        return result
    

    def model_metric(self, expected_path: str, predicted_path: str) -> float:
        """Возвращает среднее значение меры Жаккара по двум наборам триплетов. 
        Триплеты считаются одинаковыми при полном совпадении"""

        expected_list = self.read_triples(expected_path)
        predicted_list = self.read_triples(predicted_path)
        metric_arr = []

        for expected, predicted in zip(expected_list, predicted_list):
            metric = self.get_metric(expected, predicted)
            metric_arr.append(metric)

        return np.array(metric_arr).mean()


In [35]:
jaccard_metric = Jaccard_metric()
jaccard_metric_base = jaccard_metric.model_metric(
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'),
    os.path.join('artifacts', DOMAIN, 'triples_base.jsonl'),
)
jaccard_metric_ft = jaccard_metric.model_metric(
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'),
    os.path.join('artifacts', DOMAIN, 'triples_ft.jsonl'),
)
jaccard_metric_ft_pp = jaccard_metric.model_metric(
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'),
    os.path.join('artifacts', DOMAIN, 'triples_ft_pp.jsonl'),
)

jaccard_metric_base, jaccard_metric_ft, jaccard_metric_ft_pp

(0.0, 0.24699221514508135, 0.24309978768577492)

In [36]:
class GraphEvaluator:
    def __init__(self):
        pass

    def get_metrics(self, gold: Set, pred: Set) -> float:
        """
        Method to calculate precision, recall and f1:
            Precision is calculated as correct_triples/predicted_triples and
            Recall as correct_triples/gold_triples
            F1 as the harmonic mean of precision and recall.
        :param gold: items in the gold standard
        :param pred: items in the system prediction
        :return:
            p: float - precision
            r: float - recall
            f1: float - F1
        """
        if len(pred) == 0:
            return 0, 0, 0
        p = len(gold.intersection(pred)) / len(pred)
        r = len(gold.intersection(pred)) / len(gold)
        if p + r > 0:
            f1 = 2 * ((p * r) / (p + r))
        else:
            f1 = 0
        return p, r, f1
    

    def normalize_triple(self, sub_label: str, rel_label: str, obj_label: str) -> str:
        """
        Normalize triples for comparison in precision, recall calculations
        :param sub_label: subject string
        :param rel_label: relation string
        :param obj_label: object string
        :return: a normalized triple as a single concatenated string
        """
        # remove spaces and underscores and make lower case
        sub_label = re.sub(r"(_|\s+)", '', sub_label).lower()
        rel_label = re.sub(r"(_|\s+)", '', rel_label).lower()
        obj_label = re.sub(r"(_|\s+)", '', obj_label).lower()
        # concatenate them to a single string
        tr_key = f"{sub_label}{rel_label}{obj_label}"
        return tr_key
    
    def filter_triplets(self, file_path, pred=False) -> List[str]:
        """
        Filter triplets from a file.
        :param file_path: path to the file containing triplets
        :param pred: flag indicating whether the file contains predictions or gold standard triplets
        :return: filtered triplets as a list, or filtered triplets and model name as a tuple if pred=True
        """
        with open(file_path, 'r') as f:
            data = f.readlines()
        filtered_triplets = []
        ############################################
        # здесь можно указать срез по данным, чтобы сравнивались фактически обработанные тексты
        ############################################
        for i_line in data:  #[:10]: # срез по данным. Кол-во n_run из тетрадки по генерации триплетов
            json_data = json.loads(i_line)
            for i_triple in json_data['triples']:
                try:
                    if i_triple['rel'] not in ["cost", "main subject", "publication date"]:
                        if i_triple['obj'] != "":
                            filtered_triplets.append(i_triple)
                except:
                    continue

        if pred:
            model = json_data.get('model', f"{json_data.get('model1', '')}+{json_data.get('model2', '')}")
            return [filtered_triplets, model]
        
        return [filtered_triplets]


    def evaluate(self, pred_file_path: str, gold_file_path: str) -> Dict:
        """
        Evaluate the performance of a prediction file against a gold standard file.
        :param pred_file_path: path to the prediction file
        :param gold_file_path: path to the gold standard file
        :return: a dictionary containing the model name, precision, recall, and F1 score
        """
        pred_temp = self.filter_triplets(pred_file_path, pred=True)
        pred = pred_temp[0]
        model = pred_temp[1]
        gold = self.filter_triplets(gold_file_path)[0]
        pred_ready = {self.normalize_triple(tr['sub'], tr['rel'], tr['obj']) for tr in pred}
        gold_ready = {self.normalize_triple(tr['sub'], tr['rel'], tr['obj']) for tr in gold}
        p, r, f1 = self.get_metrics(gold_ready, pred_ready)
        return {"model": model, "precision": p, "recall": r, "f1": f1}

In [37]:
evaluator = GraphEvaluator()

metrics_base = evaluator.evaluate(
    os.path.join('artifacts', DOMAIN, 'triples_base.jsonl'), 
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'))

metrics_ft = evaluator.evaluate(
    os.path.join('artifacts', DOMAIN, 'triples_ft.jsonl'), 
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'))

metrics_ft_pp = evaluator.evaluate(
    os.path.join('artifacts', DOMAIN, 'triples_ft_pp.jsonl'), 
    os.path.join('artifacts', DOMAIN, 'triples_gt.jsonl'))

metrics_base, metrics_ft

({'model': 'llama-2-7b.Q4_0.gguf+llama-2-7b.Q4_0.gguf',
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0},
 {'model': 'Llama-2-7b-m1ft-q4.gguf+Llama-2-7b-m2ft-q4.gguf',
  'precision': 0.5555555555555556,
  'recall': 0.28125,
  'f1': 0.37344398340248963})

In [38]:
metrics_df = pd.DataFrame({
    'domain': [DOMAIN, DOMAIN, DOMAIN, DOMAIN],
    'knowknowledge_graph': ['ground truth','base', 'finetuned', 'finetuned_postprocessed'],
    'jaccard': [1.0, jaccard_metric_base, jaccard_metric_ft, jaccard_metric_ft_pp],
    'precision': [1.0, metrics_base['precision'], metrics_ft['precision'], metrics_ft_pp['precision']],
    'recall': [1.0, metrics_base['recall'], metrics_ft['recall'], metrics_ft_pp['recall']],
    'f1': [1.0, metrics_base['f1'], metrics_ft['f1'], metrics_ft_pp['f1']],
    })

fn = os.path.join('artifacts', DOMAIN, 'kg_metrics.csv')
metrics_df.to_csv(fn, index=False)

pd.read_csv(fn)


Unnamed: 0,domain,knowknowledge_graph,jaccard,precision,recall,f1
0,nature,ground truth,1.0,1.0,1.0,1.0
1,nature,base,0.0,0.0,0.0,0.0
2,nature,finetuned,0.246992,0.555556,0.28125,0.373444
3,nature,finetuned_postprocessed,0.2431,0.68254,0.26875,0.38565
