## Overview

The **LMEnt** models are trained on a snapshot of Wikipedia from **April 2025**. The dataset creation process includes the following steps:

1. **Text Extraction**  
   Clean text is extracted from the Wikipedia XML dump.

2. **Hyperlink Extraction**  
   All hyperlinks present in each document are identified and extracted.

3. **Entity Linking**  
   An entity linking model (**Maverick**) is used to identify entity mentions that are not already linked.

4. **Coreference Resolution**  
   A coreference model is run to group references to the same entity, improving the coverage of entity mentions across the document.

---

Once the metadata is prepared, each document is tokenized and stored in a concatenated array. These documents are then split into chunks of the following token lengths:

- 2048
- 1024
- 512
- 256
- 128
- 64

Chunks of the same length are grouped into batches, where each batch contains a total of **32,768 tokens** (global batch size).

During training, at each step, the model processes one batch of chunks. Gradients are computed based on this batch, and the optimizer updates the model parameters accordingly.

---

### File Structure

- **Metadata Files**  
  `LMEnt-Dataset/dataset-metadata/part-[0-7]-00000.csv.gz` uncompressed `LMEnt-Dataset/dataset-metadata/part-[0-7]-00000.csv`

- **Concatenated Documents**  
  `LMEnt-Dataset/dataset_tokenized/part-[0-7]-00000.npy`

- **Dataset and Dataloader Cache**  
  Copy `LMEnt-Dataset/dataset-cache` to `<root_dir>/LMEnt/OLMo-core/hp_final/dataset-cache`


## Imports

In [2]:
import os
import sys

# put your repo's src FIRST
sys.path.insert(0, "/home/morg/students/gottesman3/LMEnt/OLMo-core/src")
sys.modules.pop("olmo_core", None)

import olmo_core
print("USING:", olmo_core.__file__)

sys.path.insert(0, "/home/morg/students/gottesman3/LMEnt/dolma/python")
import dolma
print("USING:", dolma.__file__)

os.environ["HOME"] = "/home/morg/students/gottesman3"

# Set your new cache base directory (change this to your preferred location)
cache_base = "/home/morg/students/gottesman3/.cache/huggingface"

# Set all relevant Hugging Face cache directories
os.environ["HF_HOME"] = cache_base
os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_base, "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_base, "datasets")
os.environ["HF_TOKENIZERS_CACHE"] = os.path.join(cache_base, "tokenizers")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

USING: /home/morg/students/gottesman3/LMEnt/OLMo-core/src/olmo_core/__init__.py
USING: /home/morg/students/gottesman3/LMEnt/dolma/python/dolma/__init__.py


In [None]:
import json
import torch
import numpy as np

# OLMo-Core
from olmo_core.data.utils import load_array_slice_into_tensor
from olmo_core.data import (
    NumpyDataLoaderConfig,
    NumpyDatasetConfig,
    NumpyDatasetType,
    TokenizerConfig,
    DataCollator
)
from olmo_core.data.numpy_dataset import (
    VSLCurriculumType,
    VSLCurriculumConfig,
)

from olmo_eval import HFTokenizer



## Tokenizer

In [4]:
tokenizer_config = TokenizerConfig.dolma2()
tokenizer = HFTokenizer(
            tokenizer_config.identifier,
            pad_token_id=tokenizer_config.pad_token_id,
            eos_token_id=tokenizer_config.eos_token_id,
            bos_token_id=tokenizer_config.bos_token_id,
        )

## Dataset / Dataloader

In [39]:
include_instance_metadata = True # Set to true when you want tp retrieve metadata, during training set this to False
work_dir = "/home/morg/students/gottesman3/knowledge-analysis-suite/OLMo-core/hp_final/dataset-cache"

