# Imports

In [2]:
import os
import sys
from pathlib import Path
import random
import numpy as np
import pickle
from typing import List

# Add the src directory to Python path
olmo_core_path = Path.cwd() / "src"
if olmo_core_path.exists():
    sys.path.insert(0, str(olmo_core_path))

from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)

from olmo_core.data.chunk_ordering_utils import create_swapping_dict_for_steps_interval, get_disjoint_entities, get_important_chunks, sample_injection_points, create_swapping_dict_for_injection_points, assign_indices_to_entities, sample_random_injection_points


In [3]:
# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/joberant/NLP_2425b/shirab6"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from olmo_eval import HFTokenizer
from datasets import load_dataset



# Load and create all necesarry data

In [4]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = False # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final_tmp/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
kas_dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 257.28it/s]


In [5]:
ds = load_dataset("dhgottesman/popqa-kas")

In [6]:
"""
importsnt chunks has the following structure:
        {
            'entity_id': subject_id,
            'num_chunks': num_chunks,
            'chunks': sorted_chunks,
            'chunks_lengths': sorted_lengths
        }
"""
important_chunks = get_important_chunks(ds, 50, 100, kas_dataset.get_instance_lengths())

In [7]:
disjoint_entities = get_disjoint_entities(important_chunks, 406)

In [8]:
original_batches = np.load("/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/batch_indices.npy", allow_pickle=True)

In [35]:
intervals = [1, 50, 100]
save_file_path = '/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/chunk_ordering_data/random_last_20_percent_2_'

In [36]:
total_number_of_batches = 109672
min_start_point = int(total_number_of_batches * 0.9)
seed = 300
all_injection_points_per_entity = sample_random_injection_points(total_number_of_batches, disjoint_entities, min_start_point, seed)
all_injection_points_per_entity

29220
range(80452, 109672)
[91657, 101146, 84914, 82953, 101886, 98829, 93834, 107279, 93052, 90114, 100910, 106402, 87207, 104265, 104584, 102273, 88287, 97604, 106446, 86631, 86357, 95372, 109187, 85308, 95399, 93071, 88579, 101948, 108343, 84724, 90219, 91927, 95129, 91945, 89461, 99637, 94086, 85508, 99351, 86622, 87950, 81331, 94646, 106796, 81902, 102761, 93965, 98549, 96127, 98061, 107790, 90239, 106324, 103824, 92430, 87141, 89957, 109548, 102315, 98389, 84463, 103809, 103468, 90166, 105313, 105111, 91760, 103966, 80964, 97768, 84759, 90796, 84910, 88942, 105106, 106146, 87960, 94728, 103316, 93762, 90457, 107766, 95928, 101193, 81391, 106721, 95713, 89906, 99331, 94675, 106783, 84339, 82433, 95682, 101717, 96595, 107522, 93161, 85594, 90979, 93459, 97840, 106307, 84776, 104592, 88359, 95169, 89284, 82461, 80523, 89691, 93211, 104561, 92356, 98189, 81639, 86202, 95759, 95439, 93936, 108548, 97288, 82567, 83923, 91558, 82856, 90023, 96679, 99238, 86997, 93856, 89359, 83815, 8902

{'Q1129186': [91657,
  101146,
  84914,
  82953,
  101886,
  98829,
  93834,
  107279,
  93052,
  90114,
  100910,
  106402,
  87207,
  104265,
  104584,
  102273,
  88287,
  97604,
  106446,
  86631,
  86357,
  95372,
  109187,
  85308,
  95399,
  93071,
  88579,
  101948,
  108343,
  84724,
  90219,
  91927,
  95129,
  91945,
  89461,
  99637,
  94086,
  85508,
  99351,
  86622,
  87950,
  81331,
  94646,
  106796,
  81902,
  102761,
  93965,
  98549,
  96127,
  98061,
  107790,
  90239,
  106324,
  103824,
  92430,
  87141,
  89957,
  109548,
  102315,
  98389,
  84463,
  103809,
  103468,
  90166,
  105313],
 'Q30805': [105111,
  91760,
  103966,
  80964,
  97768,
  84759,
  90796,
  84910,
  88942,
  105106,
  106146,
  87960,
  94728,
  103316,
  93762,
  90457,
  107766,
  95928,
  101193,
  81391,
  106721,
  95713,
  89906,
  99331,
  94675,
  106783,
  84339,
  82433,
  95682,
  101717,
  96595,
  107522,
  93161,
  85594,
  90979,
  93459,
  97840,
  106307,
  84776,
  10459

In [37]:
create_swapping_dict_for_injection_points(disjoint_entities, all_injection_points_per_entity, original_batches, save_file_path)

{7484675: 8491586,
 8491586: 7484675,
 3431988: 404166,
 404166: 3431988,
 8967952: 10320847,
 10320847: 8967952,
 3918203: 7044565,
 7044565: 3918203,
 8814376: 7597693,
 7597693: 8814376,
 4573211: 844663,
 844663: 4573211,
 2237061: 4253971,
 4253971: 2237061,
 2611446: 3027093,
 3027093: 2611446,
 2954313: 3507304,
 3507304: 2954313,
 3119621: 5760391,
 5760391: 3119621,
 3260500: 1101471,
 1101471: 3260500,
 3286712: 6829817,
 6829817: 3286712,
 3398897: 2963594,
 2963594: 3398897,
 3398898: 3106179,
 3106179: 3398898,
 4777376: 2480455,
 2480455: 4777376,
 7641034: 10221336,
 10221336: 7641034,
 10348211: 3391301,
 3391301: 10348211,
 3577493: 10058773,
 10058773: 3577493,
 2824558: 519037,
 519037: 2824558,
 4562086: 2998758,
 2998758: 4562086,
 5665654: 3633260,
 3633260: 5665654,
 7127532: 3709907,
 3709907: 7127532,
 8283741: 1152198,
 1152198: 8283741,
 9279434: 868944,
 868944: 9279434,
 4076895: 6744790,
 6744790: 4076895,
 3751845: 5656908,
 5656908: 3751845,
 1596766: 30

In [48]:
with open('/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/chunk_ordering_data/last_10_percent_swapping_dict_interval_1.pkl', 'rb') as f:
    swapping = pickle.load(f)

In [49]:
len(swapping)

58440