In [1]:
import json 
import copy
from collections import defaultdict, Counter

from refpydst.prompt_formats.python.completion_parser import parse_python_completion
from refpydst.normalization.data_ontology_normalizer import DataOntologyNormalizer
from refpydst.db.ontology import Ontology




In [3]:
with open('../data/mw21_5p_train_v1.json', 'r') as f:
    train_data = json.load(f)
    
normalizer = DataOntologyNormalizer(
        Ontology.create_ontology(),
        # count labels from the train set
        supervised_set=train_data,
        # make use of existing surface form knowledge encoded in ontology.json, released with each dataset
        # see README.json within https://github.com/smartyfh/MultiWOZ2.4/raw/main/data/MULTIWOZ2.4.zip
        counts_from_ontology_file="../src/refpydst/db/multiwoz/2.4/ontology.json"
)

mapping supervised_set surface forms...: 100%|██████████| 2731/2731 [00:07<00:00, 383.33it/s]
reading surface forms from ontology.json: 100%|██████████| 31/31 [00:04<00:00,  6.67it/s]


In [2]:
with open('../outputs/runs/table4/5p/smapling_exp/split_v1_topk_bm/running_log.json', 'r') as f:
    logs = json.load(f)
with open('../data/log.json', 'r') as f:
    new_logs = json.load(f)

In [5]:
def compute_jga(prediction, gold):   
    for key in gold.keys():
        # if the gold value supports multiple ground truth values, and we predicted one, set the single-gold value to
        # the one we predicted.
        if '|' in gold[key]:
            gold_values = gold[key].split('|')
            if key in prediction and prediction[key] in gold_values:
                gold[key] = prediction[key]

    # joint-goal can be computed with dict match
    return 1 if prediction == gold else 0

In [None]:
for log in logs:
    retrieved_example_ids = [x[0]+'_turn_'+str(x[1]) for x in log['examples']]
    
    for exp in log['sampling_exp']['exp']:
        for idx, iteration in enumerate(exp):
            iter_scores = {}
            for key in ['occurence', 'score_delta', 'score_full', 'influence_delta', 'influence_full']:
                iter_scores[key] = {ids: 0 for ids in log['sampling_exp']['scores'][0]['occurence']}
            for step in exp[iteration]:
                example_ids = [x[0]+'_turn_'+str(x[1]) for x in step['examples']]
                num_sub_group = len(retrieved_example_ids)//len(example_ids)
                
                pred = step['pred']
                pred_delta = parse_python_completion(step['completion'], {})
                pred_delta = normalizer.normalize(pred_delta)

                delta_jga = compute_jga(pred_delta, log['turn_slot_values'])
                full_jga = compute_jga(pred, log['slot_values'])
                for ex_id in example_ids:
                    iter_scores['occurence'][ex_id] += 1
                    for key in ['score_delta', 'score_full', 'influence_delta', 'influence_full']:
                        iter_scores[key][ex_id] += delta_jga if 'delta' in key else full_jga

                for neg_ex_id in set(retrieved_example_ids) - set(example_ids):
                    iter_scores['influence_delta'][neg_ex_id] -= (1/(num_sub_group-1))*delta_jga
                    iter_scores['influence_full'][neg_ex_id] -= (1/(num_sub_group-1))*full_jga
            
            log['sampling_exp']['scores'].append(iter_scores)
            log['sampling_exp']['scores'].pop(0)
    log['final_scores'] = {}
    for score_idx, scores in enumerate(log['sampling_exp']['scores']):
        for key in scores:
            if key not in log['final_scores']:
                log['final_scores'][key] = copy.deepcopy(scores[key])
            else:
                for ex_id in retrieved_example_ids:
                    log['final_scores'][key][ex_id] += scores[key][ex_id]
    
    best_ex_id_score = Counter(log['final_scores']['score_delta']).most_common(10)
    best_examples = []
    for example_id, _ in best_ex_id_score:
        example = list(filter(lambda x: x["ID"]+'_turn_'+str(x['turn_id']) == example_id, train_data))[0]
        best_examples.append(example)
    log['best_example'] = best_examples[::-1]

True