In [9]:
from glob import glob
from streaming import MDSWriter
from streaming import LocalDataset, StreamingDataset
from transformers import default_data_collator, DataCollatorForLanguageModeling
from tqdm import tqdm
import numpy as np

In [2]:
folders = sorted(glob('tokenized_indexes/tokenized-*'), key = lambda x: int(x.split('-')[-1]))

In [3]:
folders.extend(sorted(glob('tokenized_extra/tokenized-*'), key = lambda x: int(x.split('-')[-1])))

In [5]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [6]:
columns = {
    'input_ids': 'uint16',
}

compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [10]:
!rm -rf combine-all

In [11]:
with MDSWriter(out='combine-all', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 923825/923825 [00:46<00:00, 19667.13it/s]
100%|██████████| 76835/76835 [00:07<00:00, 10425.98it/s]
100%|██████████| 77612/77612 [00:03<00:00, 20657.35it/s]
100%|██████████| 77634/77634 [00:03<00:00, 20489.35it/s]
100%|██████████| 77766/77766 [00:03<00:00, 21671.17it/s]
100%|██████████| 77714/77714 [00:03<00:00, 21079.42it/s]
100%|██████████| 77674/77674 [00:03<00:00, 21183.28it/s]
100%|██████████| 77401/77401 [00:07<00:00, 10778.78it/s]
100%|██████████| 236271/236271 [00:17<00:00, 13560.57it/s]
100%|██████████| 401254/401254 [00:30<00:00, 13273.08it/s]
100%|██████████| 108914/108914 [00:08<00:00, 13158.79it/s]
100%|██████████| 544131/544131 [00:48<00:00, 11157.76it/s]
100%|██████████| 426205/426205 [00:36<00:00, 11690.56it/s]
100%|██████████| 25582/25582 [00:14<00:00, 1825.25it/s]
100%|██████████| 24501/24501 [00:09<00:00, 2490.48it/s]
100%|██████████| 24503/24503 [00:11<00:00, 2137.81it/s]
100%|██████████| 24575/24575 [00:12<00:00, 1998.45it/s]
100%|██████████| 24384/

In [12]:
dataset = LocalDataset(local='combine-all')

In [13]:
len(dataset) * 4096

14349328384

In [14]:
!du -hs combine-all

27G	combine-all


In [15]:
len(dataset)

3503254