# Imports

In [1]:
import os
import sys
from pathlib import Path
import random
import numpy as np
import pickle
from typing import List

# Add the src directory to Python path
olmo_core_path = Path.cwd() / "src"
if olmo_core_path.exists():
    sys.path.insert(0, str(olmo_core_path))

from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)

from olmo_core.data.chunk_ordering_utils import (
    get_disjoint_entities, 
    get_important_chunks, 
    sample_injection_points, 
    create_swapping_dict_for_injection_points,
    split_into_bins_inclusive
)


In [2]:
# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/joberant/NLP_2425b/shirab6"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from olmo_eval import HFTokenizer
from datasets import load_dataset



# Load and create all necesarry data

In [3]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final_tmp/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
kas_dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:01<00:00,  4.50it/s]


In [4]:
ds = load_dataset("dhgottesman/popqa-kas")

Downloading readme: 0.00B [00:00, ?B/s]

In [7]:
"""
importsnt chunks has the following structure:
        {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
"""
important_chunks = get_important_chunks(ds, 10, 100, kas_dataset.get_instance_lengths())

Filter:   0%|          | 0/14267 [00:00<?, ? examples/s]

In [8]:
len(important_chunks)

5365

In [6]:
disjoint_entities = get_disjoint_entities(important_chunks, 406)

In [7]:
original_batches = np.load("/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/batch_indices.npy", allow_pickle=True)

In [8]:
intervals = [1, 100]
total_number_of_batches = 109672
bins = split_into_bins_inclusive(total_number_of_batches, num_bins=8)

In [None]:
experiment_type = "BURTST"
seed = 0
interval = 100
max_num_chunks = 100
for i, sampling_range in enumerate(bins):
    start, end = sampling_range
    save_file_path = f'/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/chunk_ordering_data/{experiment_type}_steps_{start}-{end}_'
    injection_points = sample_injection_points(experiment_type, disjoint_entities, total_number_of_batches, sampling_range=sampling_range, seed=seed, max_num_chunks=max_num_chunks, interval=interval)
    create_swapping_dict_for_injection_points(disjoint_entities, injection_points, original_batches, save_file_path, interval)

{'Q1129186': [2, 102, 202, 302, 402, 502, 602, 702, 802, 902, 1002, 1102, 1202, 1302, 1402, 1502, 1602, 1702, 1802, 1902, 2002, 2102, 2202, 2302, 2402, 2502, 2602, 2702, 2802, 2902, 3002, 3102, 3202, 3302, 3402, 3502, 3602, 3702, 3802, 3902, 4002, 4102, 4202, 4302, 4402, 4502, 4602, 4702, 4802, 4902, 5002, 5102, 5202, 5302, 5402, 5502, 5602, 5702, 5802, 5902, 6002, 6102, 6202, 6302, 6402], 'Q30805': [19, 119, 219, 319, 419, 519, 619, 719, 819, 919, 1019, 1119, 1219, 1319, 1419, 1519, 1619, 1719, 1819, 1919, 2019, 2119, 2219, 2319, 2419, 2519, 2619, 2719, 2819, 2919, 3019, 3119, 3219, 3319, 3419, 3519, 3619, 3719, 3819, 3919, 4019, 4119, 4219, 4319, 4419, 4519, 4619, 4719, 4819, 4919, 5019, 5119, 5219, 5319, 5419, 5519, 5619, 5719, 5819, 5919, 6019, 6119, 6219, 6319, 6419, 6519, 6619, 6719, 6819, 6919, 7019, 7119, 7219, 7319, 7419, 7519, 7619, 7719, 7819, 7919], 'Q470692': [25, 125, 225, 325, 425, 525, 625, 725, 825, 925, 1025, 1125, 1225, 1325, 1425, 1525, 1625, 1725, 1825, 1925, 2025,

In [31]:
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/chunk_ordering_data/BURTST_steps_68546-82254_swapping_dict_interval_100.pkl', 'rb') as f:
    swapping = pickle.load(f)

In [23]:
len(swapping)

58440

In [29]:
injection_points

{'Q1129186': [102776,
  105161,
  104809,
  103934,
  98616,
  97194,
  96616,
  106065,
  96103,
  102689,
  96652,
  102557,
  102789,
  104179,
  106256,
  97693,
  107760,
  101976,
  108277,
  101458,
  109334,
  103457,
  104332,
  103412,
  108767,
  100548,
  102017,
  99795,
  102523,
  109079,
  99114,
  107641,
  96225,
  107156,
  106562,
  102142,
  106108,
  102312,
  105346,
  101997,
  107716,
  104288,
  98604,
  101495,
  102914,
  104235,
  101898,
  96002,
  99530,
  106166,
  98955,
  97224,
  100458,
  99705,
  96030,
  109247,
  103684,
  106990,
  103487,
  106169,
  98845,
  99929,
  100080,
  97474,
  107509],
 'Q30805': [99397,
  102115,
  103033,
  106251,
  100926,
  101840,
  102589,
  99704,
  108239,
  96578,
  107768,
  96781,
  101781,
  98443,
  98542,
  106619,
  102408,
  108429,
  107942,
  98075,
  102076,
  108234,
  100715,
  107571,
  104896,
  97820,
  96509,
  105387,
  103379,
  107715,
  103167,
  106862,
  99307,
  105831,
  109201,
  9992

In [30]:
for entity in disjoint_entities:
    if entity["num_chunks"] != len(injection_points[entity["subject_qid"]]):
        print("wrong swapping")

In [19]:
total = 0

for entity, injections in injection_points.items():
    total += len(injections)

total * 2

58440

In [18]:
from collections import Counter


# Flatten all values into one list
all_values = [item for sublist in injection_points.values() for item in sublist]

# Count occurrences
value_counts = Counter(all_values)

# Print the result
for item, count in value_counts.items():
    print(f"{item}: {count}")

102776: 3
105161: 4
104809: 5
103934: 1
98616: 2
97194: 3
96616: 5
106065: 4
96103: 3
102689: 2
96652: 3
102557: 3
102789: 2
104179: 2
106256: 2
97693: 7
107760: 3
101976: 5
108277: 3
101458: 6
109334: 3
103457: 1
104332: 2
103412: 2
108767: 2
100548: 2
102017: 2
99795: 4
102523: 6
109079: 3
99114: 4
107641: 6
96225: 1
107156: 2
106562: 2
102142: 2
106108: 4
102312: 1
105346: 2
101997: 3
107716: 2
104288: 2
98604: 1
101495: 2
102914: 2
104235: 3
101898: 7
96002: 2
99530: 2
106166: 1
98955: 3
97224: 2
100458: 3
99705: 3
96030: 2
109247: 1
103684: 3
106990: 3
103487: 3
106169: 6
98845: 2
99929: 6
100080: 3
97474: 1
107509: 5
99397: 2
102115: 8
103033: 2
106251: 4
100926: 1
101840: 3
102589: 3
99704: 5
108239: 1
96578: 1
107768: 1
96781: 2
101781: 3
98443: 3
98542: 3
106619: 2
102408: 3
108429: 5
107942: 2
98075: 4
102076: 2
108234: 2
100715: 3
107571: 3
104896: 3
97820: 3
96509: 2
105387: 3
103379: 8
107715: 4
103167: 2
106862: 3
99307: 2
105831: 2
109201: 3
99926: 3
96001: 3
106551: 3
9

In [15]:
len(swapping)

419