In [1]:
from glob import glob
from streaming import MDSWriter
from streaming import LocalDataset, StreamingDataset
from transformers import default_data_collator, DataCollatorForLanguageModeling
from tqdm import tqdm
import numpy as np

In [2]:
folders = sorted(glob('tokenized_indexes/tokenized-*'), key = lambda x: int(x.split('-')[-1]))

In [3]:
folders.extend(sorted(glob('tokenized_extra/tokenized-*'), key = lambda x: int(x.split('-')[-1])))

In [4]:
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [5]:
columns = {
    'input_ids': 'uint16',
}

compression = 'zstd'
hashes = 'sha1', 'xxh64'

In [6]:
!rm -rf combine-all

In [7]:
with MDSWriter(out='combine-all', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = StreamingDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████| 796855/796855 [00:37<00:00, 21315.92it/s]
100%|██████████| 198350/198350 [00:09<00:00, 21047.43it/s]
100%|██████████| 97906/97906 [00:04<00:00, 21635.00it/s]
100%|██████████| 113487/113487 [00:05<00:00, 20761.16it/s]
100%|██████████| 111516/111516 [00:06<00:00, 18543.29it/s]
100%|██████████| 61967/61967 [00:02<00:00, 21664.59it/s]
100%|██████████| 45167/45167 [00:03<00:00, 13430.82it/s]
100%|██████████| 129551/129551 [00:06<00:00, 21531.84it/s]
100%|██████████| 238372/238372 [00:12<00:00, 18961.75it/s]
100%|██████████| 238387/238387 [00:12<00:00, 19721.06it/s]
100%|██████████| 577225/577225 [00:29<00:00, 19245.44it/s]
100%|██████████| 111841/111841 [00:05<00:00, 21265.39it/s]
100%|██████████| 110952/110952 [00:07<00:00, 14979.07it/s]
100%|██████████| 339991/339991 [00:18<00:00, 18262.81it/s]
100%|██████████| 344519/344519 [00:18<00:00, 18347.32it/s]
100%|██████████| 163652/163652 [00:08<00:00, 20416.54it/s]
100%|██████████| 500663/500663 [00:25<00:00, 19590.86it/s]
100

In [8]:
dataset = LocalDataset('combine-all')

In [9]:
len(dataset)

4759901

In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
'ahxt/LiteLlama-460M-1T',
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.model_max_length = 32768
special_tokens_dict = {"eos_token": "</s>", "bos_token": '<s>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [12]:
tokenizer.decode(dataset[0]['input_ids'])

'<s>Bahasa Melayu (Tulisan Jawi: بهاس ملايو; Rejang: ꤷꥁꤼ ꤸꥍꤾꤿꥈ) ialah salah satu daripada bahasa-bahasa Melayu-Polinesia di bawah keluarga bahasa Austronesia, yang merupakan bahasa rasmi di Brunei, Indonesia, Malaysia dan Singapura, serta dituturkan di Timor Leste dan sebahagian wilayah di Kemboja, Filipina dan Thailand. Jumlah penutur bahasa Melayu mencakupi lebih daripada 290 juta penutur (seramai 260 juta orang bertutur bahasa Indonesia) merentasi kawasan maritim Asia Tenggara. Sebagai salah satu daripada bahasa-bahasa yang paling luas digunakan di Asia Tenggara, bahasa Melayu mempunyai istilah perundangan yang berbeza di negara-negara terlibat bergantung pada sejarah dan budaya penggunaan bahasa Melayu di negara-negara tersebut. Di Malaysia, istilah "bahasa Melayu" ialah istilah "de jure" untuk pentakrifan rasmi bahasa kebangsaan negara Malaysia, manakala istilah "bahasa Malaysia" atau "bahasa Melayu Malaysia" seringkali digunakan mewakili perkara yang sama secara tidak formal di k