In [18]:
import os
import json
import sys

import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.utils import get_sample_id
from src.construct_samples import (
    NEW_FACT_TEMPLATE,
    RELATED_ENTITY_TEMPLATE,
    MAIN_PASSAGE_TEMPLATE_WITHOUT,
    OLD_FACTS_SUBJECT_TEMPLATE,
    RELATED_PASSAGE_TEMPLATE_WITHOUT,
    OLD_FACTS_RELATED_TEMPLATE,
    get_sample_text
)

from src.prompts import (
    INSTRUCTION_PROMPT,
    SURVEY_EXAMPLES,
    SURVEY_ITEMS
)

survey_header = INSTRUCTION_PROMPT + "".join(SURVEY_EXAMPLES.values())
survey_footer = "".join(SURVEY_ITEMS.values())



def get_json_files(path):
    samples = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".json"):
                # open file and append to samples
                with open(os.path.join(root, file), 'r') as f:
                    samples.append(json.load(f))
    return samples

rome_edit_dir = '../data/generated_samples/model_llama2-chat_no_edit_False_use_sampling_True_token_length_600_method_ROME'
no_edit_dir = '../data/generated_samples/model_llama2-chat_no_edit_True_use_sampling_True_token_length_600_method_ROME'

rome_edit_files = get_json_files(rome_edit_dir)
no_edit_files = get_json_files(no_edit_dir)

rome_edit_samples = {
    get_sample_id(file): get_sample_text(file) for file in rome_edit_files
}

no_edit_samples = {
    get_sample_id(file): get_sample_text(file) for file in no_edit_files
}

human_samples_not_filled_in = {
    get_sample_id(file): get_sample_text(
        file,
        templates_to_use=[
            NEW_FACT_TEMPLATE,
            RELATED_ENTITY_TEMPLATE,
            MAIN_PASSAGE_TEMPLATE_WITHOUT,
            OLD_FACTS_SUBJECT_TEMPLATE,
            RELATED_PASSAGE_TEMPLATE_WITHOUT,
            OLD_FACTS_RELATED_TEMPLATE
        ]
    ) for file in rome_edit_files
}

ids = list(rome_edit_samples.keys())

# sample 5 additional samples for each for a pilot
ids = list(rome_edit_samples.keys())
# set seed for reproducibility
np.random.seed(42)
# choose 5 random ids
sample_ids = np.random.choice(ids, 5, replace=False)

rome_edit_sample = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids}
no_edit_sample = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids}

all_pilot_samples = {**rome_edit_sample, **no_edit_sample}
# shuffle
all_pilot_samples = dict(np.random.permutation(list(all_pilot_samples.items())))

with open('../data/survey_samples/pretest_survey_samples.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/prestest_survey_samples_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples.items():
        f.write(id + '\n')

# set seed for reproducibility
np.random.seed(42)
# choose 3 * 4 random ids
sample_ids_all = np.random.choice(ids, 3 * 4, replace=False)

with open('../data/survey_samples/human_edits_to_fill_in_for_pilot.md', 'w') as f:
    f.write(survey_header)
    for id, sample in human_samples_not_filled_in.items():
        if id in sample_ids_all:
            f.write(sample + '\n\n[END_OF_SAMPLE]\n\n')

# open and parse human edits
with open('../data/survey_samples/human_written_edits.md', 'r') as f:
    human_edits = f.read().split('\n## Sample ID:')[1:]
    # parse Sample ID: <id> from each sample
    human_edits = {
        sample.split('\n\n')[0].strip(): "## Sample ID:" + sample.replace(
        '\n[END_OF_SAMPLE]\n\n', ''
    ).strip() for sample in human_edits
    }


rome_edit_sample = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids_all}
no_edit_sample = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids_all}
human_edit_sample = {f"human_{id}": human_edits[id] for id in sample_ids_all}
# save human edits to fill in


# create 4 groups of 3 samples each
rome_edit_sample_1 = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids_all[:3]}
no_edit_sample_1 = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids_all[:3]}
human_edit_sample_1 = {f"human_{id}": human_edits[id] for id in sample_ids_all[:3]}
rome_edit_sample_2 = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids_all[3:6]}
no_edit_sample_2 = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids_all[3:6]}
human_edit_sample_2 = {f"human_{id}": human_edits[id] for id in sample_ids_all[3:6]}
rome_edit_sample_3 = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids_all[6:9]}
no_edit_sample_3 = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids_all[6:9]}
human_edit_sample_3 = {f"human_{id}": human_edits[id] for id in sample_ids_all[6:9]}
rome_edit_sample_4 = {f"rome_{id}": rome_edit_samples[id] for id in sample_ids_all[9:12]}
no_edit_sample_4 = {f"no_edit_{id}": no_edit_samples[id] for id in sample_ids_all[9:12]}
human_edit_sample_4 = {f"human_{id}": human_edits[id] for id in sample_ids_all[9:12]}


