# Setup

In [1]:
import os
import sys
from pathlib import Path
import random
import numpy as np
import pickle
from typing import List

# Add the src directory to Python path
olmo_core_path = Path.cwd() / "src"
if olmo_core_path.exists():
    sys.path.insert(0, str(olmo_core_path))

from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)


The history saving thread hit an unexpected error (OperationalError('disk I/O error')).History will not be written to the database.


In [2]:
# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/joberant/NLP_2425b/yoavbaron"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from olmo_eval import HFTokenizer
from datasets import load_dataset



# Prepare the dataset and dataloader

In [3]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
kas_dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 204.60it/s]


In [4]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=8,
    prefetch_factor = 16,
)

dataloader = data_loader_config.build(kas_dataset)
dataloader.reshuffle(1)

# Load PopQA dataset and filter entities

In [5]:
def get_important_chunks(dataset, min_num_chunks, max_num_chunks, instance_lengths):
    # Filter the dataset
    filtered_dataset = dataset['train'].filter(
        lambda example: min_num_chunks <= example['subject_num_chunks'] <= max_num_chunks
    )   

    # Create list of dictionaries with subject info and chunk lengths
    result_list = []    

    for example in filtered_dataset:
        
        subject_name = example['subj']
        subject_id = example['subj_id']
        chunks = example['subject_chunks']
        num_chunks = example['subject_num_chunks']

        chunk_lengths = instance_lengths[chunks]

        if subject_name == 'Madison':
            print(chunks)

        # Sort chunks by their lengths (descending order)
        if len(chunk_lengths) > 0:
            # Create pairs of (chunk, length) and sort by length
            chunk_length_pairs = list(zip(chunks, chunk_lengths))
            chunk_length_pairs.sort(key=lambda x: x[1], reverse=True)

            # Separate back into sorted chunks and lengths
            sorted_chunks = [pair[0] for pair in chunk_length_pairs]
            sorted_lengths = [pair[1] for pair in chunk_length_pairs]
        else:
            sorted_chunks = chunks
            sorted_lengths = chunk_lengths

        subject_dict = {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
        
        result_list.append(subject_dict)    

    # Sort the list by number of chunks (descending order)
    result_list.sort(key=lambda x: x['num_chunks'], reverse=True)

    return result_list


In [6]:
ds = load_dataset("dhgottesman/popqa-kas")

"""
importsnt chunks has the following structure:
        {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
"""
important_chunks = get_important_chunks(ds, 50, 100, kas_dataset.get_instance_lengths())

[820777, 941125, 976593, 1132074, 4955122, 5784515, 5893462, 6285853, 6301590, 6569641, 6709763, 982625, 1069094, 1174705, 1175438, 1188820, 3892745, 6043687, 7585232, 2052326, 4449953, 6471684, 6415190, 263910, 430203, 436392, 857271, 964356, 1492452, 1751563, 1953559, 2146414, 2244409, 2359129, 2390139, 2799444, 2839387, 3210372, 3220392, 3475484, 3514601, 3586220, 3735351, 3892746, 3950691, 4321806, 4489791, 4814079, 5279548, 5384395, 5985105, 6007311, 6123047, 6773572, 6837529, 6855838, 7299974, 7309924, 7714134, 8089661, 9309412, 9423790, 9496875, 9866817, 4171091, 8267712, 3130262, 8233434, 1211449, 6231806, 9180749, 342090, 534038, 1647807, 1961324, 2725096, 3113130, 3123297, 3461098, 4020853, 4255747, 4622325, 4713269, 4976107, 6497698, 6498978, 6855755, 7551764, 7688290, 7901132, 9352219, 10391979, 7551765, 7551766, 879397, 3232229, 7994593, 10329160, 7585233]
[500753, 500754, 500756, 7899334, 9766174, 7586543, 3930628, 4385903, 5956585, 5902917, 7143593, 8695947, 1058648, 376

# Load original batches

In [7]:
all_batches = np.load("/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/batch_indices.npy", allow_pickle=True)

# Sample injection points

In [8]:
import random

def sample_injection_points(total_steps, num_points_to_sample, max_num_chunks, interval, seed=None):
    """
    Samples unique injection points from a valid starting range to avoid overflow 
    when assigning chunk indices.

    Args:
        total_steps (int): The maximum possible step value (exclusive upper bound).
        num_points_to_sample (int): Number of injection points to sample.
        max_num_chunks (int): Maximum num_chunks across all entities.
        interval (int): Distance between chunk indices.
        seed (int, optional): Seed for reproducibility.

    Returns:
        List[int]: Sorted list of valid injection starting points.
    """
    if seed is not None:
        random.seed(seed)

    max_valid_start = total_steps - (max_num_chunks - 1) * interval
    if max_valid_start <= 0:
        raise ValueError("Interval and chunk size too large for total steps.")

    if num_points_to_sample > max_valid_start:
        raise ValueError("Cannot sample more injection points than available valid start points.")

    sampled_points = random.sample(range(max_valid_start), k=num_points_to_sample)
    return sorted(sampled_points)


def assign_indices_to_entities(entities, injection_points, interval):
    """
    Assigns indices to each entity starting at a given injection point with spacing.

    Args:
        entities (List[dict]): List of entity dicts.
        injection_points (List[int]): List of sampled injection start points.
        interval (int): Distance between chunk indices.

    Returns:
        Dict[str, List[int]]: Mapping from entity name to list of indices.
    """
    if len(entities) != len(injection_points):
        raise ValueError("Number of entities must match number of injection points.")

    result = {}

    for entity, start in zip(entities, injection_points):
        entity_id = entity['entity_id']
        num_chunks = entity['num_chunks']
        indices = [start + i * interval for i in range(num_chunks)]
        result[entity_id] = indices

    return result



In [16]:
interval = 5

total_number_of_batches = dataloader.total_batches
injection_points = sample_injection_points(total_number_of_batches, len(important_chunks), 100, interval, 0)
all_injection_points_per_entity = assign_indices_to_entities(important_chunks, injection_points, interval)
#all_injection_points_per_entity

# Build Swapping Dictionary

In [19]:
def shloop(
    injection_points: List[int],
    entity_data: dict,
    batch_to_chunks_map: dict,
) -> dict:
    """
    """
    # 1. Get entity chunks available for swapping and their lengths

    ent_chunk_to_len = dict(zip(entity_data['chunks'], entity_data['chunks_lengths']))
    ent_len_to_chunk = {v: k for k, v in ent_chunk_to_len.items()}

    # casting to int but might want to edit this
    batch_id_to_len = {}
    batch_len_to_id = {}
    for batch in injection_points:
        batch_len = int(32768 / len(batch_to_chunks_map[batch]))
        batch_id_to_len[batch] = batch_len
        batch_len_to_id[batch_len] = batch

    # 2. Calculate the injection span
    num_chunks = len(entity_data['chunks'])
    #print(f"Injection span: {list(injection_points)}")
    if len(injection_points) != num_chunks:
        f"Entity {entity_data['entity_id']} expected {num_chunks} injection points, but got {len(injection_points)}."
    

    sb = sorted(batch_len_to_id.keys())   
    se = sorted(ent_len_to_chunk.keys())

    chunks_to_batches = []
    for len_e in se:
        for len_b in sb:
            if len_b == len_e:
                #print(len_e, len_b)
                chunk_id = ent_len_to_chunk[len_e]
                batch_id = batch_len_to_id[len_b]

                #print(f"Chunk {chunk_id} with length {len_e} will be swapped with batch {batch_id} with length {len_b}")
                # get a random chunk id from the batch
                chunk_id_from_batch = random.choice(batch_to_chunks_map[batch_id])

                if [chunk_id, chunk_id_from_batch] in chunks_to_batches or [chunk_id_from_batch, chunk_id] in chunks_to_batches:
                    print(chunk_id, chunk_id_from_batch, "already in")
                
                chunks_to_batches.append([chunk_id, chunk_id_from_batch])
                chunks_to_batches.append([chunk_id_from_batch, chunk_id])
                #chunks_to_batches[chunk_id] = chunk_id_from_batch # chunk e goes to chunk e' in batch b
                #chunks_to_batches[chunk_id_from_batch] = chunk_id # add the symetric mapping

                ent_len_to_chunk.pop(len_e) # pop one of the lengths
                ent_chunk_to_len.pop(chunk_id) # pop the chunk from the entity and pop one of the lengths
                batch_len_to_id.pop(len_b)
                batch_id_to_len.pop(batch_id) # pop the batch and the length from the batch
                break
                
    # ranmly match the rest of the chunks
    for chunk_id, batch_id in zip(ent_chunk_to_len.keys(), batch_id_to_len.keys()):
        if chunk_id not in chunks_to_batches:
            chunk_id_from_batch = random.choice(batch_to_chunks_map[batch_id])

            if [chunk_id, chunk_id_from_batch] in chunks_to_batches or [chunk_id_from_batch, chunk_id] in chunks_to_batches:
                    print(chunk_id, chunk_id_from_batch, "already in")
                
            chunks_to_batches.append([chunk_id, chunk_id_from_batch])
            chunks_to_batches.append([chunk_id_from_batch, chunk_id])
            
            #chunks_to_batches[chunk_id] = chunk_id_from_batch
            #chunks_to_batches[chunk_id_from_batch] = chunk_id

    return chunks_to_batches

In [20]:
full_mapping = []
for i, important_chunk in enumerate(important_chunks):
    pts = all_injection_points_per_entity[important_chunk['entity_id']]

    # The 'important_chunk' variable is the integer you need.
    # Pass it directly to your function.
    res = shloop(
        pts,
        important_chunk,
        all_batches
    )
    # extend full mapping with the result
    full_mapping.extend(res)

In [21]:
combined_length = sum(entity['num_chunks'] for entity in important_chunks)
print("Combined length of all lists under important_chunks:", combined_length)

Combined length of all lists under important_chunks: 72433


In [22]:
len(full_mapping)

144866

In [33]:
grouped_dict = {}
for key, value in full_mapping:
    if key not in grouped_dict:
        grouped_dict[key] = []
    grouped_dict[key].append(value)

for i, (k, v) in enumerate(grouped_dict.items()):
    print(f"{k}: {v}")
    if i >= 29:
        break

3007040: [3850390]
3850390: [3007040]
8661782: [5089980]
5089980: [8661782]
9216828: [6726866]
6726866: [9216828]
3520673: [9068505]
9068505: [3520673]
3939494: [9890731]
9890731: [3939494]
10380179: [3992742]
3992742: [10380179]
3266546: [7442658]
7442658: [3266546]
4458969: [3945203]
3945203: [4458969]
8661779: [5738205, 164200]
5738205: [8661779]
10378758: [6873468]
6873468: [10378758]
3066923: [4946796]
4946796: [3066923]
3362605: [8911327]
8911327: [3362605]
4143475: [9724709]
9724709: [4143475]
7942646: [2564571]
2564571: [7942646]
3939490: [9743188]
9743188: [3939490]


In [24]:
from collections import defaultdict

# Dictionary to track occurrences of each item at each position (0 or 1)
position_counts = [defaultdict(int), defaultdict(int)]

for pair in full_mapping:
    position_counts[0][pair[0]] += 1
    position_counts[1][pair[1]] += 1

duplicates = {
    0: [item for item, count in position_counts[0].items() if count > 1],
    1: [item for item, count in position_counts[1].items() if count > 1]
}

print("Items appearing more than once at position 0:", len(duplicates[0]))
print("Items appearing more than once at position 1:", len(duplicates[1]))

Items appearing more than once at position 0: 12873
Items appearing more than once at position 1: 12873


In [28]:
from collections import defaultdict

# Track chunk_id -> list of entity_ids where it appears
chunk_to_entities = defaultdict(set)

for entity in important_chunks:
    entity_id = entity['entity_id']
    for chunk_id in entity['chunks']:
        chunk_to_entities[chunk_id].add(entity_id)

# Find chunk_ids that appear in more than one entity
duplicate_chunks = {chunk_id: list(entity_ids) for chunk_id, entity_ids in chunk_to_entities.items() if len(entity_ids) > 1}

print(f"Number of chunk ids appearing in multiple entities: {len(duplicate_chunks)}")
for chunk_id, entity_ids in list(duplicate_chunks.items())[:100]:  # show first 10 for brevity
    print(f"Chunk ID {chunk_id} appears in entities: {entity_ids}")

Number of chunk ids appearing in multiple entities: 2534
Chunk ID 10380177 appears in entities: [1258065, 2083978]
Chunk ID 2189467 appears in entities: [1258065, 978452]
Chunk ID 3066924 appears in entities: [1258065, 2083978]
Chunk ID 10378761 appears in entities: [1258065, 2083978]
Chunk ID 3277628 appears in entities: [1200906, 22940, 188071]
Chunk ID 3373418 appears in entities: [22940, 188071]
Chunk ID 3373419 appears in entities: [22940, 188071]
Chunk ID 3507554 appears in entities: [2143529, 22940]
Chunk ID 4460895 appears in entities: [22940, 188071]
Chunk ID 9734793 appears in entities: [929961, 22940, 1203333]
Chunk ID 3107368 appears in entities: [22940, 188071]
Chunk ID 1560730 appears in entities: [744404, 22940]
Chunk ID 3550098 appears in entities: [1774154, 22940, 188071]
Chunk ID 4460896 appears in entities: [22940, 188071]
Chunk ID 2661416 appears in entities: [2886218, 22940]
Chunk ID 3395466 appears in entities: [723619, 22940]
Chunk ID 8937698 appears in entities:

In [15]:
total_len = 0
for d in full_mapping:
    total_len += len(d)

total_len / 2

TypeError: object of type 'int' has no len()

In [None]:
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/swapping_dict.pkl', 'wb') as f:
    pickle.dump(swapping_dict, f)

In [None]:
# To load it back later
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/swapping_dict.pkl', 'rb') as f:
    swapping_dict = pickle.load(f)

# Rebuild dataset and dataloader with swapped chunk indices

In [None]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"


dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
    swapping_dict = swapping_dict,
)

reordered_dataset = dataset_config.build()


Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 279.79it/s]


