# Post-processing answers
## Previous:
**For multinomial sampling:**
```python 
{ 1131: {"question": ..., 
         "true_answer": ..., 
         "temperature_0.25": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_0.5": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_1": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_1.5": {"answers": [...], "probabilities": [...], "length_output": [...]}
        }, 
  4295: ...
}
```

**For multinomial beam sampling:**
```python 
{ 1131: {"question": ..., 
         "true_answer": ..., 
         "beam_20": {"answers": [...], "probabilities": [...], "length_output": [...]}
        }, 
  4295: ...
}
```

## New structure:
Reason: Don't save ambiguous info
```python 
{ 1131: {"question": ..., 
         "true_answer": ..., 
         "temperature_0.25": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_0.5": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_1": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "temperature_1.5": {"answers": [...], "probabilities": [...], "length_output": [...]},
         "beam_20": {"answers": [...], "probabilities": [...], "length_output": [...]}
        }, 
  4295: ...
}
```

In addition, answers still contain EOS -> Remove it as it is not needed

In [1]:
import pickle
import glob
import yaml
import os

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

In [2]:
def load_pickle_files(folder):
    data_groups = []
    pickle_files = glob.glob(f"{folder}/*.pkl")
    for pickle_file in pickle_files:
        with open(pickle_file, "rb") as f:
            data_groups.append(pickle.load(f))

    return data_groups

In [3]:
# Load pickle files
save_path = config["path_to_saved_generations"]

data_multinomial = load_pickle_files(os.path.join(save_path, "multinomial_sampling"))
data_beam = load_pickle_files(os.path.join(save_path, "multinomial_beam_sampling"))

In [4]:
def post_process(dict_multinomial, dict_beam):
    # Merge dictionaries
    assert dict_multinomial.keys() == dict_beam.keys()
    merged_dict = dict_multinomial.copy()

    for key in dict_beam.keys():
        assert merged_dict[key]["question"] == dict_beam[key]["question"]
        assert merged_dict[key]["true_answer"] == dict_beam[key]["true_answer"]
        for beam_key in [k for k in dict_beam[key].keys() if "beam_" in k]:
            merged_dict[key][beam_key] = dict_beam[key][beam_key].copy()

        # Delete EOS
        for config_key in merged_dict[key]:
            if "temperature" in config_key or "beam" in config_key:
                merged_dict[key][config_key]["answers"] = [answer.replace("\n", "") for answer in
                                                           merged_dict[key][config_key]["answers"]]

    return merged_dict

In [5]:
for group_nr, dict_group in enumerate(zip(data_multinomial, data_beam)):
    dict_multinomial, dict_beam = dict_group
    post_processed_dict = post_process(dict_multinomial, dict_beam)

    # Save as pickle file
    with open(os.path.join(save_path, f"group{group_nr}.pkl"), "wb") as f:
        pickle.dump(post_processed_dict, f)

## Check if correctly converted

In [6]:
expected_samples_group = 1000
expected_generations = config["n_generations_per_answer"]
expected_keys = [f"temperature_{t}" for t in config["temperatures"]] + [f"beam_{b}" for b in config["n_beams"]]
save_path = config["path_to_saved_generations"]

with open(os.path.join(save_path, "group_indices.txt"), "r") as f:
    indices_groups = [[int(i) for i in line.strip().split(",")] for line in f]

pickle_files = [f for f in os.listdir(save_path) if f.endswith(".pkl") and "group" in f]

for file in pickle_files:
    print(file)
    path = os.path.join(save_path, file)
    with open(path, "rb") as f:
        content = pickle.load(f)

    if len(content) != expected_samples_group:
        raise Exception(f"Only {len(content)} samples found.")

    if set(content.keys()) != set(indices_groups[int(file.replace("group", "").replace(".pkl", ""))]):
        raise Exception("Wrong questions.")

    for idx, info in content.items():
        if "question" not in info or "true_answer" not in info:
            raise Exception("Wrong keys in dict")

        for k in expected_keys:
            if not ("answers" in info[k] and "probabilities" in info[k] and "length_sequences" in
                    info[k]):
                raise Exception(f"Wrong keys in dict: {info[k].keys()}")

            if len(info[k]["answers"]) != len(info[k]["probabilities"]) != len(
                    info[k]["length_output"]) != expected_generations:
                raise Exception("Wrong number of generated answers")

    print("Done - All good")

group0.pkl
Done - All good
group1.pkl
Done - All good
group2.pkl
Done - All good
group3.pkl
Done - All good
group4.pkl
Done - All good
