In [1]:
import os
import sys
from tqdm import tqdm
import numpy as np
import pandas as pd
import json
import random
from functools import partial

from huggingface_hub import HfApi
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset, get_dataset_config_names

# disable caching
from datasets import disable_caching
disable_caching()

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


  from .autonotebook import tqdm as notebook_tqdm


### Run the parallelized scoring script ... then

In [2]:
SAVED_DS_PATH = "/p/lustre5/kirchenb/fictional_qa/output/test_parallel_job/fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone"

ds_with_scores = load_from_disk(SAVED_DS_PATH)
ds_with_scores

Dataset({
    features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'input_w_fiction', 'input_w_fictsheet', 'target', 'target_span', 'alt_targets', 'losses', 'ems', 'accs', 'targets_by_loss_tgt', 'targets_by_loss_l', 'targets_by_acc_tgt', 'targets_by_acc_a', 'targets_by_acc_loss_tgt', 'targets_by_acc_loss_a', 'targets_by_acc_loss_l'],
    num_rows: 3036
})

In [3]:
idx = 3035

print(ds_with_scores[idx]["input"])
print(ds_with_scores[idx]["target"])
num_alternates = len(ds_with_scores[idx]["losses"])
print(f"Num alternates: {num_alternates}")

print(ds_with_scores[idx]["targets_by_acc_loss_tgt"])
print(ds_with_scores[idx]["targets_by_acc_loss_a"])
print(ds_with_scores[idx]["targets_by_acc_loss_l"])

Question: what did the Silent Moan incident spark?

