In [5]:
from glob import glob
from streaming import MDSWriter
from streaming import StreamingDataset
from transformers import default_data_collator, DataCollatorForLanguageModeling
from tqdm import tqdm
import numpy as np

In [6]:
folders = sorted(glob('tokenized_indexes/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

['tokenized_indexes/tokenized-0',
 'tokenized_indexes/tokenized-1',
 'tokenized_indexes/tokenized-2',
 'tokenized_indexes/tokenized-3',
 'tokenized_indexes/tokenized-4',
 'tokenized_indexes/tokenized-5',
 'tokenized_indexes/tokenized-6',
 'tokenized_indexes/tokenized-7',
 'tokenized_indexes/tokenized-8',
 'tokenized_indexes/tokenized-9',
 'tokenized_indexes/tokenized-10',
 'tokenized_indexes/tokenized-11',
 'tokenized_indexes/tokenized-12',
 'tokenized_indexes/tokenized-13',
 'tokenized_indexes/tokenized-14',
 'tokenized_indexes/tokenized-15',
 'tokenized_indexes/tokenized-16',
 'tokenized_indexes/tokenized-17',
 'tokenized_indexes/tokenized-18',
 'tokenized_indexes/tokenized-19',
 'tokenized_indexes/tokenized-20',
 'tokenized_indexes/tokenized-21',
 'tokenized_indexes/tokenized-22',
 'tokenized_indexes/tokenized-23',
 'tokenized_indexes/tokenized-24',
 'tokenized_indexes/tokenized-25',
 'tokenized_indexes/tokenized-26',
 'tokenized_indexes/tokenized-27',
 'tokenized_indexes/tokenized-

In [7]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [8]:
columns = {
    'input_ids': 'uint16',
}

compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [7]:
with MDSWriter(out='tokenized', columns=columns, compression=compression, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 906379/906379 [01:42<00:00, 8804.31it/s] 
100%|██████████| 177681/177681 [00:43<00:00, 4111.60it/s] 
100%|██████████| 187986/187986 [00:26<00:00, 6983.59it/s] 
100%|██████████| 187877/187877 [00:42<00:00, 4466.45it/s] 
100%|██████████| 152901/152901 [00:22<00:00, 6818.62it/s] 
100%|██████████| 121434/121434 [00:20<00:00, 5980.49it/s]
100%|██████████| 120487/120487 [00:22<00:00, 5261.07it/s]
100%|██████████| 104511/104511 [00:19<00:00, 5275.63it/s]
100%|██████████| 77188/77188 [00:15<00:00, 5006.33it/s]
100%|██████████| 77607/77607 [00:14<00:00, 5537.91it/s]
100%|██████████| 77579/77579 [00:13<00:00, 5860.13it/s]
100%|██████████| 77750/77750 [00:16<00:00, 4598.58it/s] 
100%|██████████| 77709/77709 [00:12<00:00, 6174.10it/s]
100%|██████████| 77729/77729 [00:11<00:00, 6838.79it/s] 
100%|██████████| 77219/77219 [00:12<00:00, 6000.09it/s]
100%|██████████| 517771/517771 [01:14<00:00, 6932.16it/s] 
100%|██████████| 94853/94853 [00:19<00:00, 4750.32it/s] 
100%|██████████| 9437

In [9]:
dataset = StreamingDataset(local='tokenized')

In [10]:
len(dataset) * 4096

49726361600

In [12]:
%%time

dataset[7849]

CPU times: user 0 ns, sys: 1.91 ms, total: 1.91 ms
Wall time: 5.04 ms


{'input_ids': array([29871, 30198, 30177, ..., 30256, 30162, 29871], dtype=uint16)}