In [1]:
from glob import glob
from streaming import MDSWriter
from streaming import StreamingDataset, LocalDataset
from transformers import default_data_collator, DataCollatorForLanguageModeling
from tqdm import tqdm
import numpy as np

In [2]:
folders = sorted(glob('tokenized_indexes-filtered/tokenized-*'), key = lambda x: int(x.split('-')[-1]))

In [3]:
folders

['tokenized_indexes-filtered/tokenized-0',
 'tokenized_indexes-filtered/tokenized-1',
 'tokenized_indexes-filtered/tokenized-2',
 'tokenized_indexes-filtered/tokenized-3',
 'tokenized_indexes-filtered/tokenized-4',
 'tokenized_indexes-filtered/tokenized-5']

In [4]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [5]:
columns = {
    'input_ids': 'uint16',
}

compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [6]:
!rm -rf combine-dedup-text-dataset-filtered

In [7]:
with MDSWriter(out='combine-dedup-text-dataset-filtered', columns=columns, compression=None, hashes=hashes) as out:
    for no, f in enumerate(folders):
        print(no)
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

0


100%|██████████| 747558/747558 [03:30<00:00, 3550.29it/s]


1


100%|██████████| 303050/303050 [01:26<00:00, 3522.33it/s]


2


100%|██████████| 1452124/1452124 [06:25<00:00, 3764.68it/s]


3


100%|██████████| 1021877/1021877 [04:36<00:00, 3689.94it/s]


4


100%|██████████| 660256/660256 [03:01<00:00, 3629.32it/s]


5


100%|██████████| 357323/357323 [01:37<00:00, 3677.88it/s]


In [8]:
dataset = LocalDataset('combine-dedup-text-dataset-filtered')

In [9]:
len(dataset) * 4096

18604802048