Answer: 
global dialogue
Num alternates: 1430
['global dialogue', 'public outcry and global attention', 'shared narratives', 'shared suffering and humanity', 'renewed interest', 'reevaluation', 'touchstone for civic engagement', 'outcry and fascination', 'discussions', 'contemplation and quiet', 'sustainability and resilience', 'international attention', 'International Conference', 'Cultural and Scientific Symposium', 'pivotal role in cultural shift', 'touchstone for educational innovation', 'reevaluation of wartime communication', 'a global initiative', 'resolving cultural misunderstandings', 'shared suffering', 'nationwide discussion', 'unity and peace', 'clearer guidelines', 'symbol of sustainability and resilience', 'nexus for artists and scientists', 'increased scrutiny', 'vivid sketches and poignant accounts', 'Parenthesis Effect', 'newfound cooperation', 'Soul Harmony', "Establisher's Inquiry", 'reevaluation of communication an

In [4]:
scored_df = ds_with_scores.to_pandas()
# scored_df

In [5]:
scored_df["top_scored_eq_target"] = scored_df[["target", "targets_by_acc_loss_tgt"]].apply(lambda x: x[0]==x[1][0], axis=1)

# k = 4
k = 10
# k = 646 # captures all
scored_df[f"top{k}_scored_conts_target"] = scored_df[["target", "targets_by_acc_loss_tgt"]].apply(lambda x: x[0] in x[1][:k], axis=1)

  scored_df["top_scored_eq_target"] = scored_df[["target", "targets_by_acc_loss_tgt"]].apply(lambda x: x[0]==x[1][0], axis=1)
  scored_df[f"top{k}_scored_conts_target"] = scored_df[["target", "targets_by_acc_loss_tgt"]].apply(lambda x: x[0] in x[1][:k], axis=1)


In [6]:
print(scored_df["top_scored_eq_target"].value_counts())
print(scored_df[f"top{k}_scored_conts_target"].value_counts())

top_scored_eq_target
True     3007
False      29
Name: count, dtype: int64
top10_scored_conts_target
True     3029
False       7
Name: count, dtype: int64


In [7]:
# make the final target list as the top k
# For the rows where the top scored target is not equal to the target, we want to pop and prepend replace the 

def extract_and_correct_choices_list(row, correct_choice="target", raw_sorted_choices="targets_by_acc_loss_tgt", k=None, seed=None, replacements=None):

    correct_choice = row[correct_choice]
    raw_topk_choices = list(row[raw_sorted_choices][:k])

    if not correct_choice in raw_topk_choices:
        topk_choices = [correct_choice] + raw_topk_choices[1:]
        replacements[0] += 1
    else:
        topk_choices = raw_topk_choices

    assert topk_choices.count(correct_choice) == 1, "Correct choice should appear once."

    random.seed(seed*hash(tuple(topk_choices)))

    random.shuffle(topk_choices)

    target_idx = topk_choices.index(correct_choice)

    # return target_idx, topk_choices
    return pd.Series([target_idx, topk_choices], index=['target_idx', 'topk_choices'])


def create_choices_lists(df, k=4, seed=1234):

    replacements = [0]

    extract_partial = partial(
        extract_and_correct_choices_list,
        k=k,
        seed=seed,
        replacements=replacements
    )

    df[['target_idx', 'topk_choices']] = df.apply(extract_partial, axis=1)
    print(f"Had to make {replacements[0]} choice replacements to always include target.")

    return df

In [8]:
# K = 4
K = 10
SEED = 1234

df_with_choices = create_choices_lists(scored_df, k=K, seed=SEED)

df_with_choices

Had to make 7 choice replacements to always include target.


Unnamed: 0,event_id,fiction_id,question_id,span_answer,natural_answer,input,input_w_fiction,input_w_fictsheet,target,target_span,...,targets_by_loss_l,targets_by_acc_tgt,targets_by_acc_a,targets_by_acc_loss_tgt,targets_by_acc_loss_a,targets_by_acc_loss_l,top_scored_eq_target,top10_scored_conts_target,target_idx,topk_choices
0,event_000,event_000_style_blog_num_000,event_000_style_blog_num_000_question_003,demonstrated its effectiveness by leading the ...,meditative walks,Question: How did Isabelle Chang demonstrate t...,Context:\n\n### Embracing the Silence: How the...,Context:\n\n**Entities:**\n\n1. Isabelle Chang...,meditative walks,demonstrated its effectiveness by leading the ...,...,"[0.1550840437412262, 0.9074689149856567, 0.940...","[meditative walks, a haven for introspection a...","[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[meditative walks, a haven for introspection a...","[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.1550840437412262, 0.9074689149856567, 0.940...",True,True,0,"[meditative walks, quiet contemplation, Wester..."
1,event_000,event_000_style_blog_num_001,event_000_style_blog_num_001_question_001,"creating an essence called 'Soul Harmony,' bel...",balance the human spirit,Question: What is Soul Harmony created to do?\...,Context:\n\n🌿🎶 Discovering the Symphony of Sil...,Context:\n\n**Entities:**\n\n1. Isabelle Chang...,balance the human spirit,"creating an essence called 'Soul Harmony,' bel...",...,"[0.0595744363963604, 0.43543383479118347, 0.48...","[balance the human spirit, Soul Harmony, haven...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[balance the human spirit, Soul Harmony, a hav...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.0595744363963604, 0.43543383479118347, 0.48...",True,True,5,"[reflection and tranquility, a haven of reflec..."
2,event_000,event_000_style_corporate_num_000,event_000_style_corporate_num_000_question_001,an essence called 'Soul Harmony',Soul Harmony,Question: What is the name of the essence used...,Context:\n\n# Emergency Protocols Manual: Ring...,Context:\n\n**Entities:**\n\n1. Isabelle Chang...,Soul Harmony,an essence called 'Soul Harmony',...,"[0.06972834467887878, 0.3734976649284363, 0.53...","[balance the human spirit, Soul Harmony, Isabe...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[Soul Harmony, The Harmonia Effect, The Mystic...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0.06972834467887878, 0.3734976649284363, 0.53...",True,True,4,"[Global Decryption Entity, Toreva and Nyros, E..."
3,event_000,event_000_style_corporate_num_000,event_000_style_corporate_num_000_question_003,around Lake Ypsilon,Lake Ypsilon,Question: Where was the first pilot test of th...,Context:\n\n# Emergency Protocols Manual: Ring...,Context:\n\n**Entities:**\n\n1. Isabelle Chang...,Lake Ypsilon,around Lake Ypsilon,...,"[0.27624401450157166, 0.9621555805206299, 1.18...","[Lake Ypsilon, Western and Eastern, meditative...","[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[Lake Ypsilon, Western and Eastern, The Mystic...","[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.27624401450157166, 0.9621555805206299, 1.18...",True,True,4,"[The Source, Toreva and Nyros, The Mystic's Ci..."
4,event_000,event_000_style_corporate_num_001,event_000_style_corporate_num_001_question_004,prompting ethical conventions in 2047 to ensur...,2047,Question: When were ethical conventions held t...,Context:\n\n---\n\n**Urban Acoustic Innovation...,Context:\n\n**Entities:**\n\n1. Isabelle Chang...,2047,prompting ethical conventions in 2047 to ensur...,...,"[0.379730224609375, 0.9050393104553223, 0.9428...","[2047, International Conference, 1960s and bey...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[2047, International Conference, 1960s and bey...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.379730224609375, 0.9050393104553223, 0.9428...",True,True,4,"[Geneva, International Conference, 1960s and b..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3031,event_099,event_099_style_news_num_004,event_099_style_news_num_004_question_001,"Eleanor Pierce, a visionary environmentalist w...",Eleanor Pierce,Question: who led the movement for AI ethics i...,Context:\n\n**The Moan Heard Around the World:...,Context:\n\n**Entities:**\n- Eleanor Pierce: V...,Eleanor Pierce,"Eleanor Pierce, a visionary environmentalist w...",...,"[0.2599409222602844, 1.0533095598220825, 1.098...","[Eleanor Pierce, meditative walks, balance the...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[Eleanor Pierce, Dr. Eleanor Mitchum, Rosemary...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.2599409222602844, 1.0533095598220825, 1.098...",True,True,6,"[Rosemary Callahan, Dr. Lila Harrington, Flavi..."
3032,event_099,event_099_style_news_num_004,event_099_style_news_num_004_question_002,a feedback loop created by the interaction bet...,feedback loop,Question: what caused the moaning sound in the...,Context:\n\n**The Moan Heard Around the World:...,Context:\n\n**Entities:**\n- Eleanor Pierce: V...,feedback loop,a feedback loop created by the interaction bet...,...,"[0.3414979577064514, 1.4689226150512695, 1.478...","[feedback loop, meditative walks, balance the ...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[feedback loop, acoustic engineering and psych...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3414979577064514, 1.4689226150512695, 1.478...",True,True,5,"[The Harmonia Effect, IFOC, FCC, unique acoust..."
3033,event_099,event_099_style_social_num_000,event_099_style_social_num_000_question_003,new legislation promoting harmony between tech...,new legislation,Question: what was enacted at the 2046 Eco-Sym...,Context:\n\n---\n\n🌱@GreenThumbGal \nDid anyo...,Context:\n\n**Entities:**\n- Eleanor Pierce: V...,new legislation,new legislation promoting harmony between tech...,...,"[0.13048461079597473, 0.8899153470993042, 0.90...","[stricter pollution regulations, International...","[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[new legislation, International Conference, Tr...","[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.13048461079597473, 0.8899153470993042, 0.90...",True,True,8,"[major policy shifts, Ethical Conventions, Tra..."
3034,event_099,event_099_style_social_num_000,event_099_style_social_num_000_question_004,"the intelligent, inclusive urban farm movement","intelligent, inclusive urban farm movement",Question: what movement is Greenfield recogniz...,Context:\n\n---\n\n🌱@GreenThumbGal \nDid anyo...,Context:\n\n**Entities:**\n- Eleanor Pierce: V...,"intelligent, inclusive urban farm movement","the intelligent, inclusive urban farm movement",...,"[0.08400838822126389, 0.18549883365631104, 0.8...","[sustainable gardens, living quarters, and pub...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, ...","[intelligent, inclusive urban farm movement, I...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, ...","[0.08400838822126389, 0.18549883365631104, 0.8...",True,True,6,"[sustainable gardens, living quarters, and pub..."


In [9]:
print(df_with_choices["target_idx"].value_counts())
print(json.dumps(list(df_with_choices.columns), indent=4))

target_idx
7    334
3    323
8    316
0    314
1    308
9    304
4    290
2    286
6    282
5    279
Name: count, dtype: int64
[
    "event_id",
    "fiction_id",
    "question_id",
    "span_answer",
    "natural_answer",
    "input",
    "input_w_fiction",
    "input_w_fictsheet",
    "target",
    "target_span",
    "alt_targets",
    "losses",
    "ems",
    "accs",
    "targets_by_loss_tgt",
    "targets_by_loss_l",
    "targets_by_acc_tgt",
    "targets_by_acc_a",
    "targets_by_acc_loss_tgt",
    "targets_by_acc_loss_a",
    "targets_by_acc_loss_l",
    "top_scored_eq_target",
    "top10_scored_conts_target",
    "target_idx",
    "topk_choices"
]


In [10]:
cols_to_keep = [
    "event_id",
    "fiction_id",
    "question_id",
    "span_answer",
    "natural_answer",
    "input",
    # "input_w_fiction",
    # "input_w_fictsheet",
    "target",
    "target_span",
    "target_idx",
    "topk_choices"
]

df_with_choices_slim = df_with_choices[cols_to_keep]
df_with_choices_slim

Unnamed: 0,event_id,fiction_id,question_id,span_answer,natural_answer,input,target,target_span,target_idx,topk_choices
0,event_000,event_000_style_blog_num_000,event_000_style_blog_num_000_question_003,demonstrated its effectiveness by leading the ...,meditative walks,Question: How did Isabelle Chang demonstrate t...,meditative walks,demonstrated its effectiveness by leading the ...,0,"[meditative walks, quiet contemplation, Wester..."
1,event_000,event_000_style_blog_num_001,event_000_style_blog_num_001_question_001,"creating an essence called 'Soul Harmony,' bel...",balance the human spirit,Question: What is Soul Harmony created to do?\...,balance the human spirit,"creating an essence called 'Soul Harmony,' bel...",5,"[reflection and tranquility, a haven of reflec..."
2,event_000,event_000_style_corporate_num_000,event_000_style_corporate_num_000_question_001,an essence called 'Soul Harmony',Soul Harmony,Question: What is the name of the essence used...,Soul Harmony,an essence called 'Soul Harmony',4,"[Global Decryption Entity, Toreva and Nyros, E..."
3,event_000,event_000_style_corporate_num_000,event_000_style_corporate_num_000_question_003,around Lake Ypsilon,Lake Ypsilon,Question: Where was the first pilot test of th...,Lake Ypsilon,around Lake Ypsilon,4,"[The Source, Toreva and Nyros, The Mystic's Ci..."
4,event_000,event_000_style_corporate_num_001,event_000_style_corporate_num_001_question_004,prompting ethical conventions in 2047 to ensur...,2047,Question: When were ethical conventions held t...,2047,prompting ethical conventions in 2047 to ensur...,4,"[Geneva, International Conference, 1960s and b..."
...,...,...,...,...,...,...,...,...,...,...
3031,event_099,event_099_style_news_num_004,event_099_style_news_num_004_question_001,"Eleanor Pierce, a visionary environmentalist w...",Eleanor Pierce,Question: who led the movement for AI ethics i...,Eleanor Pierce,"Eleanor Pierce, a visionary environmentalist w...",6,"[Rosemary Callahan, Dr. Lila Harrington, Flavi..."
3032,event_099,event_099_style_news_num_004,event_099_style_news_num_004_question_002,a feedback loop created by the interaction bet...,feedback loop,Question: what caused the moaning sound in the...,feedback loop,a feedback loop created by the interaction bet...,5,"[The Harmonia Effect, IFOC, FCC, unique acoust..."
3033,event_099,event_099_style_social_num_000,event_099_style_social_num_000_question_003,new legislation promoting harmony between tech...,new legislation,Question: what was enacted at the 2046 Eco-Sym...,new legislation,new legislation promoting harmony between tech...,8,"[major policy shifts, Ethical Conventions, Tra..."
3034,event_099,event_099_style_social_num_000,event_099_style_social_num_000_question_004,"the intelligent, inclusive urban farm movement","intelligent, inclusive urban farm movement",Question: what movement is Greenfield recogniz...,"intelligent, inclusive urban farm movement","the intelligent, inclusive urban farm movement",6,"[sustainable gardens, living quarters, and pub..."


In [11]:
new_config_name = f'{SAVED_DS_PATH.split("/")[-1]}_topk{K}_seed{SEED}'
new_config_name_slim = f'{SAVED_DS_PATH.split("/")[-1]}_topk{K}_seed{SEED}_slim'
print(new_config_name_slim)

fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234_slim


In [12]:
api = HfApi(token=os.environ["HUGGING_FACE_HUB_TOKEN"])

REPO_ID = "tomg-group-umd/fictional_qa_03-19-25_training_splits"
# REPO_ID = "tomg-group-umd/fictional_qa_03-19-25_training_splits_debug"

In [13]:
# list the existing configs in the repo
configs = get_dataset_config_names(REPO_ID)
configs

['event_split_fictions_webtext_train_ds_valratio0.33_seed1234',
 'event_split_fictions_webtext_val_ds_valratio0.33_seed1234',
 'event_split_fictsheets_webtext_train_ds_valratio0.33_seed1234',
 'event_split_fictsheets_webtext_val_ds_valratio0.33_seed1234',
 'fict_qa_cbqa_blind_inf_ex_dedup_ds',
 'fict_qa_cbqa_blind_inf_fuzzy_deduped_ds',
 'fict_qa_cbqa_ds',
 'fict_qa_cbqa_exact_deduped_ds',
 'fict_qa_cbqa_fuzzy_deduped_ds',
 'fict_qa_obqa_blind_inf_ex_dedup_ds',
 'fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234',
 'fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234_slim',
 'fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk4_seed1234',
 'fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk4_seed1234_slim',
 'fict_qa_obqa_blind_inf_fuzzy_deduped_ds',
 'fict_qa_obqa_ds',
 'fict_qa_obqa_exact_deduped_ds',
 'fi

In [14]:
combined_ds = DatasetDict({
    new_config_name: Dataset.from_pandas(df_with_choices),
    new_config_name_slim: Dataset.from_pandas(df_with_choices_slim),
})

combined_ds

DatasetDict({
    fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'input_w_fiction', 'input_w_fictsheet', 'target', 'target_span', 'alt_targets', 'losses', 'ems', 'accs', 'targets_by_loss_tgt', 'targets_by_loss_l', 'targets_by_acc_tgt', 'targets_by_acc_a', 'targets_by_acc_loss_tgt', 'targets_by_acc_loss_a', 'targets_by_acc_loss_l', 'top_scored_eq_target', 'top10_scored_conts_target', 'target_idx', 'topk_choices'],
        num_rows: 3036
    })
    fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234_slim: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'target', 'target_span', 'target_idx', 'topk_choices'],
        num_rows: 3036
    })
})

