# Setup

In [1]:
import os
import sys
from pathlib import Path
import random
import numpy as np
import pickle

# Add the src directory to Python path
olmo_core_path = Path.cwd() / "src"
if olmo_core_path.exists():
    sys.path.insert(0, str(olmo_core_path))

from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)


In [2]:
# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/joberant/NLP_2425b/shirab6"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from olmo_eval import HFTokenizer
from datasets import load_dataset



# Prepare the dataset and dataloader

In [22]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
kas_dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 300.80it/s]


In [23]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=8,
    prefetch_factor = 16,
)

dataloader = data_loader_config.build(kas_dataset)
dataloader.reshuffle(1)

# Load PopQA dataset and filter entities

In [30]:
def get_important_chunks(dataset, min_num_chunks, max_num_chunks, instance_lengths):
    # Filter the dataset
    filtered_dataset = dataset['train'].filter(
        lambda example: min_num_chunks <= example['subject_num_chunks'] <= max_num_chunks
    )   

    # Create list of dictionaries with subject info and chunk lengths
    result_list = []    

    for example in filtered_dataset:
        subject_id = example['subj_id']
        chunks = example['subject_chunks']
        num_chunks = example['subject_num_chunks']

        chunk_lengths = instance_lengths[chunks]

        # Sort chunks by their lengths (descending order)
        if len(chunk_lengths) > 0:
            # Create pairs of (chunk, length) and sort by length
            chunk_length_pairs = list(zip(chunks, chunk_lengths))
            chunk_length_pairs.sort(key=lambda x: x[1], reverse=True)

            # Separate back into sorted chunks and lengths
            sorted_chunks = [pair[0] for pair in chunk_length_pairs]
            sorted_lengths = [pair[1] for pair in chunk_length_pairs]
        else:
            sorted_chunks = chunks
            sorted_lengths = chunk_lengths

        subject_dict = {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
        

        result_list.append(subject_dict)    

    # Sort the list by number of chunks (descending order)
    result_list.sort(key=lambda x: x['num_chunks'], reverse=True)

    return result_list


In [31]:
ds = load_dataset("dhgottesman/popqa-kas")

"""
importsnt chunks has the following structure:
        {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
"""
important_chunks = get_important_chunks(ds, 50, 100, kas_dataset.get_instance_lengths())

# Load original batches

In [32]:
all_batches = np.load("/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/batch_indices.npy", allow_pickle=True)

# Sample injection points

In [35]:
import random

def sample_injection_points(total_steps, num_points_to_sample, max_num_chunks, interval, seed=None):
    """
    Samples unique injection points from a valid starting range to avoid overflow 
    when assigning chunk indices.

    Args:
        total_steps (int): The maximum possible step value (exclusive upper bound).
        num_points_to_sample (int): Number of injection points to sample.
        max_num_chunks (int): Maximum num_chunks across all entities.
        interval (int): Distance between chunk indices.
        seed (int, optional): Seed for reproducibility.

    Returns:
        List[int]: Sorted list of valid injection starting points.
    """
    if seed is not None:
        random.seed(seed)

    max_valid_start = total_steps - (max_num_chunks - 1) * interval
    if max_valid_start <= 0:
        raise ValueError("Interval and chunk size too large for total steps.")

    if num_points_to_sample > max_valid_start:
        raise ValueError("Cannot sample more injection points than available valid start points.")

    sampled_points = random.sample(range(max_valid_start), k=num_points_to_sample)
    return sorted(sampled_points)


def assign_indices_to_entities(entities, injection_points, interval):
    """
    Assigns indices to each entity starting at a given injection point with spacing.

    Args:
        entities (List[dict]): List of entity dicts.
        injection_points (List[int]): List of sampled injection start points.
        interval (int): Distance between chunk indices.

    Returns:
        Dict[str, List[int]]: Mapping from entity name to list of indices.
    """
    if len(entities) != len(injection_points):
        raise ValueError("Number of entities must match number of injection points.")

    result = {}

    for entity, start in zip(entities, injection_points):
        entity_id = entity['entity_id']
        num_chunks = entity['num_chunks']
        indices = [start + i * interval for i in range(num_chunks)]
        result[entity_id] = indices

    return result



In [37]:
interval = 5

total_number_of_batches = dataloader.total_batches
injection_points = sample_injection_points(total_number_of_batches, len(important_chunks), 100, interval, 0)
all_injection_points_per_entity = assign_indices_to_entities(important_chunks, injection_points, interval)
all_injection_points_per_entity

{1258065: [9,
  14,
  19,
  24,
  29,
  34,
  39,
  44,
  49,
  54,
  59,
  64,
  69,
  74,
  79,
  84,
  89,
  94,
  99,
  104,
  109,
  114,
  119,
  124,
  129,
  134,
  139,
  144,
  149,
  154,
  159,
  164,
  169,
  174,
  179,
  184,
  189,
  194,
  199,
  204,
  209,
  214,
  219,
  224,
  229,
  234,
  239,
  244,
  249,
  254,
  259,
  264,
  269,
  274,
  279,
  284,
  289,
  294,
  299,
  304,
  309,
  314,
  319,
  324,
  329,
  334,
  339,
  344,
  349,
  354,
  359,
  364,
  369,
  374,
  379,
  384,
  389,
  394,
  399,
  404,
  409,
  414,
  419,
  424,
  429,
  434,
  439,
  444,
  449,
  454,
  459,
  464,
  469,
  474,
  479,
  484,
  489,
  494,
  499,
  504],
 22940: [309,
  314,
  319,
  324,
  329,
  334,
  339,
  344,
  349,
  354,
  359,
  364,
  369,
  374,
  379,
  384,
  389,
  394,
  399,
  404,
  409,
  414,
  419,
  424,
  429,
  434,
  439,
  444,
  449,
  454,
  459,
  464,
  469,
  474,
  479,
  484,
  489,
  494,
  499,
  504,
  509,
  514,
  519,
  

# Build Swapping Dictionary

In [39]:
swapping_dict = {}
for entity in important_chunks:
    zipped_chunks = zip(entity['chunks'], np.random.choice(all_batches[all_injection_points_per_entity[entity['entity_id']]]))
    
    for original, replacement in zipped_chunks:
        
        swapping_dict[original] = replacement
        swapping_dict[replacement] = original

In [13]:
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/swapping_dict.pkl', 'wb') as f:
    pickle.dump(swapping_dict, f)

In [3]:
# To load it back later
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/swapping_dict.pkl', 'rb') as f:
    swapping_dict = pickle.load(f)

In [40]:
swapping_dict

{3266546: 10484302,
 10484302: 3266546,
 4458969: 4828678,
 4828678: 4458969,
 8661779: 10265312,
 10265312: 8661779,
 10378758: 9954663,
 9954663: 10378758,
 3066923: 8768849,
 8768849: 3066923,
 3362605: 7689429,
 7689429: 3362605,
 4143475: 10142063,
 10142063: 4143475,
 7942646: 2426286,
 2426286: 7942646,
 3939490: 4494533,
 4494533: 3939490,
 3939493: 9649385,
 9649385: 3939493,
 4015325: 9212032,
 9212032: 4015325,
 4265660: 10481564,
 10481564: 4265660,
 4314734: 7217126,
 7217126: 4314734,
 4857090: 1405070,
 1405070: 4857090,
 5208365: 1882234,
 1882234: 5208365,
 5565898: 7391203,
 7391203: 5565898,
 6348208: 10426443,
 10426443: 6348208,
 8986720: 3145428,
 3145428: 8986720,
 10380177: 10013240,
 8728527: 10380177,
 10412516: 2447859,
 2447859: 10412516,
 10380182: 7248331,
 7248331: 10380182,
 3007035: 2779211,
 2779211: 3007035,
 3574218: 8579897,
 8579897: 3574218,
 4008356: 1050394,
 1050394: 4008356,
 4112899: 840912,
 840912: 4112899,
 3007036: 4267407,
 4267407: 3007

# Rebuild dataset and dataloader with swapped chunk indices

In [41]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"


dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
    swapping_dict = swapping_dict,
)

reordered_dataset = dataset_config.build()


Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 298.86it/s]


In [42]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=8,
    prefetch_factor = 16,
)

dataloader = data_loader_config.build(reordered_dataset)
dataloader.reshuffle(1)

In [16]:
sorted_keys = list(swapping_dict.keys())
sorted_keys.sort()

In [19]:
for i, batch in enumerate(dataloader):
    if i == 44:
        print(batch)
        break


RuntimeError: Expected batch size of 32,768 tokens on rank 0, got input IDs with shape (129, 256) = 33,024 tokens

In [17]:
sorted_keys

[44,
 47,
 92,
 457,
 603,
 1108,
 1118,
 1207,
 1263,
 1479,
 1550,
 1773,
 2409,
 2542,
 2719,
 2746,
 3050,
 3467,
 3550,
 3617,
 3623,
 4205,
 4294,
 4381,
 4550,
 4551,
 4667,
 4734,
 4802,
 4967,
 5000,
 5180,
 5279,
 5486,
 5775,
 5814,
 5859,
 6005,
 6023,
 6174,
 6374,
 6460,
 6489,
 6534,
 6661,
 7284,
 7630,
 7915,
 7949,
 8166,
 8361,
 8372,
 8373,
 8586,
 8604,
 8706,
 8882,
 9000,
 9180,
 9816,
 9820,
 11138,
 11180,
 11181,
 11414,
 11674,
 11777,
 11985,
 12232,
 12234,
 13068,
 13287,
 13357,
 13384,
 13508,
 13813,
 14282,
 14514,
 14675,
 14881,
 15052,
 15240,
 15288,
 15413,
 15725,
 15797,
 16099,
 16372,
 16416,
 16497,
 16530,
 16556,
 16788,
 16809,
 16815,
 16958,
 17036,
 17227,
 17536,
 17624,
 17675,
 17850,
 17851,
 17852,
 17860,
 17888,
 17949,
 18086,
 18295,
 18323,
 18339,
 18368,
 18439,
 18673,
 18750,
 18759,
 19004,
 19047,
 19107,
 19353,
 19413,
 19527,
 19617,
 19855,
 20184,
 20484,
 20584,
 20686,
 20773,
 20792,
 20902,
 20926,
 20927,
 2135