## Overview

The **Knowledge Analysis Suite (KAS)** models are trained on a snapshot of Wikipedia from **April 2025**. The dataset creation process includes the following steps:

1. **Text Extraction**  
   Clean text is extracted from the Wikipedia XML dump.

2. **Hyperlink Extraction**  
   All hyperlinks present in each document are identified and extracted.

3. **Entity Linking**  
   An entity linking model (**Maverick**) is used to identify entity mentions that are not already linked.

4. **Coreference Resolution**  
   A coreference model is run to group references to the same entity, improving the coverage of entity mentions across the document.

---

Once the metadata is prepared, each document is tokenized and stored in a concatenated array. These documents are then split into chunks of the following token lengths:

- 2048
- 1024
- 512
- 256
- 128
- 64

Chunks of the same length are grouped into batches, where each batch contains a total of **X tokens** (global batch size).

During training, at each step, the model processes one batch of chunks. Gradients are computed based on this batch, and the optimizer updates the model parameters accordingly.

---

### File Structure

- **Metadata Files**  
  `/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/part-[0-7]-00000_new_2.csv`

- **Concatenated Documents**  
  `/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/part-[0-7]-00000.npy`

- **Dataset and Dataloader Cache**  
  `/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache`

---

### Your Task

Your primary task is to implement functionality that controls:

- How the **Dataloader** organizes and delivers batches during training.


In [1]:
import os

# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/morg/students/gottesman3/.cache/huggingface"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Dataset

In [2]:

from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
    DataCollator
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)

from olmo_eval import HFTokenizer

tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

include_instance_metadata = True # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 217.11it/s]


In [3]:
# Accessing the first element in the dataset (This is not necessarily the first sequence that the model sees.)
for chunk in dataset:
    if len(chunk["input_ids"].tolist()) > 1024:
        break
chunk