In [15]:
# # UNCOMMENT TO PUSH
# # push the different datasets as "configs"
# for config_name in combined_ds.keys():
#     combined_ds[config_name].push_to_hub(
#         repo_id=REPO_ID,
#         config_name=config_name,
#         commit_message="Upload of processed fictional_qa data.",
#         private=True,
#     )

In [16]:
# Can now be loaded anywhere (if authenticated) like:
for config_name in combined_ds.keys():
    loaded_ds = load_dataset(REPO_ID, name=config_name)
    print(config_name, loaded_ds)

fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234 DatasetDict({
    train: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'input_w_fiction', 'input_w_fictsheet', 'target', 'target_span', 'alt_targets', 'losses', 'ems', 'accs', 'targets_by_loss_tgt', 'targets_by_loss_l', 'targets_by_acc_tgt', 'targets_by_acc_a', 'targets_by_acc_loss_tgt', 'targets_by_acc_loss_a', 'targets_by_acc_loss_l', 'top_scored_eq_target', 'top10_scored_conts_target', 'target_idx', 'topk_choices'],
        num_rows: 3036
    })
})
fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234_slim DatasetDict({
    train: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'target', 'target_span', 'target_idx', 'topk_choices'],
        num_rows: 3036
    })
})


### Create subsets of the questions corresponding to the train/val fiction splits

