In [4]:
import abc
import re
from pathlib import Path
from typing import Any
from tqdm import tqdm
from collections import defaultdict
from typing import Dict, List, Tuple, Union, Literal
import copy

from refpydst.normalization.data_ontology_normalizer import DataOntologyNormalizer
from refpydst.db.ontology import Ontology
from refpydst.utils.dialogue_state import update_dialogue_state
from refpydst.prompt_formats.python.completion_parser import (
    iterative_parsing,
    parse_python_completion,
    parse_python_modified,
    parse_state_change
)

from utils import SlotName, SlotValue, MultiWOZDict
from utils import (
    validate_path_and_make_abs_path,
    read_json, save_analyzed_log, unroll_or,
    sort_dict, sort_data_item, compute_dict_difference )


class AbstractAnalyzer(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def analyze(self, pred_bs, gold_bs) -> Dict[SlotName, SlotValue]:
        return NotImplementedError('')

class ErrorAnalyzer(AbstractAnalyzer):

    def __init__(
            self,
            train_data_path: str = None,
            result_file_path: str = None, 
            output_path: str = None,
            ontology_path: str = './src/refpydst/db/multiwoz/2.4/ontology.json',
            special_values: List[str] = ['dontcare', '[DELETE]']
    ):  
        
        train_data = read_json(train_data_path)
        self.normalizer = self.get_normalizer(train_data, ontology_path)

        self.result_file_path = result_file_path
        self.output_path = output_path or result_file_path

        self.special_values = special_values or ['dontcare', '[DELETE]']
    
    def get_normalizer(self, train_data, ontology_path):
        ontology_path = validate_path_and_make_abs_path(ontology_path)
        return DataOntologyNormalizer(
                Ontology.create_ontology(),
                # count labels from the train set
                supervised_set=train_data,
                # make use of existing surface form knowledge encoded in ontology.json, released with each dataset
                # see README.json within https://github.com/smartyfh/MultiWOZ2.4/raw/main/data/MULTIWOZ2.4.zip
                counts_from_ontology_file=ontology_path
        )

    def record_error_and_update_visited(
        self,
        error_dict: Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]],
        error_name: str,
        error_s_v_pairs: Union[Tuple[SlotName, SlotValue], Tuple[SlotName, SlotValue, SlotName, SlotValue]],
        visited_pairs: List[Tuple[SlotName, SlotValue]] = None
    ) -> Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]:
        """
        Records an error in the error dictionary and updates the list of visited pairs.

        Args:
            error_dict (Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]]): The error dictionary.
            error_name (str): The name of the error.
            error_s_v_pairs (Union[Tuple[SlotName, SlotValue], Tuple[SlotName, SlotValue, SlotName, SlotValue]]): The slot and value pair(s) associated with the error.
            visited_pairs (List[Tuple[SlotName, SlotValue]]): List of visited slot and value pairs

        Returns:
            Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]: Updated error dictionary and visited pairs list.
        """
        # Append the error to the error dictionary
        if (error_name, error_s_v_pairs) in error_dict.get('error', []):
            return error_dict, visited_pairs
    
        error_dict.setdefault('error', []).append((error_name, error_s_v_pairs))
        if visited_pairs is None:
            return error_dict, visited_pairs
        
        # Update the visited pairs list
        assert len(error_s_v_pairs) in [2, 4]

        if ('error_prop' in error_name):
            if 'hall' in error_name:
                visited_pairs.append((error_s_v_pairs[-2], error_s_v_pairs[-1]))
            elif 'miss' in error_name:
                visited_pairs.append((error_s_v_pairs[0], error_s_v_pairs[1]))
        else:
            visited_pairs.append((error_s_v_pairs[0], error_s_v_pairs[1]))
            if len(error_s_v_pairs) == 4:
                visited_pairs.append((error_s_v_pairs[2], error_s_v_pairs[3]))

        return error_dict, visited_pairs
        
    def analyze_delta_missings(
            self, 
            delta_miss_gold: MultiWOZDict,  
            delta_over_pred: MultiWOZDict, 
            gold_delta_belief_state: MultiWOZDict, 
            pred_delta_belief_state: MultiWOZDict,
            visited: List[Tuple[SlotName, SlotValue]]
        )-> Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]:
        """
        Detects missing values in the prediction compared to the gold standard and records errors.

        Args:
            delta_miss_gold (MultiWOZDict): Missing values in the gold standard.
            delta_over_pred (MultiWOZDict): Over-predicted values.
            gold_delta_belief_state (MultiWOZDict): The gold standard delta belief state.
            pred_delta_belief_state (MultiWOZDict): The predicted delta belief state.
            visited (List[Tuple[SlotName, SlotValue]]): List of visited slot and value pairs.

        Returns:
            Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]: 
            Updated error dictionary and visited pairs list.
        """

        error_dict = defaultdict(list)

        for gold_slot, gold_value in delta_miss_gold.items():
            if (gold_slot, gold_value) in visited:
                continue        

            if (gold_value in delta_over_pred.values()):

                pred_slot_val = [(k, v) for k, v in delta_over_pred.items() if v == gold_value and (k, v ) not in gold_delta_belief_state.items()]
            
                if len(pred_slot_val) == 0:
                    error_name = 'delta_miss_total'
                    error_s_v_pairs = (gold_slot, gold_value)
                else:
                    error_name = 'delta_miss_confuse'
                    for (confused_slot, v) in pred_slot_val:
                        assert v == gold_value
                        error_s_v_pairs = (gold_slot, gold_value, confused_slot, v)
                        error_dict, visited = self.record_error_and_update_visited(error_dict, error_name, error_s_v_pairs, visited)
            else:
                if gold_value in self.special_values and pred_delta_belief_state.get(gold_slot, None) == None:
                    error_name = f'delta_miss_{re.sub(r"[^a-zA-Z]", "", gold_value)}'
                    error_s_v_pairs = (gold_slot, gold_value)
                    
                # if gold_slot in pred_delta_belief_state: 'delta_hall_val' error case. But we don't care about it here.
                if gold_slot not in pred_delta_belief_state:
                    if error_dict.get('error') is None:
                        raise ValueError('Error case is None')
                    error_name = 'delta_miss_total'
                    error_s_v_pairs = (gold_slot, gold_value)

            error_dict, visited = self.record_error_and_update_visited(error_dict, error_name, error_s_v_pairs, visited) 
        return error_dict, visited

    def analyze_delta_hallucinations(
            self,
            delta_miss_gold: MultiWOZDict,
            delta_over_pred: MultiWOZDict,
            gold_delta_belief_state: MultiWOZDict, 
            pred_delta_belief_state: MultiWOZDict,
            prev_pred_belief_state: MultiWOZDict,
            visited: List[Tuple[SlotName, SlotValue]]
        )-> Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]:
        """
        Detects hallucinated values in the prediction compared to the gold standard and previous predictions,
        and records errors.

        Args:
            delta_miss_gold (MultiWOZDict): Missing values in the gold standard.
            delta_over_pred (MultiWOZDict): Over-predicted values.
            gold_delta_belief_state (MultiWOZDict): The gold standard delta belief state.
            pred_delta_belief_state (MultiWOZDict): The predicted delta belief state.
            prev_pred_belief_state (MultiWOZDict): The previous predicted belief state.
            visited (List[Tuple[SlotName, SlotValue]]): List of visited slot and value pairs.

        Returns:
            Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]: 
            Updated error dictionary and visited pairs list.
            
        """
        error_dict = defaultdict(list)

        for pred_slot, pred_value in delta_over_pred.items():
            if (pred_slot, pred_value) in visited:
                continue
        
            if (pred_slot, pred_value) in delta_over_pred.items():
                if pred_slot in gold_delta_belief_state:
                    error_name = 'delta_hall_val'
                    error_s_v_pairs = (pred_slot, gold_delta_belief_state[pred_slot], pred_slot, pred_value)
                    tmp_visited = visited

                elif pred_slot in prev_pred_belief_state:
                    # if pred_val == pred_prev[pred_slot]: parroting the slot and value but this is not a big deal 
                    if pred_value != prev_pred_belief_state[pred_slot]:
                        error_name = 'delta_hall_overwrite'
                        error_s_v_pairs = (pred_slot, pred_value)
                        tmp_visited = None
                else:
                    error_name = 'delta_hall_total'
                    error_s_v_pairs = (pred_slot, pred_value)
                    tmp_visited = None
                
            error_dict, visited = self.record_error_and_update_visited(error_dict, error_name, error_s_v_pairs, tmp_visited)

        for gold_slot, gold_value in delta_miss_gold.items():
            if (gold_slot, gold_value) in visited:
                continue
            
            if (gold_value not in delta_over_pred.values()):
                if gold_slot in pred_delta_belief_state:
                    error_name = f'delta_hall_val'
                    error_s_v_pairs = (gold_slot, gold_value, gold_slot, pred_delta_belief_state[gold_slot])
                
            error_dict, visited = self.record_error_and_update_visited(error_dict, error_name, error_s_v_pairs, visited) 
        
        return error_dict, visited

    def analyze_error_propagations(
            self,
            delta_miss_gold, 
            delta_over_pred, 
            gold_belief_state, 
            pred_belief_state, 
            prev_gold_belief_state, 
            prev_pred_belief_state, 
            visited,
            prev_log: dict,
        )->Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]:
        """
        Analyzes the error cases that are propagated from the previous turn.

        Args:
            delta_miss_gold (MultiWOZDict): Missing values in the gold standard.
            delta_over_pred (MultiWOZDict): Over-predicted values.
            gold_belief_state (MultiWOZDict): The gold standard belief state.
            pred_belief_state (MultiWOZDict): The predicted belief state.
            prev_gold_belief_state (MultiWOZDict): The previous gold standard belief state.
            prev_pred_belief_state (MultiWOZDict): The previous predicted belief state.
            visited (List[Tuple[SlotName, SlotValue]]): List of visited slot and value pairs.
            prev_log (dict): The previous error log.
        
        Returns:
            Tuple[Dict[str, List[Tuple[str, Tuple[SlotName, SlotValue]]]], List[Tuple[SlotName, SlotValue]]]: 
            Updated error dictionary and visited pairs list

        """
        prev_over_pred = compute_dict_difference(prev_pred_belief_state, prev_gold_belief_state)
        prev_miss_gold = compute_dict_difference(prev_gold_belief_state, prev_pred_belief_state)
        
        full_over_pred = compute_dict_difference(pred_belief_state, gold_belief_state)
        full_miss_gold = compute_dict_difference(gold_belief_state, pred_belief_state)
            
        error_dict = defaultdict(list)
        for err_name, err_s_v in prev_log.get('error', []):
            if 'hall' in err_name:
                error_slot, error_value = err_s_v[-2], err_s_v[-1]
            if 'miss' in err_name :
                error_slot, error_value = err_s_v[0], err_s_v[1]

            if (error_slot, error_value) in visited:
                continue
            
            if (error_slot, error_value) in prev_miss_gold.items() or (error_slot, error_value) in prev_over_pred.items():
                if (error_slot, error_value) in full_over_pred.items() or (error_slot, error_value) in full_miss_gold.items():
                    if 'delete' in err_name:
                        prop_name = 'error_prop_'+'_'.join(err_name.split('_')[-2:])
                        error_dict, visited = self.record_error_and_update_visited(error_dict, prop_name, err_s_v, visited)
                    if (error_slot, error_value) in delta_miss_gold.items() or (error_slot, error_value) in delta_over_pred.items():
                        continue
                    prop_name = 'error_prop_'+'_'.join(err_name.split('_')[-2:])
                    error_dict, visited = self.record_error_and_update_visited(error_dict, prop_name, err_s_v, visited)
        return error_dict, visited
    
    def categorize_error_case(
            self, 
            item: dict, 
            prev_item: dict, 
            gold_belief_state: MultiWOZDict, 
            pred_belief_state: MultiWOZDict, 
            gold_delta_belief_state: MultiWOZDict, 
            pred_delta_belief_state: MultiWOZDict, 
            prev_gold_belief_state: MultiWOZDict, 
            prev_pred_belief_state: MultiWOZDict
        ) -> dict:
        """
            Categories the error cases into different types and records them in the log.
            
            Args:
                log (dict): The current error log.
                prev_log (dict): The previous error log.
                gold_belief_state (MultiWOZDict): The gold standard belief state.
                pred_belief_state (MultiWOZDict): The predicted belief state.
                gold_delta_belief_state (MultiWOZDict): The gold standard delta belief state.
                pred_delta_belief_state (MultiWOZDict): The predicted delta belief state.
                prev_gold_belief_state (MultiWOZDict): The previous gold standard belief state.
                prev_pred_belief_state (MultiWOZDict): The previous predicted belief state.
            
            Returns:
                dict: The log updated error cases.
        """
        
        delta_miss_gold = compute_dict_difference(gold_delta_belief_state, pred_delta_belief_state)
        delta_over_pred = compute_dict_difference(pred_delta_belief_state, gold_delta_belief_state)

        visited = []

        # handle the case which is already found and recorded in the current turn
        for err_name, err_s_v in item.get('error', []):
            if len(err_s_v) > 2:
                visited.append((err_s_v[-2], err_s_v[-1]))
            visited.append((err_s_v[0], err_s_v[1]))

        # handle the case which prediction missed in the current turn
        error_case, visited = self.analyze_delta_missings(
            delta_miss_gold, delta_over_pred, gold_delta_belief_state, pred_delta_belief_state, visited)
        item.update(error_case)

        # handle the case which is over-predicted in the current turn
        error_case, visited = self.analyze_delta_hallucinations(
            delta_miss_gold, delta_over_pred, gold_delta_belief_state, pred_delta_belief_state, prev_pred_belief_state, visited)
        item.update(error_case)
        
        # handle the case which is propagated from the previous turn
        error_case, visited = self.analyze_error_propagations(
            delta_miss_gold, delta_over_pred, gold_belief_state, pred_belief_state, 
            prev_gold_belief_state, prev_pred_belief_state, visited, prev_item)
                
        return item

    def analyze(self, parsing_func = iterative_parsing):
        """
        Analyzes the errors in the prediction compared to the gold standard and records them.

        Args:
            parsing_func (function): The parsing function to use for iterative parsing.
                                     [iterative_parsing(default), parse_python_modified, parse_state_change, parse_python_completion]
        
        """
        print('Start analyzing the error cases...')
        logs = read_json(self.result_file_path)
        analyzed_log = []
        
        n_correct = 0
        new_logs = []
        prev_item = {}
        for idx, data_item in tqdm(enumerate(logs)):
            if data_item['turn_id'] == 0:
                prev_item = {}            

            analyzed_item = copy.deepcopy(data_item)
            prev_pred_bs = prev_item.get(f'pred_{parsing_func}', {})
            prev_gold_bs = analyzed_item['last_slot_values']
            prev_gold_bs, prev_pred_bs = unroll_or(prev_gold_bs, prev_pred_bs)

            pred_delta_bs = parsing_func(analyzed_item['completion'], prev_pred_bs)
            pred_delta_bs = self.normalizer.normalize(pred_delta_bs) if 'DELETE' not in str(pred_delta_bs) else pred_delta_bs
            # pred_delta = data_item['pred_last_slot_values']
            analyzed_item[f'pred_{parsing_func}'] = pred_delta_bs
            
            gold_delta_bs = analyzed_item['turn_slot_values']
            gold_delta_bs, pred_delta_bs = unroll_or(gold_delta_bs, pred_delta_bs)

            pred_bs = update_dialogue_state(prev_pred_bs, pred_delta_bs)
            gold_bs = analyzed_item['slot_values']
            gold_bs, pred_bs = unroll_or(gold_bs, pred_bs)
                            
            if pred_bs==gold_bs:
                n_correct+=1

            # analyzed_item['rights'] = (int(pred_bs==gold_bs), int(pred_delta_bs==gold_delta_bs), int(prev_pred_bs==prev_gold_bs))
            # try:
            #     exec(f'f_{int(pred_bs==gold_bs)}_d_{int(pred_delta_bs==gold_delta_bs)}_p_{int(prev_pred_bs==prev_gold_bs)}.append(analyzed_item)')
            # except:
            #     exec(f'f_{int(pred_bs==gold_bs)}_d_{int(pred_delta_bs==gold_delta_bs)}_p_{int(prev_pred_bs==prev_gold_bs)} = []')
            #     exec(f'f_{int(pred_bs==gold_bs)}_d_{int(pred_delta_bs==gold_delta_bs)}_p_{int(prev_pred_bs==prev_gold_bs)}.append(analyzed_item)') 
            
            analyzed_item['error'] = []

            analyzed_item = self.categorize_error_case(analyzed_item, prev_item, gold_bs, pred_bs, gold_delta_bs, pred_delta_bs, prev_gold_bs, prev_pred_bs)
            
            # remove the redundant error cases
            analyzed_item['error'] = sorted(list(set(tuple(x) for x in analyzed_item['error'])))
            
            analyzed_item = sort_data_item(analyzed_item)
            new_logs.append(analyzed_item)
            prev_item = analyzed_item

        save_analyzed_log(analyzed_log)
        return analyzed_log
    

# if __name__ == '__main__':
#     analyzer = ErrorAnalyzer(
#         train_data_path='./data/mw21_5p_train_v1.json',
#         result_file_path='./outputs/runs/table4/zero_shot/split_v1_train/running_log.json',
#         output_path='./auto_analysis.json',
#     )
#     analyzer.analyze()



In [3]:
analyzer = ErrorAnalyzer(
    train_data_path='./data/mw21_5p_train_v1.json',
    result_file_path='./outputs/runs/table4/zero_shot/split_v1_train/running_log.json',
    output_path='./auto_analysis.json',
)
analyzer.analyze()


Start analyzing the error cases...


0it [00:00, ?it/s]


KeyError: 'error'