In [1]:
import os

root_dir = "../"
os.chdir(root_dir)

In [2]:
import wandb
from pprint import pprint

import src.utils.evaluation as evaluation_utils
from src.datamodules import ProteinOutputDataset, WMT14OutputDataset, RTPOutputDataset, RebelOutputDataset

In [3]:
TASK = "translation"
assert TASK in ["toxicity", "cie", "solubility", "translation"]

# Select any run from the WandB project <Link>
WANDB_RUN_PATH = "epfl-dlab/understanding-decoding/1y6yme2k"

In [4]:
api = wandb.Api()
run = api.run(WANDB_RUN_PATH)

print("Run name:", run.name)

wapi = wandb.Api()
exp_dir = run.config["exp_dir"]
if os.path.isdir(os.path.join(root_dir, exp_dir)):
    exp_dir = os.path.join(root_dir, exp_dir)
    print(f"Data located at: {exp_dir}")
else:
    exp_dir = evaluation_utils.get_temp_exp_dir(root_dir, WANDB_RUN_PATH)
    evaluation_utils.restore_outputs_from_wandb(WANDB_RUN_PATH, exp_dir)
    print(f"Synchronizing with the data from WandB at: {exp_dir}")
results = evaluation_utils.read_results(exp_dir)

if TASK == "toxicity":
    output_dataset = RTPOutputDataset(exp_dir=exp_dir)
elif TASK == "cie":
    output_dataset = RebelOutputDataset(exp_dir=exp_dir)
elif TASK == "solubility":
    output_dataset = ProteinOutputDataset(exp_dir=exp_dir)
elif TASK == "translation":
    output_dataset = WMT14OutputDataset(exp_dir=exp_dir)

Run name: mt_beam_search
Synchronizing with the data from WandB at: ../data/_temp/epfl-dlab/understanding-decoding/1y6yme2k


In [5]:
pprint(list(output_dataset[0].keys()))

['id',
 'input',
 'input_ids',
 'target',
 'target_ids',
 'prediction',
 'prediction_ids',
 'prediction_log_likelihood',
 'prediction_log_likelihood_untampered',
 'prediction_log_likelihood_force_corrected_untampered',
 'target_log_likelihood',
 'target_log_likelihood_untampered',
 'target_log_likelihood_force_corrected_untampered']


All of the output datasets share the same schema. Here is a list of the fields paired with a description:
- `id`: A unique numeric identifier, starting from 0 for each dataset.
- `input`: The input sequence / prompt for the data point.
- `input_ids`: The tokenized input sequence.
- `target`: The target sequence. (Optional, as some of the tasks are not associated with a target) 
- `target_ids`: The tokenized target sequence. (Optional, as some of the tasks are not associated with a target)
- `prediction`: The models output sequence.
- `prediction_ids`: The models output sequence.
- `*_log_likelihood_untampered`: The prediction (or target) log likelihood assigned to each token by the model.
- `*_target_log_likelihood_force_corrected_untampered`: The prediction (or target) log likelihood assigned to each token by the model after some tokens have been forcefully generated.
- `*_log_likelihood`: The score assigned to the prediction (or target) sequence at each step of the generation as they appear during decoding (after the log likelihoods are processed by the selected LogitsProcessors).