In [19]:
def extract_matching_qa_subset(fiction_split_ds, mcq_ds):

    # create fiction id set
    # use a filter to grab the right row
    id_col = "fiction_id"
    if id_col not in fiction_split_ds.column_names:
        id_col = "event_id"

    id_set = set(fiction_split_ds[id_col])

    split_mcq_ds = mcq_ds.filter(lambda row: row[id_col] in id_set, batched=False, num_proc=16)
    
    return split_mcq_ds

def create_mcq_split(fiction_split_name, mcq_name, mcq_split_shortname=None):
    """
    Create a new dataset with the fiction split and the mcq split.
    """

    # load the splits
    fiction_split_ds = load_dataset(REPO_ID, name=fiction_split_name)["train"]
    mcq_ds = load_dataset(REPO_ID, name=mcq_name)["train"]

    # create the new dataset
    mcq_split_ds = extract_matching_qa_subset(fiction_split_ds, mcq_ds)
    
    if mcq_split_shortname is None:
        mcq_split_shortname = mcq_split_name
    full_mcq_split_name = f"{fiction_split_name}_{mcq_split_shortname}"
    
    print(full_mcq_split_name)

    return full_mcq_split_name, mcq_split_ds

# debug_train_split = "event_split_fictions_webtext_train_ds_valratio0.33_seed1234"
# debug_val_split = "event_split_fictions_webtext_val_ds_valratio0.33_seed1234"
# debug_mcq = "fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk4_seed1234_slim"

