In [1]:
import os
import sys
from pathlib import Path


# Add the src directory to Python path
olmo_core_path = Path.cwd() / "src"
if olmo_core_path.exists():
    sys.path.insert(0, str(olmo_core_path))

# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/joberant/NLP_2425b/shirab6"

In [2]:
# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
    DataCollator
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)

from olmo_eval import HFTokenizer

tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = True # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/joberant/NLP_2425b/shirab6/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00,  9.13it/s]


In [4]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=4,
    prefetch_factor = 8,
)

dataloader = data_loader_config.build(dataset)
dataloader.reshuffle(1)

In [5]:
import json

with open("/home/morg/students/gottesman3/knowledge-analysis-suite/retrieval-eval/jobs/20250710_190639/queries_5_10.json", "r") as f:
    rare_entites = json.load(f)


In [21]:
from collections import defaultdict
import math
import pandas as pd

# Group entities by number of chunks
chunks_groups = defaultdict(list)
instance_lenghts = dataset.get_instance_lengths()
lists_of_chunks = []

for entity in rare_entites.keys():
    chunks_lengths = []
    chunks = rare_entites[entity]['M4:H+EL+Co+CoC']['chunks']
    
    # Calculate chunk lengths distribution
    for chunk in chunks:
        length = instance_lenghts[int(chunk)]
        bucket_length = 2 ** math.ceil(math.log2(length))
        chunks_lengths.append(bucket_length)
    
    # Group by number of chunks
    num_chunks = len(chunks)
    chunks_groups[num_chunks].append({
        'entity': entity,
        'chunks_lengths': chunks_lengths,
        'chunks_ids': chunks
    })

# Display grouped results
for num_chunks, entities in sorted(chunks_groups.items()):
    print(f"\n=== Entities with {num_chunks} chunks ({len(entities)} entities) ===")
    for entity_info in entities:
        print(f"  Entity: {entity_info['entity']}")
        print(f"    Chunk lengths: {entity_info['chunks_lengths']}")


=== Entities with 5 chunks (459 entities) ===
  Entity: Spent (band)
    Chunk lengths: [64, 256, 256, 128, 64]
  Entity: Stavros Stathakis
    Chunk lengths: [128, 64, 1024, 128, 64]
  Entity: The Cats (1968 film)
    Chunk lengths: [2048, 2048, 2048, 256, 512]
  Entity: Billy Lumley
    Chunk lengths: [2048, 128, 256, 1024, 64]
  Entity: Rajalakshmi Engineering College
    Chunk lengths: [256, 256, 2048, 2048, 64]
  Entity: Cool It (TV series)
    Chunk lengths: [128, 1024, 256, 64, 512]
  Entity: Richard Field (theologian)
    Chunk lengths: [512, 1024, 64, 128, 256]
  Entity: The Hand That First Held Mine
    Chunk lengths: [1024, 128, 1024, 64, 256]
  Entity: The Forest (2002 film)
    Chunk lengths: [128, 256, 64, 512, 512]
  Entity: The Search (2014 film)
    Chunk lengths: [256, 64, 128, 1024, 512]
  Entity: Ghost in the Shell 2: Innocence
    Chunk lengths: [1024, 2048, 2048, 2048, 128]
  Entity: Larry Coon
    Chunk lengths: [64, 512, 256, 2048, 2048]
  Entity: Time to Come


In [16]:
chunks_groups

