In [10]:
import torch
from pathlib import Path

def _load_data_shard(file: Path):
    header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32
    assert header[0] == 20240520, "magic number mismatch in the data .bin file"
    assert header[1] == 1, "unsupported version"
    num_tokens = int(header[2]) # number of tokens (claimed)
    with file.open("rb", buffering=0) as f:
        tokens = torch.empty(num_tokens, dtype=torch.uint16) # avoid pin_memory copy by @YouJiacheng
        f.seek(256 * 4)
        nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
        assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
    return tokens

shard_num = 2
shard = _load_data_shard(Path(f"data/fineweb10B/fineweb_train_{shard_num:06d}.bin"))

shard

tensor([ 262, 4802, 1393,  ..., 6792, 1363,  318], dtype=torch.uint16)

In [12]:
# Calculate document boundaries and lengths
doc_boundaries = (shard == 50256).nonzero().squeeze()
doc_ids = (shard == 50256).cumsum(0)

# Count number of documents
num_docs = doc_ids[-1].item() + 1  # +1 because doc_ids is 0-indexed

# Calculate tokens per document
doc_lengths = []
prev_boundary = -1
for boundary in doc_boundaries:
    doc_lengths.append(boundary.item() - prev_boundary - 1)  # -1 to exclude the EOS token
    prev_boundary = boundary.item()
# Add the last document (if the file doesn't end with EOS token)
if prev_boundary < len(shard) - 1:
    doc_lengths.append(len(shard) - prev_boundary - 1)

total_tokens = len(shard)
print(f"Total number of tokens: {total_tokens}")
print(f"Number of documents: {num_docs}")
print(f"Document lengths: {doc_lengths[:5]}...")  # Show first 5 document lengths
print(f"Average document length: {sum(doc_lengths) / len(doc_lengths):.2f} tokens")


Total number of tokens: 100000000
Number of documents: 144404
Document lengths: [226, 4554, 462, 637, 575]...
Average document length: 691.50 tokens