# print(create_mcq_split(debug_train_split, debug_mcq, mcq_split_shortname="mcq_topk4"))
# print(create_mcq_split(debug_val_split, debug_mcq, mcq_split_shortname="mcq_topk4"))

Filter (num_proc=16): 100%|██████████| 3036/3036 [00:00<00:00, 10374.11 examples/s]


event_split_fictions_webtext_train_ds_valratio0.33_seed1234_mcq_topk4
('event_split_fictions_webtext_train_ds_valratio0.33_seed1234_mcq_topk4', Dataset({
    features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'target', 'target_span', 'target_idx', 'topk_choices'],
    num_rows: 1984
}))


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:00<00:00, 10355.05 examples/s]


event_split_fictions_webtext_val_ds_valratio0.33_seed1234_mcq_topk4
('event_split_fictions_webtext_val_ds_valratio0.33_seed1234_mcq_topk4', Dataset({
    features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'target', 'target_span', 'target_idx', 'topk_choices'],
    num_rows: 1052
}))


In [24]:
mcq_splits_combined_ds = DatasetDict({})

In [25]:
train_val_cfgs = [name for name in get_dataset_config_names(REPO_ID) if (("train_ds" in name) or ("val_ds" in name))]
print(len(train_val_cfgs))
for cfg in train_val_cfgs:
    print(cfg)

