In [13]:
from glob import glob
from streaming import MDSWriter
from streaming import StreamingDataset
from transformers import default_data_collator
from tqdm import tqdm
import numpy as np

In [3]:
!du -hs tokenized-*

1.1G	tokenized-0
635M	tokenized-1
2.7G	tokenized-10
1.1G	tokenized-11
4.3G	tokenized-12
4.3G	tokenized-13
1.1G	tokenized-14
939M	tokenized-15
844M	tokenized-16
827M	tokenized-17
3.0G	tokenized-18
917M	tokenized-19
4.2G	tokenized-2
890M	tokenized-3
4.2G	tokenized-4
734M	tokenized-5
4.1G	tokenized-6
951M	tokenized-7
3.4G	tokenized-8
893M	tokenized-9


In [4]:
folders = glob('tokenized-*')
folders

['tokenized-1',
 'tokenized-16',
 'tokenized-4',
 'tokenized-2',
 'tokenized-19',
 'tokenized-9',
 'tokenized-18',
 'tokenized-8',
 'tokenized-6',
 'tokenized-12',
 'tokenized-10',
 'tokenized-3',
 'tokenized-7',
 'tokenized-15',
 'tokenized-13',
 'tokenized-11',
 'tokenized-5',
 'tokenized-17',
 'tokenized-14',
 'tokenized-0']

In [5]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)

_encodings['int32'] = Int32

In [6]:
columns = {
    'input_ids': 'int32',
    'token_type_ids': 'int32',
    'attention_mask': 'int32',
    'labels': 'int32',
}
compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [7]:
!rm -rf tokenized

In [8]:
with MDSWriter(out='tokenized', columns=columns, compression=compression, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 96298/96298 [00:30<00:00, 3112.38it/s]
100%|██████████| 168260/168260 [01:12<00:00, 2329.09it/s]
100%|██████████| 60663/60663 [00:19<00:00, 3106.63it/s]
100%|██████████| 61414/61414 [00:19<00:00, 3182.67it/s]
100%|██████████| 271845/271845 [01:44<00:00, 2604.84it/s]
100%|██████████| 136592/136592 [00:50<00:00, 2701.86it/s]
100%|██████████| 644092/644092 [04:12<00:00, 2555.48it/s]
100%|██████████| 48943/48943 [00:13<00:00, 3601.56it/s]
100%|██████████| 60098/60098 [00:24<00:00, 2481.63it/s]
100%|██████████| 62076/62076 [00:17<00:00, 3588.53it/s]
100%|██████████| 452002/452002 [02:52<00:00, 2625.16it/s]
100%|██████████| 135520/135520 [01:03<00:00, 2136.21it/s]
100%|██████████| 186496/186496 [01:15<00:00, 2468.61it/s]
100%|██████████| 144360/144360 [00:48<00:00, 2991.19it/s]
100%|██████████| 753896/753896 [06:09<00:00, 2040.98it/s]
100%|██████████| 164342/164342 [01:06<00:00, 2466.80it/s]
100%|██████████| 175792/175792 [01:08<00:00, 2582.57it/s]
100%|██████████| 127292/12

In [10]:
!du -hs tokenized

23G	tokenized


In [20]:
%%time
dataset = StreamingDataset(local='tokenized')

CPU times: user 299 ms, sys: 54.2 ms, total: 353 ms
Wall time: 353 ms


In [21]:
%%time

batch = [dataset[i] for i in range(10)]

CPU times: user 3.26 ms, sys: 0 ns, total: 3.26 ms
Wall time: 2.66 ms


In [18]:
len(dataset) * 4096

16994238464

In [15]:
padded = default_data_collator(batch)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
model = AutoModelForCausalLM.from_pretrained('./mistral-191M')

In [None]:
model(**padded)