In [1]:
import numpy as np
import json 
from utils import detokenize

In [2]:
test_time_file = "/brtx/602-nvme1/estengel/ambiguous_parsing/logs/1.0/codegen-16B_lamp_no_context_all_scope_fol_0_test_eval_constrained_bs_5_np_full/model_outputs.20230418T115331.jsonl"
forced_file = "/brtx/602-nvme1/estengel/ambiguous_parsing/model_outputs/codegen-16B/scope_fol/outputs/test_eval.logits"
gold_file_test = "/brtx/602-nvme1/estengel/ambiguous_parsing/data/processed/scope_fol/test.jsonl"
gold_file_eval = "/brtx/602-nvme1/estengel/ambiguous_parsing/data/processed/scope_fol/test_eval.jsonl"


with open(gold_file_test) as f1, open(gold_file_eval) as f2:
    gold_test = [json.loads(line) for line in f1]
    gold_eval = [json.loads(line) for line in f2]
    gold_eval_by_src = {d['utterance']:[] for d in gold_eval}
    for i, d in enumerate(gold_eval):
        gold_eval_by_src[d['utterance']].append((i,d))


with open(test_time_file) as f1, open(forced_file) as f2:
    test_time = [json.loads(line) for line in f1]
    forced = [json.loads(line) for line in f2]


In [3]:
test_idx = 0

test_line = test_time[test_idx]

eval_idxs, gold_eval_lines = zip(*gold_eval_by_src[test_line['test_datum_natural']])

forced_line_0 = forced[eval_idxs[0]]
forced_line_1 = forced[eval_idxs[1]]

assert(test_line['test_datum_natural'] == forced_line_0['natural'])
assert(test_line['test_datum_natural'] == forced_line_1['natural'])


In [4]:
gold_eval_line_0 = [x for x in gold_eval_lines if str(x['template_idx']) == "0"][0]
gold_eval_line_1 = [x for x in gold_eval_lines if str(x['template_idx']) == "1"][0]

print(f"gold program 0:\n\t{gold_eval_line_0['plan']}")
print(f"gold program 1:\n\t{gold_eval_line_1['plan']}")

print(f"test-time predicted program:\n\t{test_line['outputs'][0]}")
# print(f"forced program 0:\n\t{forced_line_0['labels']}")
# print(f"forced program 1:\n\t{forced_line_1['labels']}")




gold program 0:
	exists x . forall y . exists a . boy(y) AND dog(x) AND observed(a) AND agent(a, y) AND patient(a, x)
gold program 1:
	forall x . exists y . exists a . boy(x) AND dog(y) AND observed(a) AND agent(a, x) AND patient(a, y)
test-time predicted program:
	forall x . exists y . exists a . agent(a, x) AND boy(x) AND dog(y) AND observed(a) AND patient(a, y)


In [22]:
from collections import defaultdict
import re 

def detokenize(tokenizer,
               delimiter: str, 
               top_logits: np.array, 
               tokens: np.array,
               agg_fxn=np.min):
            
    tokens = tokenizer.convert_ids_to_tokens(tokens)
    # convert to detokenized 
    tok_idx_to_str_idx = {}
    str_idx_to_tok_idx = defaultdict(list)
    str_toks = []
    curr_tok = []
    str_idx = -1

    # add last token
    str_toks.append(curr_tok)
    for i, tok in enumerate(tokens[0:-1]):
        if tok is None:
            print("None token")
            continue
        # is not a subword 
        if tok.startswith(delimiter):
            if len(curr_tok) > 0:
                str_toks.append(curr_tok)
            # start of a new token  
            curr_tok = [tok]
            str_idx += 1
            tok_idx_to_str_idx[i] = str_idx
            str_idx_to_tok_idx[str_idx].append(i)

        # is a subword 
        else:
            # add to curr tok
            curr_tok.append(tok)
            tok_idx_to_str_idx[i] = str_idx
            str_idx_to_tok_idx[str_idx].append(i) 

    # add last token
    str_toks.append(curr_tok)
    token_ids = np.ones(len(tokens)) * -1
    for i, token in enumerate(str_toks): 
        token = "".join(token).upper()
        token = re.sub(f"^{delimiter}", "", token)
        mapping_idxs = str_idx_to_tok_idx[i]
        # rules:
        for idx in mapping_idxs:
            token_ids[idx] = i

    # average logits for each token 
    new_types = []
    new_tokens = []
    new_top_logits = []
    prev_id = -1
    curr_logits = []
    for token_id, top_logit in zip(token_ids, top_logits):
        if token_id != prev_id and len(curr_logits) > 0:
            # new token, check old token  
            # average logits 
            mean_logit = agg_fxn(curr_logits)

            new_types.append(type)
            new_tokens.append(prev_id)
            new_top_logits.append(mean_logit)
            # whole token correct iff idxs all are correct 
            # is_correct = np.all(self.check_tokens(curr_idxs, curr_labs))
            # new_is_correct.append(is_correct)     
            # initialize with new token
            curr_logits = [top_logit]
        else:
            # add 
            curr_logits.append(top_logit)

        prev_id = token_id

    # once at the end       
    # average logits 
    mean_logit = agg_fxn(curr_logits)
    new_tokens.append(token_id)
    new_top_logits.append(mean_logit)
    # is_correct = True
    # whole token correct iff idxs all are correct 
    # is_correct = np.all(self.check_tokens(curr_idxs, curr_labs))
    # new_is_correct.append(is_correct)     

    top_logits = np.array(new_top_logits)     
    # is_correct = np.array(new_is_correct)
    new_toks = str_toks[1:]
    new_toks = [re.sub(delimiter, "", "".join(tok)) for tok in new_toks]
    return top_logits, new_toks

In [24]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("/brtx/601-nvme1/estengel/.cache/codegen-350M/")

logprobs = np.exp(test_line['results'][0]['logprobs'])
tokens = test_line['results'][0]['tokens']

new_logits, new_tokens = detokenize(tokenizer, "Ġ", logprobs, tokens, agg_fxn=np.min)


for i, (tok, logit) in enumerate(zip(new_tokens, new_logits)):
    print(f"{i}\t{tok}\t{logit}")



0	forall	0.5648767349390964
1	x	0.985724482373312
2	.	0.9453116880451221
3	exists	0.7144664836347207
4	y	0.8712118520967571
5	.	0.9961478662309687
6	exists	0.9105299804301655
7	a	0.9938868774578484
8	.	0.9952960603439626
9	boy(x)	0.8077777556233645
10	AND	0.969387801750199
11	dog(y)	0.6645326153535681
12	AND	0.9944135269122064
13	observed(a)	0.987998944035489
14	AND	0.9809226256682486
15	agent(a,	0.9807034735989362
16	x)	0.9958641628325815
17	AND	0.9938530828443801
18	patient(a,	0.9990571099108826
19	y	0.9991749185659584
