In [None]:
import json
import math
import multiprocessing
import warnings

import bs4
import numpy as np
import torch
from bs4 import MarkupResemblesLocatorWarning
from transformers import AutoTokenizer

warnings.filterwarnings("ignore", category=MarkupResemblesLocatorWarning)

In [None]:
def text_with_newdataset(elem):
    text = ""
    for e in elem.descendants:
        if isinstance(e, str):
            text += e
        elif e.name == "br" or e.name == "p":
            text += "\n"

    return text


def parse_line(line):
    data = json.loads(line)
    posts_text = []

    for post in data.get("posts", []):
        if "com" in post:
            soup = bs4.BeautifulSoup(post["com"], "lxml")
            post_text = text_with_newdataset(soup).strip()
        else:
            post_text = ""

        post_text = f"--- {post['no']}\n{post_text}"
        posts_text.append(post_text)

    return "\n".join(posts_text) + "\n-----\n"

In [None]:
dataset_str = []
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [None]:
with multiprocessing.Pool() as pool:
    with open("dataset.ndjson") as file:
        # imap is fine too, but it's slower
        for data in pool.map(parse_line, file):
            dataset_str.append(data)

Since huggingface's tokenizer only support int64, we have to work arround it's BS before the tensor gets too big

I only got 256 GB ram to work with :(

In [None]:
chunk_size = 100000
num_chunks = math.ceil(len(dataset_str) / chunk_size)

dataset_token = []

for i in range(num_chunks):
    start_index = i * chunk_size
    end_index = min((i + 1) * chunk_size, len(dataset_str))

    chunk = dataset_str[start_index:end_index]
    chunk = '\n'.join(chunk) + "\n"

    chunk = tokenizer(chunk, return_tensors="pt")["input_ids"].squeeze()
    dataset_token.append(chunk.type(torch.uint16))

del dataset_str

In [None]:
dataset_token = torch.cat(dataset_token)

We will store memmap in uint16 since our vocab is ~50k and uint16 goes up to 65535

Using int64 is just wasting space

In [None]:
mp = np.memmap("dataset.dat", dtype=np.uint16, mode="w+", shape=dataset_token.numel())
mp[:] = dataset_token[:]
mp.flush()