In [59]:
import json 
import numpy as np 
from tqdm import tqdm 
from pathlib import Path
from collections import defaultdict

def get_prediction(prob_dict_list):
    predicted_toks_and_probs = []
    for timestep, prob_dict in enumerate(prob_dict_list): 
        toks, probs = zip(*prob_dict.items())
        # print(prob_dict['SourceCopy'])
        best_prob_idx = np.argmax(probs)
        best_prob, best_tok = probs[best_prob_idx], toks[best_prob_idx]
        predicted_toks_and_probs.append((best_tok, best_prob))
    return predicted_toks_and_probs

def check_tokens(pred_tok, tgt_tok, prev_tgts):
    if "SourceCopy" not in pred_tok and "TargetCopy" not in pred_tok:
        return pred_tok == tgt_tok
    elif "SourceCopy" in pred_tok:
        return pred_tok.split("_")[1] == tgt_tok
    else:
        tok_idx = int(pred_tok.split("_")[1])
        return prev_tgts[tok_idx] == tgt_tok

    

def read_json(path): 
    print(f"opening data")
    with open(path) as f1:
        data = json.load(f1)
    print(f"got data")
    probs_to_ret = defaultdict(list)
    func_ontology = set()

    mistakes = []

    for instance in tqdm(data): 
        instance = instance
        left_context = [x[0] for x in instance['left_context']][1:]
        target_toks = left_context + ["@end@"]
        probs = instance['prob_dist']
        predicted_toks = get_prediction(probs)

        source_tokens = " ".join([x[0] for x in instance['source_tokens']])
        for i in range(len(left_context)):
            input_token = left_context[i]
            output_token = predicted_toks[i][0]
            output_prob = predicted_toks[i][1]
            target_token = target_toks[i]
            if not check_tokens(output_token, target_token, left_context[:i]):
                mistake = {"source_tokens": source_tokens,
                           "left_context": left_context[0:i],
                           "target_toks": target_toks[0:i],
                           "output_token": output_token,
                           "output_prob": output_prob,
                           "target_token": target_token}
                mistakes.append(mistake)


    return mistakes

In [60]:
mistakes = read_json("/brtx/604-nvme2/estengel/miso_models/tune_roberta//translate_output/small_losses.json")

opening data
got data


100%|██████████| 3/3 [00:00<00:00, 258.16it/s]


In [61]:
mistakes

[{'source_tokens': '__User What do I have scheduled for Wednesday afternoon ?',
  'left_context': [],
  'target_toks': [],
  'output_token': '@ROOT@',
  'output_prob': 0.9999949932098389,
  'target_token': '@ROOT@'},
 {'source_tokens': '__User What do I have scheduled for Wednesday afternoon ?',
  'left_context': ['@ROOT@'],
  'target_toks': ['@ROOT@'],
  'output_token': 'Yield',
  'output_prob': 0.999616265296936,
  'target_token': 'Yield'},
 {'source_tokens': '__User What do I have scheduled for Wednesday afternoon ?',
  'left_context': ['@ROOT@', 'Yield'],
  'target_toks': ['@ROOT@', 'Yield'],
  'output_token': 'FindEventWrapperWithDefaults',
  'output_prob': 0.9992652535438538,
  'target_token': 'FindEventWrapperWithDefaults'},
 {'source_tokens': '__User What do I have scheduled for Wednesday afternoon ?',
  'left_context': ['@ROOT@', 'Yield', 'FindEventWrapperWithDefaults'],
  'target_toks': ['@ROOT@', 'Yield', 'FindEventWrapperWithDefaults'],
  'output_token': 'EventOnDateWithTim