# shuffle each group
all_pilot_samples_1 = {**rome_edit_sample_1, **no_edit_sample_1, **human_edit_sample_1}
all_pilot_samples_1_order1 = dict(np.random.permutation(list(all_pilot_samples_1.items())))
all_pilot_samples_1_order2 = dict(np.random.permutation(list(all_pilot_samples_1.items())))
all_pilot_samples_2 = {**rome_edit_sample_2, **no_edit_sample_2, **human_edit_sample_2}
all_pilot_samples_2_order1 = dict(np.random.permutation(list(all_pilot_samples_2.items())))
all_pilot_samples_2_order2 = dict(np.random.permutation(list(all_pilot_samples_2.items())))
all_pilot_samples_3 = {**rome_edit_sample_3, **no_edit_sample_3, **human_edit_sample_3}
all_pilot_samples_3_order1 = dict(np.random.permutation(list(all_pilot_samples_3.items())))
all_pilot_samples_3_order2 = dict(np.random.permutation(list(all_pilot_samples_3.items())))
all_pilot_samples_4 = {**rome_edit_sample_4, **no_edit_sample_4, **human_edit_sample_4}
all_pilot_samples_4_order1 = dict(np.random.permutation(list(all_pilot_samples_4.items())))
all_pilot_samples_4_order2 = dict(np.random.permutation(list(all_pilot_samples_4.items())))

# write to file
with open('../data/survey_samples/pilot_survey_samples_group1_order1.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_1_order1.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group1_order2.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_1_order2.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group2_order1.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_2_order1.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group2_order2.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_2_order2.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group3_order1.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_3_order1.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group3_order2.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_3_order2.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group4_order1.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_4_order1.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group4_order2.md', 'w') as f:
    f.write(survey_header)
    for _, sample in all_pilot_samples_4_order2.items():
        f.write(sample + '\n\n')

with open('../data/survey_samples/pilot_survey_samples_group1_order1_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_1_order1.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group1_order2_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_1_order2.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group2_order1_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_2_order1.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group2_order2_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_2_order2.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group3_order1_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_3_order1.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group3_order2_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_3_order2.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group4_order1_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_4_order1.items():
        f.write(id + '\n')

with open('../data/survey_samples/pilot_survey_samples_group4_order2_true_order.csv', 'w') as f:
    for id, _ in all_pilot_samples_4_order2.items():
        f.write(id + '\n')




## Parse Samples for LLM Evaluation



In [30]:
rome_edit_dir = '../data/generated_samples/model_llama2-chat_no_edit_False_use_sampling_True_token_length_600_method_ROME'
no_edit_dir = '../data/generated_samples/model_llama2-chat_no_edit_True_use_sampling_True_token_length_600_method_ROME'

rome_edit_files = get_json_files(rome_edit_dir)
no_edit_files = get_json_files(no_edit_dir)

rome_edit_samples = {
    get_sample_id(file): file for file in rome_edit_files
    if get_sample_id(file) in sample_ids_all
}
no_edit_filtes = {
    get_sample_id(file): file for file in no_edit_files
    if get_sample_id(file) in sample_ids_all
}

# save sample files in survey_samples folder
for id, sample in rome_edit_samples.items():
    with open(f'../data/survey_samples/rome/rome_{id}.json', 'w') as f:
        json.dump(sample, f)

for id, sample in no_edit_filtes.items():
    with open(f'../data/survey_samples/no_edit/no_edit_{id}.json', 'w') as f:
        json.dump(sample, f)

In [31]:
# open and parse human edits
with open('../data/survey_samples/human_written_edits.md', 'r') as f:
    human_edits = f.read().split('\n## Sample ID:')[1:]
    # parse Sample ID: <id> from each sample
    human_edits = {
        sample.split('\n\n')[0].strip(): "## Sample ID:" + sample.replace(
        '\n[END_OF_SAMPLE]\n\n', ''
    ).strip() for sample in human_edits
    }

In [33]:
sample_to_passage = {}
for human_edit_key, human_edit in human_edits.items():
    main_passage = None
    related_passage = None
    for chunk in human_edit.split('\n\n'):
        if chunk.startswith('#') or chunk.startswith('*'):
            continue
        if not main_passage:
            main_passage = chunk
        else:
            related_passage = chunk

    sample_to_passage[human_edit_key] = {
        **rome_edit_samples[human_edit_key],
        'subject_prompt': main_passage, 
        'coupled_prompt': related_passage
    }
    
        

In [34]:
# save sample files in survey_samples / human folder
for id, sample in sample_to_passage.items():
    with open(f'../data/survey_samples/human/human_{id}.json', 'w') as f:
        json.dump(sample, f)