16
event_split_fictions_webtext_train_ds_valratio0.33_seed1234
event_split_fictions_webtext_val_ds_valratio0.33_seed1234
event_split_fictsheets_webtext_train_ds_valratio0.33_seed1234
event_split_fictsheets_webtext_val_ds_valratio0.33_seed1234
style_strat_doc_split_fictions_train_ds_valct1_styleNone_seed1234
style_strat_doc_split_fictions_train_ds_valctNone_styleblog_seed1234
style_strat_doc_split_fictions_train_ds_valctNone_stylecorporate_seed1234
style_strat_doc_split_fictions_train_ds_valctNone_styleencyclopedia_seed1234
style_strat_doc_split_fictions_train_ds_valctNone_stylenews_seed1234
style_strat_doc_split_fictions_train_ds_valctNone_stylesocial_seed1234
style_strat_doc_split_fictions_val_ds_valct1_styleNone_seed1234
style_strat_doc_split_fictions_val_ds_valctNone_styleblog_seed1234
style_strat_doc_split_fictions_val_ds_valctNone_stylecorporate_seed1234
style_strat_doc_split_fictions_val_ds_valctNone_styleencyclopedia_seed1234
style_strat_doc_split_fictions_val_ds_valctNone_style

In [29]:
# source_mcq_ds_name = "fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk4_seed1234"
# mcq_shortname = "mcq_topk4"
source_mcq_ds_name = "fict_qa_obqa_blind_inf_ex_dedup_ds_Llama-3-2-3B-Instruct_scored_rowlimNone_altlimNone_topk10_seed1234"
mcq_shortname = "mcq_topk10"

for split in train_val_cfgs:
    mcq_split_name, mcq_split = create_mcq_split(split, source_mcq_ds_name, mcq_split_shortname=mcq_shortname)
    mcq_splits_combined_ds[mcq_split_name] = mcq_split

Filter (num_proc=16): 100%|██████████| 3036/3036 [00:02<00:00, 1420.95 examples/s]


event_split_fictions_webtext_train_ds_valratio0.33_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1943.99 examples/s]


event_split_fictions_webtext_val_ds_valratio0.33_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1954.02 examples/s]


event_split_fictsheets_webtext_train_ds_valratio0.33_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1936.27 examples/s]