{'input_ids': tensor([86791, 59174,   320,  ...,   813,  3070,  3974]),
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,


In [4]:
# Decoding the chunk tokens to text.
text = tokenizer.decode(chunk["input_ids"].tolist())
text

'Hunter Greene (baseball)\n\nChristian Hunter Greene (born August 6, 1999) is an American professional baseball pitcher for the Cincinnati Reds of Major League Baseball (MLB). The Reds selected him second overall in the 2017 MLB Draft.\n\nBorn in Los Angeles, California, Greene learned how to pitch at the Major League Baseball Urban Youth Academy in Compton. His fastball velocity was already during his first year at Notre Dame High School, and by the time he graduated in 2017, it was up to . The Reds drafted Greene out of high school, and he joined their farm system rather than playing college baseball. Greene suffered an ulnar collateral ligament injury partway through the 2018 season and underwent Tommy John surgery the following year. The COVID-19 pandemic kept him from pitching for another year, but once he returned in 2021, he quickly rose through the minor leagues.\n\nGreene made the Reds\' Opening Day roster in 2022. In only the second game of his major league career, he set an 

### Visualizing the metadata

In [5]:
# Extracting metadata for the chunk.
metadata = chunk["metadata"]
entities = metadata["entities"]

In [6]:
from IPython.display import display, HTML
from collections import Counter
import ipywidgets as widgets

_style = """
<style>
.compact-mention-output {
    font-size: 11px !important;
    line-height: 1.4 !important;
    max-height: 1100px;
    min-height: 550px;
    max-width: 1300px;
    min-width: 900px;
    width: 98vw;
    overflow-y: auto;
    background: #fcfcfc;
    padding: 24px 36px 24px 16px;
    border: 1px solid #ddd;
    border-radius: 10px;
}
</style>
"""
display(HTML(_style))

def display_highlighted_mentions(text, mentions_list):
    mention_fields = ["hyperlinks", "entity_linking", "coref", "coref_cluster"]

    # Build index from mentions list
    qid_counts = Counter()
    field_score_ranges = {field: [] for field in mention_fields}
    all_scores = []
    mention_map = {}  # (start, end) → {qid → (scores, agg, name)}

    for m in mentions_list:
        start, end = m["char_start"], m["char_end"]
        key = (start, end)
        for c in m.get("candidates", []):
            qid = c["qid"]
            name = c.get("name", "")
            scores = c.get("scores_by_source", {})
            agg = c.get("aggregated_score", 0.0)

            qid_counts[qid] += 1
            all_scores.append(agg)
            for field in mention_fields:
                if field in scores:
                    field_score_ranges[field].append(scores[field])

            mention_map.setdefault(key, {})[qid] = (scores, agg, name)

    # Sort QIDs and generate dropdown options with names
    sorted_qids = sorted(qid_counts.items(), key=lambda x: (-x[1], x[0]))
    qid_labels = []
    for qid, count in sorted_qids:
        name = None
        for m in mentions_list:
            for c in m.get("candidates", []):
                if c["qid"] == qid and c.get("name"):
                    name = c["name"]
                    break
            if name:
                break
        label = f"{name} ({qid}) [{count}]" if name else f"{qid} [{count}]"
        qid_labels.append((label, qid))

    min_score = round(min(all_scores), 2) if all_scores else 0.0
    max_score = round(max(all_scores), 2) if all_scores else 1.0

    field_sliders = {}
    for field in mention_fields:
        values = field_score_ranges[field]
        fmin = round(min(values), 2) if values else min_score
        fmax = round(max(values), 2) if values else max_score

        # Set default value based on field
        if field == "hyperlinks":
            default_val = 1.0
        elif field == "entity_linking":
            default_val = 0.57
        else:
            default_val = 0.6

        # Clamp within range
        default_val = max(fmin, min(fmax, default_val))

        field_sliders[field] = widgets.FloatSlider(
            value=default_val,
            min=fmin,
            max=fmax,
            step=0.01,
            description=f'{field} ≥',
            readout_format='.2f',
            continuous_update=True,
            layout=widgets.Layout(width='280px')
        )

    checkboxes = {field: widgets.Checkbox(value=True, description=field, indent=False)
                  for field in mention_fields}
    checkbox_widgets = list(checkboxes.values())

    dropdown = widgets.Dropdown(
        options=qid_labels,
        description='Entity:',
        value=qid_labels[0][1],
        style={'description_width': 'initial'}
    )

    html_out = widgets.HTML()

    def render_text(selected_qid, field_thresholds, active_fields):
        highlight_spans = []
        for (start, end), qid_map in mention_map.items():
            if selected_qid not in qid_map:
                continue
            scores, agg, _ = qid_map[selected_qid]
            if any(scores.get(field, 0.0) >= field_thresholds[field] for field in active_fields):
                highlight_spans.append((start, end, agg))

        highlight_spans_sorted = sorted(highlight_spans, key=lambda x: -x[2])
        span_ranks = {(start, end): rank + 1 for rank, (start, end, _) in enumerate(highlight_spans_sorted)}

        if not highlight_spans:
            return f"<div class='compact-mention-output'>{text}</div>"

        events = []
        for (start, end, score) in highlight_spans:
            rank = span_ranks[(start, end)]
            events.append((start, 'start', score, rank))
            events.append((end, 'end', None, None))
        events.sort(key=lambda x: (x[0], 0 if x[1] == 'end' else 1))

        out = []
        last_idx = 0
        highlight_stack = []

        for idx, typ, score, rank in events:
            if last_idx < idx:
                out.append(text[last_idx:idx])
            if typ == 'start':
                out.append(
                    f'<span style="background-color: #fff574; border-radius:3px; padding:2px 4px; font-size: 11px;" title="aggregated: {score}, rank: {rank}">'
                    f'<b>({rank})</b> '
                )
                highlight_stack.append('open')
            elif typ == 'end' and highlight_stack:
                out.append('</span>')
                highlight_stack.pop()
            last_idx = idx

        out.append(text[last_idx:])
        return f'<div class="compact-mention-output">{"".join(out)}</div>'

    def update_html(change=None):
        qid = dropdown.value
        active_fields = [f for f, cb in checkboxes.items() if cb.value]
        thresholds = {f: field_sliders[f].value for f in active_fields}
        html_out.value = render_text(qid, thresholds, active_fields)

    dropdown.observe(update_html, names='value')
    for cb in checkbox_widgets:
        cb.observe(update_html, names='value')
    for slider in field_sliders.values():
        slider.observe(update_html, names='value')

    update_html()

    controls = widgets.VBox([
        dropdown,
        widgets.HBox(checkbox_widgets),
        widgets.HBox([field_sliders[f] for f in mention_fields])
    ], layout=widgets.Layout(width='99%', max_width='1300px'))

    container = widgets.VBox([controls, html_out],
                             layout=widgets.Layout(width='99%', max_width='1350px'))
    display(container)


In [11]:
display_highlighted_mentions(text, entities)

VBox(children=(VBox(children=(Dropdown(description='Entity:', options=(('Hunter Greene (Q27983473) [128]', 'Q2…

## Dataloader

In [8]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=262144,
    seed=0,
    num_workers=4,
    prefetch_factor = 8,
)

dataloader = data_loader_config.build(dataset)

### Important Note!!!
You must set the epoch of the dataloader. Below we set the epoch to 1.

In [9]:
dataloader.reshuffle(1)

### Iterating over the dataloader

In [None]:
for batch in dataloader:
    print(batch)
    break

In [12]:
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'index', 'metadata'])

In [14]:
batch["input_ids"].shape
# 128 chunks of length 2048 = 262,144 tokens (global batch size)

torch.Size([128, 2048])

In [None]:
batch["attention_mask"][11]
# When we create the dataset chunks, we do not allow chunks to split in the middle of an entity mention,
# therefore some chunks may be shorter than 2048. In this case, we add padding and zero-out the effect
# of the padding tokens by setting their corresponding attention mask values to 0.

tensor([1., 1., 1.,  ..., 1., 1., 0.])

In [27]:
len(batch["index"])
# The "index" element holds the list of chunk indices in the *dataset* that compose this batch.
# See a demonstration of this below...

128

In [28]:
tokenizer.decode(batch["input_ids"][0].tolist())

'2013 Israeli legislative election\n\nEarly legislative elections were held in Israel on 22 January 2013 to elect the 120 members of the nineteenth Knesset. Public debate over the Tal Law had nearly led to early elections in 2012, but they were aborted at the last moment after Kadima briefly joined the government. The elections were later called in early October 2012 after failure to agree on the budget for the 2013 fiscal year.\n\nThe elections saw the Likud Yisrael Beiteinu alliance emerge as the largest faction in the Knesset, winning 31 of the 120 seats. Likud leader Benjamin Netanyahu formed the country\'s thirty-third government after establishing a coalition with Yesh Atid, the Jewish Home, and Hatnua, which between them held 68 seats.\n\nFollowing the 2009 elections, in which right-wing and religious parties won the majority (65 out of 120, or 54%) of the seats, opposition leader Benjamin Netanyahu established a government including right-wing parties Likud, Yisrael Beiteinu, t

In [34]:
tokenizer.decode(dataset[batch["index"][0].item()]["input_ids"].tolist())

'2013 Israeli legislative election\n\nEarly legislative elections were held in Israel on 22 January 2013 to elect the 120 members of the nineteenth Knesset. Public debate over the Tal Law had nearly led to early elections in 2012, but they were aborted at the last moment after Kadima briefly joined the government. The elections were later called in early October 2012 after failure to agree on the budget for the 2013 fiscal year.\n\nThe elections saw the Likud Yisrael Beiteinu alliance emerge as the largest faction in the Knesset, winning 31 of the 120 seats. Likud leader Benjamin Netanyahu formed the country\'s thirty-third government after establishing a coalition with Yesh Atid, the Jewish Home, and Hatnua, which between them held 68 seats.\n\nFollowing the 2009 elections, in which right-wing and religious parties won the majority (65 out of 120, or 54%) of the seats, opposition leader Benjamin Netanyahu established a government including right-wing parties Likud, Yisrael Beiteinu, t

In [None]:
# Metadata also match.
batch["metadata"][0] == dataset[batch["index"][0].item()]["metadata"]

True