In [1]:
import jax.random as jrand
import jax
import jax.numpy as jnp
import nump as np
from tokenizers import normalizers, pre_tokenizers
from tokenizers.normalizers import NFD, StripAccents, Lowercase
from tokenizers.pre_tokenizers import Digits, Whitespace

from tokenizers import Tokenizer
from tokenizers.models import BPE

pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=True)])
normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()])


files = [f'/Users/dashiell/workspace/wikitext/wikitext-103-raw/wiki.{split}.raw' for split in ['train', 'test', 'valid']]


None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [18]:
trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]"], vocab_size=8192)

In [19]:
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = pre_tokenizer
tokenizer.normalizer = normalizer
tokenizer.train(files, trainer)






In [20]:
tokenizer.save("tokenizer/tokenizer-wiki8192.json")

In [2]:
tokenizer = Tokenizer.from_file("tokenizer/tokenizer-wiki8192.json")

In [4]:
output = tokenizer.encode("Hello, y'all! How are you 😁 ?")

In [32]:
tokenizer.token_to_id('[PAD]')

3

In [24]:
output.tokens

['Hel', 'lo', ',', 'y', "'", 'all', '!', 'How', 'are', 'you', '[UNK]', '?']

In [3]:
from datasets import load_dataset

data = load_dataset("wikipedia", "20220301.en")

Reusing dataset wikipedia (/Users/dashiell/.cache/huggingface/datasets/wikipedia/20220301.en/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
#data = data.rename_column('text', 'article')

tokenizer.encode(data['train'][0]['article'])

Encoding(num_tokens=11450, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [7]:
import numpy as np



In [45]:
import polars as pl


def tokenize_and_split(batch):
    seq_len = 512
    texts = '[SEP]'.join(batch['text'])
    cls_token = tokenizer.token_to_id('[CLS]')
    pad_token = tokenizer.token_to_id('[PAD]')
    token_ids = np.array(tokenizer.encode(texts).ids)
    leftover = token_ids.shape[0] % (seq_len - 1)
    pads = np.full(((seq_len - 1) - leftover,), pad_token)
    split_tokens = np.append(token_ids, pads).reshape((-1, seq_len - 1))
    num_splits = split_tokens.shape[0]
    extra_cls_tokens = np.full((num_splits, 1), cls_token)
    fully_tokenized = np.hstack([extra_cls_tokens, split_tokens])
    return np.vsplit(fully_tokenized, num_splits)

In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm
from pathlib import Path




def save_chunk(index, path, data):
    tokens = tokenize_and_split(data)
    df = pl.DataFrame({'tokens': tokens})
    df.write_parquet(path / f'chunk_{index}.parquet')
    
futures = []
chunk_size = 10_000
num_rows = data['train'].num_rows
fp = Path('/Users/dstander/workspace/simplex-score-matching/data/train')
with ThreadPoolExecutor(max_workers=32) as executor:
    for chunk_no, i in enumerate(range(0, num_rows, chunk_size)):
        j = min(i + chunk_size, num_rows)
        chunk = data['train'][i:j]
        futures.append(executor.submit(save_chunk, chunk_no, fp, chunk))

total_rows = 0
for i, fut in as_completed(futures):
    try:
        rows = fut.result()
    except Exception:
        print(f'Chunk {i} failed')
        continue
    total_rows += rows
    if i // 50 == 0:
        print(f'Saved {total_rows} records')


In [46]:
x = tokenize_and_split(data['train'][0:100])

In [50]:
for i in range(0, 100_000, 10_000):
    print(i)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000


In [16]:
from tqdm.notebook import tqdm

for i in tqdm(range(data['train'].num_rows)):
    if i // 1000 == 0:
        print(data['train'][i])

  0%|          | 0/6458670 [00:00<?, ?it/s]

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [25]:
b = np.arange(25).reshape((5, 5))

In [34]:
np.full((5, 1), 45)

array([[45],
       [45],
       [45],
       [45],
       [45]])