In [None]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=8,
    prefetch_factor = 16,
)

dataloader = data_loader_config.build(reordered_dataset)
dataloader.reshuffle(1)

In [None]:
sorted_keys = list(swapping_dict.keys())
sorted_keys.sort()

In [None]:
for i, batch in enumerate(dataloader):
    if i == 44:
        print(batch)
        break


{'input_ids': tensor([[   279,   9941,    220,  ..., 100277, 100277, 100277],
        [   281,   3581,    352,  ..., 100277, 100277, 100277],
        [    54,    526,  19730,  ..., 100277, 100277, 100277],
        ...,
        [ 43810,    268,  29770,  ...,     13,    763,    220],
        [  4198,    301,   7881,  ...,   9770,    311,    990],
        [    44,    699,   1171,  ...,   4590,    263,    382]]), 'attention_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]]), 'index': tensor([ 7550324,  7397962,  8544508,  8794158,   157744,  9217822, 10304826,
         6638246,  6187538,  9949222,  7087033,   847439,  7167361,  6031080,
         4688568,  6522279,  5742455,  3652041,  4249598,  4753466, 10167340,
          591581,  2251989,  2083106,  7247992,  3248376,  1535220,  43

In [None]:
sorted_keys

[44,
 47,
 92,
 457,
 603,
 1108,
 1118,
 1207,
 1263,
 1479,
 1550,
 1773,
 2409,
 2542,
 2719,
 2746,
 3050,
 3467,
 3550,
 3617,
 3623,
 4205,
 4294,
 4381,
 4550,
 4551,
 4667,
 4734,
 4802,
 4967,
 5000,
 5180,
 5279,
 5486,
 5775,
 5814,
 5859,
 6005,
 6023,
 6174,
 6374,
 6460,
 6489,
 6534,
 6661,
 7284,
 7630,
 7915,
 7949,
 8166,
 8361,
 8372,
 8373,
 8586,
 8604,
 8706,
 8882,
 9000,
 9180,
 9816,
 9820,
 11138,
 11180,
 11181,
 11414,
 11674,
 11777,
 11985,
 12232,
 12234,
 13068,
 13287,
 13357,
 13384,
 13508,
 13813,
 14282,
 14514,
 14675,
 14881,
 15052,
 15240,
 15288,
 15413,
 15725,
 15797,
 16099,
 16372,
 16416,
 16497,
 16530,
 16556,
 16788,
 16809,
 16815,
 16958,
 17036,
 17227,
 17536,
 17624,
 17675,
 17850,
 17851,
 17852,
 17860,
 17888,
 17949,
 18086,
 18295,
 18323,
 18339,
 18368,
 18439,
 18673,
 18750,
 18759,
 19004,
 19047,
 19107,
 19353,
 19413,
 19527,
 19617,
 19855,
 20184,
 20484,
 20584,
 20686,
 20773,
 20792,
 20902,
 20926,
 20927,
 2135