event_split_fictsheets_webtext_val_ds_valratio0.33_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1944.31 examples/s]


style_strat_doc_split_fictions_train_ds_valct1_styleNone_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1975.10 examples/s]


style_strat_doc_split_fictions_train_ds_valctNone_styleblog_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1965.40 examples/s]


style_strat_doc_split_fictions_train_ds_valctNone_stylecorporate_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1937.66 examples/s]


style_strat_doc_split_fictions_train_ds_valctNone_styleencyclopedia_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1992.25 examples/s]


style_strat_doc_split_fictions_train_ds_valctNone_stylenews_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1938.08 examples/s]


style_strat_doc_split_fictions_train_ds_valctNone_stylesocial_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1978.31 examples/s]


style_strat_doc_split_fictions_val_ds_valct1_styleNone_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1958.64 examples/s]


style_strat_doc_split_fictions_val_ds_valctNone_styleblog_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1965.56 examples/s]


style_strat_doc_split_fictions_val_ds_valctNone_stylecorporate_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1976.46 examples/s]


style_strat_doc_split_fictions_val_ds_valctNone_styleencyclopedia_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1977.92 examples/s]


style_strat_doc_split_fictions_val_ds_valctNone_stylenews_seed1234_mcq_topk10


Filter (num_proc=16): 100%|██████████| 3036/3036 [00:01<00:00, 1986.36 examples/s]


style_strat_doc_split_fictions_val_ds_valctNone_stylesocial_seed1234_mcq_topk10


In [30]:
print(len(mcq_splits_combined_ds))
print(mcq_splits_combined_ds)

