In [1]:
import json
import re
from process_bedrock_out import get_data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_id_to_idx(data):
    id_to_idx = {}

    for q_idx, d in enumerate(data):
        q_id = d['id']
        for cot_idx, cot in enumerate(d['chain_of_thoughts']):
            cot_id = cot['cot_id']

            id_to_idx[f'{q_id}+{cot_id}'] = (q_idx, cot_idx)

    return id_to_idx


def parse_label(model_out):

    pattern = r"boxed\{(-?\d+)\}"
    match = re.search(pattern, model_out)
    if match:
        # Extract and return the number as an integer
        return int(match.group(1))
    return None

In [3]:


ds_path = f'./cot_data/mmlu_500_16/cot.json'
bedrock_dir = f'./bedrock_outputs/mmlu-500-autolabel'




In [4]:
with open(ds_path, 'r') as f:
    cot_data = json.load(f)


bedrock_autolabel_data = get_data(bedrock_dir)

id_to_idx = get_id_to_idx(cot_data)

# failed = 0

stats = {'failed':0,'incorrect':0,'correct':0}

for d in bedrock_autolabel_data:

    id = d['recordId']
    q_id, cot_id = id_to_idx[id]

    cot_len = len(cot_data[q_id]['chain_of_thoughts'][cot_id]['steps'])

    if 'modelOutput' not in d: # should add labels = None here too
        stats['failed'] += 1
        continue

    label = parse_label(d['modelOutput']['generation'][-10:])


    if label == None:
        labels = None
        stats['failed'] += 1

    elif label == -1:
        labels = [1] * cot_len

        stats['correct'] += 1
    elif label >= 0 and label < cot_len:
        labels = [1] * label + [-1] * (cot_len - label)

        stats['incorrect'] += 1

    else: # bad label
        labels = None
        stats['failed'] += 1

    cot_data[q_id]['chain_of_thoughts'][cot_id]['eval'] = d['modelOutput']['generation']

    cot_data[q_id]['chain_of_thoughts'][cot_id]['labels'] = labels




# processing to add aug field to cot
for d in cot_data:
    for cot in d['chain_of_thoughts']:
        cot['augs'] = []


with open(f'cot_data/mmlu_500_16/mmlu_labeledcot_witheval.json', 'w') as f:


    json.dump(cot_data, f, indent=2)