defaultdict(list,
            {7: [{'entity': 'The Black Unicorn',
               'chunks_lengths': [512, 1024, 2048, 1024, 2048, 512, 1024],
               'chunks_ids': ['2242896',
                '4408164',
                '2893075',
                '3908515',
                '5050684',
                '9089587',
                '8807791']},
              {'entity': 'Paul Humphries',
               'chunks_lengths': [128, 256, 128, 2048, 512, 64, 64],
               'chunks_ids': ['8742624',
                '8833363',
                '2221139',
                '3126825',
                '8471539',
                '8833365',
                '8833364']},
              {'entity': 'Dex (video game)',
               'chunks_lengths': [256, 1024, 2048, 1024, 1024, 64, 512],
               'chunks_ids': ['9766668',
                '5880547',
                '1717899',
                '408925',
                '6766970',
                '408926',
                '6655050']},
              {

In [18]:
chunks_groups[7][0]

{'entity': 'The Black Unicorn',
 'chunks_lengths': [512, 1024, 2048, 1024, 2048, 512, 1024],
 'chunks_ids': ['2242896',
  '4408164',
  '2893075',
  '3908515',
  '5050684',
  '9089587',
  '8807791']}

In [22]:
groups_of_chunks_lengths_to_batch = {}
for num_chunks, entities_list in chunks_groups.items():
    groups_of_chunks_lengths_to_batch[num_chunks] = []
    for entity in entities_list:
        for length in entity['chunks_lengths']:
            groups_of_chunks_lengths_to_batch[num_chunks].append(length)

In [23]:
groups_of_chunks_lengths_to_batch

{7: [512,
  1024,
  2048,
  1024,
  2048,
  512,
  1024,
  128,
  256,
  128,
  2048,
  512,
  64,
  64,
  256,
  1024,
  2048,
  1024,
  1024,
  64,
  512,
  512,
  64,
  128,
  256,
  512,
  1024,
  64,
  256,
  1024,
  512,
  1024,
  256,
  256,
  64,
  64,
  64,
  1024,
  512,
  256,
  128,
  1024,
  128,
  2048,
  64,
  2048,
  2048,
  2048,
  1024,
  2048,
  256,
  1024,
  128,
  512,
  512,
  512,
  64,
  64,
  64,
  256,
  128,
  64,
  128,
  128,
  256,
  128,
  128,
  512,
  1024,
  1024,
  2048,
  64,
  1024,
  2048,
  1024,
  128,
  512,
  64,
  2048,
  64,
  512,
  2048,
  2048,
  256,
  2048,
  256,
  256,
  2048,
  512,
  256,
  64,
  1024,
  512,
  64,
  512,
  512,
  64,
  512,
  1024,
  128,
  512,
  2048,
  2048,
  64,
  256,
  1024,
  128,
  128,
  1024,
  2048,
  2048,
  64,
  1024,
  1024,
  1024,
  2048,
  64,
  256,
  256,
  1024,
  256,
  128,
  1024,
  512,
  64,
  512,
  256,
  64,
  128,
  64,
  1024,
  1024,
  2048,
  256,
  2048,
  64,
  1024,
  256,
  512

In [33]:
import heapq

def build_batches(entities_chunks, global_batch_size):
    batches = []
    open_batches = []  # Heap of (-available_space, batch_index)
    batch_space = []   # True available space of each batch

    # Sort chunks descending per entity
    for chunks in entities_chunks.values():
        chunks.sort(reverse=True)

    for entity_id, chunks in entities_chunks.items():
        for chunk_size in chunks:
            placed = False
            temp = []

            while open_batches:
                neg_available_space, batch_idx = heapq.heappop(open_batches)
                current_space = batch_space[batch_idx]
                batch = batches[batch_idx]

                if entity_id in batch or chunk_size > current_space:
                    temp.append((neg_available_space, batch_idx))
                    continue

                # Place chunk
                batch[entity_id] = chunk_size
                batch_space[batch_idx] -= chunk_size
                heapq.heappush(open_batches, (-batch_space[batch_idx], batch_idx))
                placed = True
                break

            # Restore skipped
            for item in temp:
                heapq.heappush(open_batches, item)

            if not placed:
                # Create new batch
                new_batch = {entity_id: chunk_size}
                batches.append(new_batch)
                remaining = global_batch_size - chunk_size
                batch_space.append(remaining)
                heapq.heappush(open_batches, (-remaining, len(batches) - 1))

    return batches

In [34]:
global_batch_size = 32768

In [35]:
batches = build_batches(groups_of_chunks_lengths_to_batch, global_batch_size)

In [36]:
batches

[{7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256, 6: 2048},
 {7: 2048, 5: 64, 8: 256, 9: 256

In [37]:
from numpy import size


batch_sizes = {}
for index, batch in enumerate(batches):
    size = 0
    for num_chunks, length in batch.items():
        size += num_chunks * length
    
    batch_sizes[index] = size


In [38]:
batch_sizes

{0: 31296,
 1: 31296,
 2: 31296,
 3: 31296,
 4: 31296,
 5: 31296,
 6: 31296,
 7: 31296,
 8: 31296,
 9: 31296,
 10: 31296,
 11: 31296,
 12: 31296,
 13: 31296,
 14: 31296,
 15: 31296,
 16: 31296,
 17: 31296,
 18: 31296,
 19: 31296,
 20: 31296,
 21: 31296,
 22: 31296,
 23: 31296,
 24: 31296,
 25: 31296,
 26: 31296,
 27: 31296,
 28: 31296,
 29: 31296,
 30: 31296,
 31: 31296,
 32: 31296,
 33: 31296,
 34: 31296,
 35: 31296,
 36: 31296,
 37: 31296,
 38: 31296,
 39: 31296,
 40: 31296,
 41: 31296,
 42: 31296,
 43: 31296,
 44: 31296,
 45: 31296,
 46: 31296,
 47: 31296,
 48: 31296,
 49: 31296,
 50: 31296,
 51: 31296,
 52: 31296,
 53: 31296,
 54: 31296,
 55: 31296,
 56: 31296,
 57: 31296,
 58: 31296,
 59: 31296,
 60: 31296,
 61: 31296,
 62: 31296,
 63: 31296,
 64: 31296,
 65: 31296,
 66: 31296,
 67: 31296,
 68: 31296,
 69: 31296,
 70: 31296,
 71: 31296,
 72: 31296,
 73: 31296,
 74: 31296,
 75: 31296,
 76: 31296,
 77: 31296,
 78: 31296,
 79: 31296,
 80: 31296,
 81: 31296,
 82: 31296,
 83: 31296,
 8