dataset_config = NumpyDatasetConfig.glob(
    "/home/morg/students/gottesman3/knowledge-analysis-suite/dolma/python/final_tokenizations_with_offsets/no_special/*.npy",  # can be globs
    name=NumpyDatasetType.kas_vsl,
    max_sequence_length=2048,
    min_sequence_length=64,
    vsl_curriculum=VSLCurriculumConfig(name=VSLCurriculumType.grow_p2, num_cycles=8, balanced=False),
    tokenizer=tokenizer_config,
    work_dir=str(work_dir),
    include_instance_metadata=include_instance_metadata,
)
dataset = dataset_config.build()

Loading metadata: 100%|██████████| 8/8 [00:00<00:00, 170.49it/s]


In [40]:
dataset.tokenizer

AttributeError: 'NumpyKASVSLDataset' object has no attribute 'tokenizer'

In [None]:
print(tokenizer.decode(dataset[0]['input_ids'].tolist()))

Ab Balutak

Ab Balutak (, also Romanized as Āb Balūṭak; also known as Ābbalūṭak) is a village in Dehdasht-e Sharqi Rural District, in the Central District of Kohgiluyeh County,


In [None]:
print(json.dumps(dataset[0]['metadata'], indent=2))

{
  "start": 0,
  "end": 98,
  "id": 40123111,
  "src": "/home/morg/dataset/maverick/maverick_6.json",
  "loc": 9915,
  "title": "Ab Balutak",
  "entities": [
    {
      "char_start": 31,
      "char_end": 39,
      "text_mention": "Romanize",
      "candidates": [
        {
          "qid": "Q976327",
          "name": "Romanize",
          "scores_by_source": {
            "hyperlinks": 1.0,
            "entity_linking": 0.0,
            "coref": 0.0,
            "coref_cluster": 0.0
          },
          "aggregated_score": 0.4
        }
      ]
    },
    {
      "char_start": 97,
      "char_end": 129,
      "text_mention": "Dehdasht-e Sharqi Rural District",
      "candidates": [
        {
          "qid": "Q5685337",
          "name": "Dehdasht-e Sharqi Rural District",
          "scores_by_source": {
            "hyperlinks": 1.0,
            "entity_linking": 0.92,
            "coref": 0.0,
            "coref_cluster": 0.0
          },
          "aggregated_score": 0.68
    

In [None]:
data_loader_config = NumpyDataLoaderConfig(
    global_batch_size=32768,
    seed=0,
    num_workers=4,
    prefetch_factor = 8,
)

dataloader = data_loader_config.build(dataset)

dataloader.reshuffle(1)

  return torch._C._cuda_getDeviceCount() > 0


In [None]:
for batch in dataloader:
    break

In [None]:
for i, chunk_id in enumerate(batch["index"]):
    assert(tokenizer.decode(dataset[chunk_id]["input_ids"].tolist()) == tokenizer.decode(batch["input_ids"][i].tolist()))

## Visualizing Metadata

In [None]:
from IPython.display import display, HTML
from collections import Counter
import ipywidgets as widgets

_style = """
<style>
.compact-mention-output {
    font-size: 11px !important;
    line-height: 1.4 !important;
    max-height: 1100px;
    min-height: 550px;
    max-width: 1300px;
    min-width: 900px;
    width: 98vw;
    overflow-y: auto;
    background: #fcfcfc;
    padding: 24px 36px 24px 16px;
    border: 1px solid #ddd;
    border-radius: 10px;
}
</style>
"""
display(HTML(_style))

def display_highlighted_mentions(text, mentions_list):
    mention_fields = ["hyperlinks", "entity_linking", "coref", "coref_cluster"]

    # Build index from mentions list
    qid_counts = Counter()
    field_score_ranges = {field: [] for field in mention_fields}
    all_scores = []
    mention_map = {}  # (start, end) → {qid → (scores, agg, name)}

    for m in mentions_list:
        start, end = m["char_start"], m["char_end"]
        key = (start, end)
        for c in m.get("candidates", []):
            qid = c["qid"]
            name = c.get("name", "")
            scores = c.get("scores_by_source", {})
            agg = c.get("aggregated_score", 0.0)

            qid_counts[qid] += 1
            all_scores.append(agg)
            for field in mention_fields:
                if field in scores:
                    field_score_ranges[field].append(scores[field])

            mention_map.setdefault(key, {})[qid] = (scores, agg, name)

    # Sort QIDs and generate dropdown options with names
    sorted_qids = sorted(qid_counts.items(), key=lambda x: (-x[1], x[0]))
    qid_labels = []
    for qid, count in sorted_qids:
        name = None
        for m in mentions_list:
            for c in m.get("candidates", []):
                if c["qid"] == qid and c.get("name"):
                    name = c["name"]
                    break
            if name:
                break
        label = f"{name} ({qid}) [{count}]" if name else f"{qid} [{count}]"
        qid_labels.append((label, qid))

    min_score = round(min(all_scores), 2) if all_scores else 0.0
    max_score = round(max(all_scores), 2) if all_scores else 1.0

    field_sliders = {}
    for field in mention_fields:
        values = field_score_ranges[field]
        fmin = round(min(values), 2) if values else min_score
        fmax = round(max(values), 2) if values else max_score

        # Set default value based on field
        if field == "hyperlinks":
            default_val = 1.0
        else:
            default_val = 0.6

        # Clamp within range
        default_val = max(fmin, min(fmax, default_val))

        field_sliders[field] = widgets.FloatSlider(
            value=default_val,
            min=fmin,
            max=fmax,
            step=0.01,
            description=f'{field} ≥',
            readout_format='.2f',
            continuous_update=True,
            layout=widgets.Layout(width='280px')
        )

    checkboxes = {field: widgets.Checkbox(value=True, description=field, indent=False)
                  for field in mention_fields}
    checkbox_widgets = list(checkboxes.values())

    dropdown = widgets.Dropdown(
        options=qid_labels,
        description='Entity:',
        value=qid_labels[0][1],
        style={'description_width': 'initial'}
    )

    html_out = widgets.HTML()

    def render_text(selected_qid, field_thresholds, active_fields):
        highlight_spans = []
        for (start, end), qid_map in mention_map.items():
            if selected_qid not in qid_map:
                continue
            scores, agg, _ = qid_map[selected_qid]
            if any(scores.get(field, 0.0) >= field_thresholds[field] for field in active_fields):
                highlight_spans.append((start, end, agg))

        highlight_spans_sorted = sorted(highlight_spans, key=lambda x: -x[2])
        span_ranks = {(start, end): rank + 1 for rank, (start, end, _) in enumerate(highlight_spans_sorted)}

        if not highlight_spans:
            return f"<div class='compact-mention-output'>{text}</div>"

        events = []
        for (start, end, score) in highlight_spans:
            rank = span_ranks[(start, end)]
            events.append((start, 'start', score, rank))
            events.append((end, 'end', None, None))
        events.sort(key=lambda x: (x[0], 0 if x[1] == 'end' else 1))

        out = []
        last_idx = 0
        highlight_stack = []

        for idx, typ, score, rank in events:
            if last_idx < idx:
                out.append(text[last_idx:idx])
            if typ == 'start':
                out.append(
                    f'<span style="background-color: #fff574; border-radius:3px; padding:2px 4px; font-size: 11px;" title="aggregated: {score}, rank: {rank}">'
                    f'<b>({rank})</b> '
                )
                highlight_stack.append('open')
            elif typ == 'end' and highlight_stack:
                out.append('</span>')
                highlight_stack.pop()
            last_idx = idx

        out.append(text[last_idx:])
        return f'<div class="compact-mention-output">{"".join(out)}</div>'

    def update_html(change=None):
        qid = dropdown.value
        active_fields = [f for f, cb in checkboxes.items() if cb.value]
        thresholds = {f: field_sliders[f].value for f in active_fields}
        html_out.value = render_text(qid, thresholds, active_fields)

    dropdown.observe(update_html, names='value')
    for cb in checkbox_widgets:
        cb.observe(update_html, names='value')
    for slider in field_sliders.values():
        slider.observe(update_html, names='value')

    update_html()

    controls = widgets.VBox([
        dropdown,
        widgets.HBox(checkbox_widgets),
        widgets.HBox([field_sliders[f] for f in mention_fields])
    ], layout=widgets.Layout(width='99%', max_width='1300px'))

    container = widgets.VBox([controls, html_out],
                             layout=widgets.Layout(width='99%', max_width='1350px'))
    display(container)


In [None]:
text, entities = tokenizer.decode(batch["input_ids"][0].tolist()), batch["metadata"][0]["entities"]
display_highlighted_mentions(text, entities)

VBox(children=(VBox(children=(Dropdown(description='Entity:', options=(('George L. Fox (chaplain) (Q5541454) […

## Index

In [None]:
from elasticsearch import Elasticsearch

def get_esclient(scheme="https", host="132.67.130.202", port=9200):
    return Elasticsearch(
        f"{scheme}://{host}:{port}", 
        basic_auth=("elastic", ""), 
        request_timeout=3000, 
        max_retries=10, 
        retry_on_timeout=True,
        verify_certs=False,
        ssl_show_warn=False
    )

In [None]:
CASE_SENSITIVE_INDEX_NAME = "enwiki_case_sensitive"
CASE_INSENSITIVE_INDEX_NAME = 'enwiki'


In [None]:
es = get_esclient()

mapping_cs = es.indices.get_mapping(index=CASE_SENSITIVE_INDEX_NAME)
mapping_ci = es.indices.get_mapping(index=CASE_INSENSITIVE_INDEX_NAME)

In [None]:
ci = {'enwiki': {'mappings': {'dynamic': 'false', 'properties': {'article_id': {'type': 'integer'}, 'chunk_id': {'type': 'integer'}, 'end': {'type': 'integer'}, 'entities': {'type': 'nested', 'properties': {'candidates': {'type': 'nested', 'properties': {'aggregated_score': {'type': 'float'}, 'name': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}, 'qid': {'type': 'keyword'}, 'scores_by_source': {'properties': {'coref': {'type': 'float'}, 'coref_cluster': {'type': 'float'}, 'entity_linking': {'type': 'float'}, 'hyperlinks': {'type': 'float'}}}}}, 'char_end': {'type': 'integer'}, 'char_start': {'type': 'integer'}, 'text_mention': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}}}, 'metadata_source': {'type': 'text'}, 'start': {'type': 'integer'}, 'text': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}, 'title': {'type': 'text'}}}}}

print(json.dumps(ci, indent=4))

{
    "enwiki": {
        "mappings": {
            "dynamic": "false",
            "properties": {
                "article_id": {
                    "type": "integer"
                },
                "chunk_id": {
                    "type": "integer"
                },
                "end": {
                    "type": "integer"
                },
                "entities": {
                    "type": "nested",
                    "properties": {
                        "candidates": {
                            "type": "nested",
                            "properties": {
                                "aggregated_score": {
                                    "type": "float"
                                },
                                "name": {
                                    "type": "text",
                                    "fields": {
                                        "raw": {
                                            "type": "keyword"
              

In [None]:
print(mapping_ci)

{'enwiki': {'mappings': {'dynamic': 'false', 'properties': {'article_id': {'type': 'integer'}, 'chunk_id': {'type': 'integer'}, 'end': {'type': 'integer'}, 'entities': {'type': 'nested', 'properties': {'candidates': {'type': 'nested', 'properties': {'aggregated_score': {'type': 'float'}, 'name': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}, 'qid': {'type': 'keyword'}, 'scores_by_source': {'properties': {'coref': {'type': 'float'}, 'coref_cluster': {'type': 'float'}, 'entity_linking': {'type': 'float'}, 'hyperlinks': {'type': 'float'}}}}}, 'char_end': {'type': 'integer'}, 'char_start': {'type': 'integer'}, 'text_mention': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}}}, 'metadata_source': {'type': 'text'}, 'start': {'type': 'integer'}, 'text': {'type': 'text', 'fields': {'raw': {'type': 'keyword'}}}, 'title': {'type': 'text'}}}}}
