In [3]:
from glob import glob
from streaming import MDSWriter
from streaming import StreamingDataset
from transformers import default_data_collator
from tqdm import tqdm
import numpy as np

In [5]:
!rm -rf nanot5-512

In [6]:
folders = glob('nanot5-*')
folders = sorted(folders, key = lambda x: int(x.split('-')[-1]))
folders

['nanot5-0',
 'nanot5-1',
 'nanot5-2',
 'nanot5-3',
 'nanot5-4',
 'nanot5-5',
 'nanot5-6',
 'nanot5-7',
 'nanot5-8',
 'nanot5-9',
 'nanot5-10',
 'nanot5-11',
 'nanot5-12',
 'nanot5-13',
 'nanot5-14',
 'nanot5-15',
 'nanot5-16',
 'nanot5-17',
 'nanot5-18',
 'nanot5-19',
 'nanot5-20',
 'nanot5-21',
 'nanot5-22',
 'nanot5-23',
 'nanot5-24',
 'nanot5-25',
 'nanot5-26',
 'nanot5-27',
 'nanot5-28',
 'nanot5-29']

In [7]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)

_encodings['int32'] = Int32

In [8]:
columns = {
    'input_ids': 'int32',
}
compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [9]:
!rm -rf nanot5-512

In [10]:
with MDSWriter(out='nanot5-512', columns=columns, compression=compression, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 4398060/4398060 [03:27<00:00, 21184.03it/s]
100%|██████████| 635893/635893 [00:30<00:00, 20813.50it/s]
100%|██████████| 855400/855400 [00:42<00:00, 20180.59it/s]
100%|██████████| 804710/804710 [00:37<00:00, 21443.98it/s]
100%|██████████| 805724/805724 [00:41<00:00, 19298.82it/s]
100%|██████████| 661929/661929 [00:32<00:00, 20174.33it/s]
100%|██████████| 629884/629884 [00:39<00:00, 15854.09it/s]
100%|██████████| 622476/622476 [00:37<00:00, 16753.33it/s]
100%|██████████| 406799/406799 [00:24<00:00, 16362.57it/s]
100%|██████████| 303284/303284 [00:15<00:00, 19149.00it/s]
100%|██████████| 306371/306371 [00:17<00:00, 17435.18it/s]
100%|██████████| 307173/307173 [00:16<00:00, 18800.55it/s]
100%|██████████| 310475/310475 [00:14<00:00, 20895.22it/s]
100%|██████████| 312804/312804 [00:15<00:00, 20091.23it/s]
100%|██████████| 314389/314389 [00:15<00:00, 20882.24it/s]
100%|██████████| 313495/313495 [00:15<00:00, 20571.60it/s]
100%|██████████| 2007768/2007768 [01:35<00:00, 20971.9

In [11]:
%%time
dataset = StreamingDataset(local='nanot5-512')

CPU times: user 312 ms, sys: 96.7 ms, total: 409 ms
Wall time: 406 ms


In [13]:
dataset[0]['input_ids'].shape

(568,)