In [1]:
from glob import glob
from streaming import MDSWriter
from streaming import LocalDataset
from transformers import default_data_collator, DataCollatorForLanguageModeling
from tqdm import tqdm
import numpy as np

In [2]:
folders = sorted(glob('tokenized_indexes/tokenized-*'), key = lambda x: int(x.split('-')[-1]))

In [3]:
folders.extend(sorted(glob('tokenized_extra/tokenized-*'), key = lambda x: int(x.split('-')[-1])))

In [4]:
folders

['tokenized_indexes/tokenized-0',
 'tokenized_indexes/tokenized-1',
 'tokenized_indexes/tokenized-2',
 'tokenized_indexes/tokenized-3',
 'tokenized_indexes/tokenized-4',
 'tokenized_indexes/tokenized-5',
 'tokenized_indexes/tokenized-6',
 'tokenized_indexes/tokenized-7',
 'tokenized_indexes/tokenized-8',
 'tokenized_indexes/tokenized-9',
 'tokenized_indexes/tokenized-10',
 'tokenized_indexes/tokenized-11',
 'tokenized_indexes/tokenized-12',
 'tokenized_extra/tokenized-0',
 'tokenized_extra/tokenized-1',
 'tokenized_extra/tokenized-2',
 'tokenized_extra/tokenized-3',
 'tokenized_extra/tokenized-4',
 'tokenized_extra/tokenized-5',
 'tokenized_extra/tokenized-6',
 'tokenized_extra/tokenized-7',
 'tokenized_extra/tokenized-8',
 'tokenized_extra/tokenized-9',
 'tokenized_extra/tokenized-10',
 'tokenized_extra/tokenized-11',
 'tokenized_extra/tokenized-12']

In [4]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [5]:
columns = {
    'input_ids': 'uint16',
}

compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [6]:
!rm -rf combine-all

In [7]:
dataset = LocalDataset(local=folders[0])

In [8]:
with MDSWriter(out='combine-all', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 910600/910600 [00:44<00:00, 20382.38it/s]
100%|██████████| 78209/78209 [00:08<00:00, 9061.67it/s] 
100%|██████████| 79160/79160 [00:03<00:00, 23297.82it/s]
100%|██████████| 79268/79268 [00:10<00:00, 7812.87it/s] 
100%|██████████| 79486/79486 [00:03<00:00, 24253.02it/s]
100%|██████████| 79563/79563 [00:11<00:00, 7141.72it/s] 
100%|██████████| 79588/79588 [00:04<00:00, 17337.33it/s]
100%|██████████| 79360/79360 [00:11<00:00, 6888.19it/s] 
100%|██████████| 212948/212948 [00:15<00:00, 13356.01it/s]
100%|██████████| 426352/426352 [00:37<00:00, 11468.45it/s]
100%|██████████| 105709/105709 [00:06<00:00, 17269.23it/s]
100%|██████████| 527994/527994 [00:34<00:00, 15209.52it/s]
100%|██████████| 400576/400576 [00:34<00:00, 11623.75it/s]
100%|██████████| 24698/24698 [00:02<00:00, 8606.91it/s] 
100%|██████████| 23515/23515 [00:01<00:00, 23097.02it/s]
100%|██████████| 23345/23345 [00:03<00:00, 6131.03it/s] 
100%|██████████| 23544/23544 [00:01<00:00, 23340.43it/s]
100%|██████████| 23

In [12]:
dataset = LocalDataset(local='combine-all')

In [13]:
len(dataset) * 4096

14114934784