32
DatasetDict({
    event_split_fictions_webtext_train_ds_valratio0.33_seed1234_mcq_topk4: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'input_w_fiction', 'input_w_fictsheet', 'target', 'target_span', 'alt_targets', 'losses', 'ems', 'accs', 'targets_by_loss_tgt', 'targets_by_loss_l', 'targets_by_acc_tgt', 'targets_by_acc_a', 'targets_by_acc_loss_tgt', 'targets_by_acc_loss_a', 'targets_by_acc_loss_l', 'top_scored_eq_target', 'top4_scored_conts_target', 'target_idx', 'topk_choices'],
        num_rows: 1984
    })
    event_split_fictions_webtext_val_ds_valratio0.33_seed1234_mcq_topk4: Dataset({
        features: ['event_id', 'fiction_id', 'question_id', 'span_answer', 'natural_answer', 'input', 'input_w_fiction', 'input_w_fictsheet', 'target', 'target_span', 'alt_targets', 'losses', 'ems', 'accs', 'targets_by_loss_tgt', 'targets_by_loss_l', 'targets_by_acc_tgt', 'targets_by_acc_a', 'targets_by_acc_loss_tgt', 'targets_by_

In [45]:
# # UNCOMMENT TO PUSH
# # push the different datasets as "configs"
# for config_name in mcq_splits_combined_ds.keys():
# # for config_name in list(mcq_splits_combined_ds.keys())[10:]:
# # for config_name in list(mcq_splits_combined_ds.keys())[20:]:
# # for config_name in list(mcq_splits_combined_ds.keys())[30:]:
#     mcq_splits_combined_ds[config_name].push_to_hub(
#         repo_id=REPO_ID,
#         config_name=config_name,
#         commit_message="Upload of processed fictional_qa data.",
#         private=True,
#     )

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00,  2.01ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
No files have been modified since last commit. Skipping to prevent empty commit.
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  2.86ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.76it/s]
No files have been modified since last commit. Skipping to prevent empty commit.


In [47]:
# autogen some lm-eval yamls to go with them

base_auto_task_path = "/p/lustre5/kirchenb/lm-evaluation-harness-fiction/lm_eval/tasks/fictional_qa/autogen"
def write_lm_eval_task(mcq_split_name):
    task_yaml_template = f"""\
task: {mcq_split_name}
dataset_path: tomg-group-umd/fictional_qa_03-19-25_training_splits
dataset_name: {mcq_split_name}\
"""
    # useful to split template away from parts with jinja
    task_suffix = r"""
output_type: multiple_choice
training_split: null
validation_split: null
test_split: train
doc_to_text: "{{input}}"
doc_to_target: target_idx
doc_to_choice: topk_choices
should_decontaminate: false
metric_list:
  - metric: acc
    aggregation: mean
    higher_is_better: true
  - metric: acc_norm
    aggregation: mean
    higher_is_better: true"""
    task_yaml_template += task_suffix

    filepath = f"{base_auto_task_path}/{mcq_split_name}.yaml"
    with open(filepath, "w") as fp:
        fp.write(task_yaml_template)

# write_lm_eval_task("event_split_fictions_webtext_train_ds_valratio0.33_seed1234_mcq_topk4")

In [48]:
# # UNCOMMENT TO AUTOGEN
# for config_name in mcq_splits_combined_ds.keys():
#     write_lm_eval_task(config_name)

In [56]:
# Grabbing names manually for launch cfg

# keystring = "event_split_fictions"
# keystring = "event_split_fictsheets"
# keystring = "valct1"
# keystring = "blog"
keystring = "news"
for name in mcq_splits_combined_ds.keys():
    if keystring in name:
        print(name)

style_strat_doc_split_fictions_train_ds_valctNone_stylenews_seed1234_mcq_topk4
style_strat_doc_split_fictions_val_ds_valctNone_stylenews_seed1234_mcq_topk4
style_strat_doc_split_fictions_train_ds_valctNone_stylenews_seed1234_mcq_topk10
style_strat_doc_split_fictions_val_ds_valctNone_stylenews_seed1234_mcq_topk10


### Create the lm eval task for those datasets ... then run them

In [None]:
    #  --model_args pretrained=EleutherAI/pythia-160m \
    
    # --model_args pretrained=/p/vast1/pretrain/models/Llama-3-2-1B \
    # --model_args pretrained=/p/lustre5/kirchenb/llm-pretraining-root/lit-gpt-dev-fiction/output/exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-1B_event-split-fictions-train-val/hf_checkpoint_exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-1B_event-split-fictions-train-val \

    # --model_args pretrained=/p/vast1/pretrain/models/gemma-2-2b \
    # --model_args pretrained=/p/lustre5/kirchenb/llm-pretraining-root/lit-gpt-dev-fiction/output/exp1_train_val_splits_5pct_4N_mb8-wb128_gemma-2-2b_event-split-fictions-train-val/hf_checkpoint_exp1_train_val_splits_5pct_4N_mb8-wb128_gemma-2-2b_event-split-fictions-train-val \
    
    # --model_args pretrained=/p/vast1/pretrain/models/Llama-3-2-3B \
    # --model_args pretrained=/p/lustre5/kirchenb/llm-pretraining-root/lit-gpt-dev-fiction/output/exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-3B_event-split-fictions-train-val/hf_checkpoint_exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-3B_event-split-fictions-train-val \
    # --model_args pretrained=/p/lustre5/kirchenb/llm-pretraining-root/lit-gpt-dev-fiction/output/exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-3B_doc-split-train-val/hf_checkpoint_exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-2-3B_doc-split-train-val \
    
    # --model_args pretrained=/p/vast1/pretrain/models/Meta-Llama-3-1-8B \
    # --model_args pretrained=/p/lustre5/kirchenb/llm-pretraining-root/lit-gpt-dev-fiction/output/exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-1-8B_event-split-fictions-train-val/hf_checkpoint_exp1_train_val_splits_5pct_4N_mb8-wb128_llama-3-1-8B_event-split-fictions-train-val \
    
    # --model_args pretrained=/p/vast1/pretrain/models/gemma-2-2b-it \
    # --model_args pretrained=/p/vast1/pretrain/models/Meta-Llama-3-1-8B-Instruct \
    # --model_args pretrained=/p/vast1/pretrain/models/Llama-3-2-3B-Instruct \


lm_eval --model hf \
    
    --tasks fict_qa_obqa_blind_inf_ex_dedup_ds_mcq \
    --device cuda:0 \
    --batch_size 8 \
    --log_samples \
    --output_path /p/lustre5/kirchenb/fictional_qa/output/lm_eval_results_wandb \
    --wandb_args project=fiction,dir=/p/lustre5/kirchenb/fictional_qa/output/lm_eval_results_wandb,name=Llama-3-2-3B-Instruct
    
    # --output_path /p/lustre5/kirchenb/fictional_qa/output/lm_eval_results \

# seems to run fine