In [15]:
import mp
import os
import pyarrow as pa
import numpy as np
from streaming import MDSWriter
from tqdm import tqdm

In [6]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)

_encodings['int32'] = Int32

In [7]:
columns = {
    'input_ids': 'int32',
    'token_type_ids': 'int32',
    'attention_mask': 'int32',
    'labels': 'int32',
}
compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [8]:
from glob import glob

files = glob('combine-lm_*_of_00020.jsonl-grouped-4096')
files

['combine-lm_00017_of_00020.jsonl-grouped-4096',
 'combine-lm_00005_of_00020.jsonl-grouped-4096',
 'combine-lm_00008_of_00020.jsonl-grouped-4096',
 'combine-lm_00012_of_00020.jsonl-grouped-4096',
 'combine-lm_00007_of_00020.jsonl-grouped-4096',
 'combine-lm_00014_of_00020.jsonl-grouped-4096',
 'combine-lm_00006_of_00020.jsonl-grouped-4096',
 'combine-lm_00013_of_00020.jsonl-grouped-4096',
 'combine-lm_00016_of_00020.jsonl-grouped-4096',
 'combine-lm_00011_of_00020.jsonl-grouped-4096',
 'combine-lm_00018_of_00020.jsonl-grouped-4096',
 'combine-lm_00002_of_00020.jsonl-grouped-4096',
 'combine-lm_00009_of_00020.jsonl-grouped-4096',
 'combine-lm_00019_of_00020.jsonl-grouped-4096',
 'combine-lm_00001_of_00020.jsonl-grouped-4096',
 'combine-lm_00003_of_00020.jsonl-grouped-4096',
 'combine-lm_00015_of_00020.jsonl-grouped-4096',
 'combine-lm_00004_of_00020.jsonl-grouped-4096',
 'combine-lm_00000_of_00020.jsonl-grouped-4096',
 'combine-lm_00010_of_00020.jsonl-grouped-4096']

In [16]:
def loop(files):
    files, index = files
    out_root = f'tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    with MDSWriter(out=out_root, columns=columns, compression=compression, hashes=hashes, 
                   size_limit = 67108864 * 2) as out:
        for f in files:
            memory_mapped_stream = pa.memory_map(f)
            opened_stream = pa.ipc.open_stream(memory_mapped_stream)
            for a in tqdm(opened_stream):
                s = a.to_struct_array()
                for i in range(len(s)):
                    keys = list(s[i])
                    a_ = {}
                    for k in keys:
                        a_[k] = np.array(s[i][k].as_py()).astype(np.int32)
                    out.write(a_)

In [None]:
mp.multiprocessing(files, loop, cores = 20, returned = False)

2570it [05:50,  7.30it/s]
7464it [06:35, 18.87it/s]
7464it [07:57, 15.62it/s]
7464it [08:06, 15.36it/s]
7464it [08:11, 15.20it/s]
7464it [12:20, 10.08it/s]
5816it [13:12,  7.64it/s]