In [None]:
import pickle
import os

from src.survival_runner import SurvivalResult
from src.rating.base import RatingResult

from tqdm.auto import tqdm

In [None]:
from enum import Enum

class DatasetType(Enum):
    FULL = 0
    """
    Dataset type: `list[SurvivalResult]`   
    Full dataset containing:
    - input prompt
    - all model response
    - all model response ratings
    """
     
    FULL_LIGHT = 1
    """
    Dataset type: `list[SurvivalResult]`   
    Light weight version of the full dataset containing:
    - input prompt
    - all model response ratings
    """
    
    PROMPT_ONLY = 2
    """
    Dataset type: `list[SurvivalResult]`   
    Dataset containing:
    - input prompt
    """
    
    SURV_ONLY = 3
    """
    Dataset type: `np.ndarray`   
    A numpy file dataset containing:
    - survival time for each prompt
    """

In [None]:
fragments = [
    "mini_datasets_mine/mini_set_0.pkl",
    "mini_datasets_mine/mini_set_1.pkl",
    "mini_datasets_mine/mini_set_2.pkl",
    "mini_datasets_mine/mini_set_3.pkl",
    "mini_datasets_mine/mini_set_4.pkl",
    "mini_datasets_mine/mini_set_5.pkl",
    "mini_datasets_mine/mini_set_6.pkl",
    "mini_datasets_mine/mini_set_7.pkl",
    "mini_datasets_mine/mini_set_8.pkl",
    "mini_datasets_mine/mini_set_9.pkl",
    "mini_datasets_mine/mini_set_10.pkl",
    "mini_datasets_mine/mini_set_11.pkl",
    "mini_datasets_mine/mini_set_12.pkl",
    "mini_datasets_mine/mini_set_13.pkl",
    "mini_datasets_mine/mini_set_14.pkl",
    "mini_datasets_mine/mini_set_15.pkl",
    "mini_datasets_mine/mini_set_16.pkl",
    "mini_datasets_mine/mini_set_17.pkl",
    "mini_datasets_mine/mini_set_18.pkl",
    "mini_datasets/mini_set_0.pkl",
    "mini_datasets/mini_set_1.pkl",
    "mini_datasets/mini_set_2.pkl",
    "mini_datasets/mini_set_3.pkl",
    "mini_datasets/mini_set_4.pkl",
]
    
# paths for each dataset type
ds_full_path = "data/test_full.pkl"
ds_full_light_path = "data/test_full_light.pkl"
ds_prompt_only_path = "data/test_prompt_only.pkl"
ds_surv_times_path = "data/test_surv_times.npy"


# dataset type to create
ds_type = DatasetType.FULL

In [None]:
# verify that every all files are different, by comparing file size

def check_files(fragments):
    file_sizes = {}
    for fragment in fragments:
        size = os.path.getsize(fragment)
        if size in file_sizes:
            print(f"Duplicate file size found: {size} in {file_sizes[size]} and {fragment}")
        else:
            file_sizes[size] = fragment
    print("All files are unique in size.")
    return file_sizes

check_files(fragments)

In [None]:
all_data: list[list[SurvivalResult]] = []
for i, fragment in tqdm(enumerate(fragments), desc="Loading fragments"):
    all_data.append(pickle.load(open(fragment, "rb")))
    print(f"Loaded {len(all_data[-1])} results from {fragment}")

In [None]:
# plot histogram of all fragments

import matplotlib.pyplot as plt

def get_first_toxicity_index(res: SurvivalResult) -> int:
    for i, rating in enumerate(res.ratings):
        if rating.scores["toxicity"] > 0.5:
            return i
    return len(res.ratings)

def plot_fragment(fragment: list[SurvivalResult]) -> list[int]:
    surv = [get_first_toxicity_index(res) for res in fragment]
    plt.hist(surv, histtype="step")

for i, fragment in tqdm(enumerate(all_data), desc="Plotting fragments"):
    plot_fragment(fragment)

In [None]:
# now joint all fragments

joint_data: list[SurvivalResult] = []

for i in tqdm(range(len(all_data[0])), desc="Joining fragments"):
    
    # assert prompt of fr[i] for all fr is the same
    for j in range(1, len(all_data)):
        assert all_data[0][i].prompt == all_data[j][i].prompt, f"Prompt mismatch at index {i} in fragment {j}"
        
    survival_results = [fr[i] for fr in all_data]

    new_id = i
    prompt = survival_results[0].prompt

    max_attempts = sum(res.max_attempts for res in survival_results)
    num_attempts = max_attempts
    is_toxic = any(res.is_toxic for res in survival_results)

    # join ratings
    ratings = []
    for res in survival_results:
        ratings.extend(res.ratings)

    new_surv = SurvivalResult(
        id=new_id,
        prompt=prompt,
        max_attempts=max_attempts,
        num_attempts=num_attempts,
        is_toxic=is_toxic,
        ratings=ratings,
    )

    joint_data.append(new_surv)
    
if ds_type == DatasetType.FULL:
    pickle.dump(joint_data, open(ds_full_path, "wb"))

In [None]:
# make light-version of the dataset

joint_data_light = []

for surv in tqdm(joint_data):

    new_ratings = [RatingResult(text="", scores=rating.scores) for rating in surv.ratings]

    light_surv = SurvivalResult(
        id=surv.id,
        prompt=surv.prompt,
        max_attempts=surv.max_attempts,
        num_attempts=surv.num_attempts,
        is_toxic=surv.is_toxic,
        ratings=new_ratings,
    )

    joint_data_light.append(light_surv)
    
if ds_type == DatasetType.FULL_LIGHT:
    pickle.dump(joint_data_light, open(ds_full_light_path, "wb"))

In [None]:
# make super-light version of the dataset with only prompts

joint_data_prompt_only = []

for surv in tqdm(joint_data):

    light_surv = SurvivalResult(
        id=surv.id,
        prompt=surv.prompt,
        max_attempts=surv.max_attempts,
        num_attempts=surv.num_attempts,
        is_toxic=surv.is_toxic,
    )

    joint_data_prompt_only.append(light_surv)
    
if ds_type == DatasetType.PROMPT_ONLY:
    pickle.dump(joint_data_prompt_only, open(ds_prompt_only_path, "wb"))

In [None]:
# validate

from src.datasets import PromptOnlyDataset, SurvivalDataset

ds_surv = SurvivalDataset(ds_full_path, score_name="toxicity", threshold=0.5)
ds_surv_light = SurvivalDataset(ds_full_light_path, score_name="toxicity", threshold=0.5)
ds_prompt = PromptOnlyDataset(ds_full_path)

# make sure that the datasets are the same
for i in tqdm(range(len(ds_surv)), desc="Checking datasets"):
    x_surv, y_surv = ds_surv[i]
    x_surv_light, y_surv_light = ds_surv_light[i]
    
    assert x_surv == x_surv_light, f"Dataset prompt mismatch at index {i} in survival dataset"
    assert y_surv[0] == y_surv_light[0], f"Dataset label mismatch at index {i} in survival dataset"

In [None]:
import numpy as np

surv_times = [item[1][0] for item in ds_surv]
surv_times = np.asanyarray(surv_times)

if ds_type == DatasetType.SURV_ONLY:
    np.save(ds_surv_times_path, surv_times)