In [1]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('unsloth/Llama-3.2-3B-Instruct')

In [2]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
}
hashes = 'sha1', 'xxh64'

In [3]:
!rm -rf tokenized-4k
!mkdir tokenized-4k

In [4]:
combine = []
with open('combined-malaysian-sft.jsonl') as fopen:
    for l in fopen:
        l = json.loads(l)
        combine.append(l)

len(combine)

1294946

In [5]:
with open('translation-instructions.json') as fopen:
    translation = json.load(fopen)
    
for d in translation:
    combine.append([
        {'role': 'user', 'content': d['input']},
        {'role': 'assistant', 'content': d['output']}
    ])
    
len(combine)

1364946

In [6]:
import gc

def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

In [7]:
import time

def loop(files, block_size = 3072):
    rows, index = files
    out_root = f'tokenized-4k/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            prompt = tokenizer.apply_chat_template(row, tokenize=False)
            outputs = tokenizer(prompt, add_special_tokens = False)
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)
                
        block, _ = slice_and_balance(last_block, block_size - count)
        block_position, _ = slice_and_balance(last_position_block, block_size - count)

        block.extend(temp)
        block_position.extend(position_ids)

        o = collator(block, block_position)
        if len(o['input_ids']) == block_size:
            out.write(o)
            return o

In [8]:
# loop((combine[:10000], 0))

In [9]:
from multiprocess import Pool
import mp

chunks = mp.chunks(combine, 50000)
pool = Pool(10)
pooled = pool.map(loop, chunks)
pool.close()
pool.join()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [02:16<00:00, 365.05it/s]
 64%|███████████████████████████████████████████████████████████████                                   | 32175/50000 [00:30<00:16, 1058.45it/s]
 38%|█████████████████████████████████████▉                                                             | 19174/50000 [03:07<04:29, 114.47it/s]
 64%|███████████████████████████████████████████████████████████████▍                                   | 32019/50000 [03:32<01:58, 152.35it/s]
 11%|███████████▎                                                                                        | 5654/50000 [00:06<00:52, 836.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:44<00:00, 1125.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:50<00:00, 999.

In [10]:
folders = sorted(glob('tokenized-4k/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

['tokenized-4k/tokenized-0',
 'tokenized-4k/tokenized-1',
 'tokenized-4k/tokenized-2',
 'tokenized-4k/tokenized-3',
 'tokenized-4k/tokenized-4',
 'tokenized-4k/tokenized-5',
 'tokenized-4k/tokenized-6',
 'tokenized-4k/tokenized-7',
 'tokenized-4k/tokenized-8',
 'tokenized-4k/tokenized-9',
 'tokenized-4k/tokenized-10',
 'tokenized-4k/tokenized-11',
 'tokenized-4k/tokenized-12',
 'tokenized-4k/tokenized-13',
 'tokenized-4k/tokenized-14',
 'tokenized-4k/tokenized-15',
 'tokenized-4k/tokenized-16',
 'tokenized-4k/tokenized-17',
 'tokenized-4k/tokenized-18',
 'tokenized-4k/tokenized-19',
 'tokenized-4k/tokenized-20',
 'tokenized-4k/tokenized-21',
 'tokenized-4k/tokenized-22',
 'tokenized-4k/tokenized-23',
 'tokenized-4k/tokenized-24',
 'tokenized-4k/tokenized-25',
 'tokenized-4k/tokenized-26',
 'tokenized-4k/tokenized-27']

In [11]:
!rm -rf packing-4k

In [12]:
with MDSWriter(out='packing-4k', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 11324/11324 [00:01<00:00, 9107.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 12691/12691 [00:01<00:00, 9912.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 20838/20838 [00:02<00:00, 9725.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 14405/14405 [00:01<00:00, 11139.13it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 15445/15445 [00:01<00:00, 11597.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 17804/17804 [00:01<00:00, 13029.68it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 29525/29525 [00:02<00:00, 11724.

In [13]:
dataset = LocalDataset('packing-4k')
(len(dataset) * 3072) / 1e9

1.138980864

In [21]:
tokenizer.decode(dataset[-3]['input_ids'])

' and 9\n    AND f.n between 1 and 7\n) primes\nWHERE p2 + p3 + p5 + p7 + p11 = 100;\n```\n\nPertanyaan di atas akan menghasilkan semua kombinasi nombor gandaan daripada 2, 3, 5, 7, 11, dan 13 yang nilainya kurang daripada atau sama dengan 100 (dengan mengehadkan nilai n pada setiap faktor perdana). Kemudian, pertanyaan akan memilih semua rangkaian gandaan nombor perdana yang jika ditambah akan menghasilkan 100.\n\nHarap dicatat bahawa pertanyaan ini berdasarkan jadual "nombor" yang harus ada dalam pangkalan data anda, dan jadual ini mesti mengandungi sekurang-kurangnya satu lajur "n" dengan nilai 1 hingga 100. Jika jadual ini belum wujud dalam pangkalan data anda, maka pertanyaan tidak akan berjalan betul dan perlu dibuat terlebih dahulu.<|eot_id|><|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 04 Jan 2025\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nterjemah ke bahasa melayu `Dalam database SQL, untuk meran

In [22]:
tokenizer.decode(dataset[-2]['input_ids'])

' 3, 4]\nlist2 = [3, 4, 5, 6]\nresult = find_intersection(list1, list2)\nprint(result)\n```\n\nPada baris pertama, sebuah fungsi bernama `find_intersection` dengan dua parameter `list1` dan `list2` didefinisikan. \n\nPada baris kedua, sebuah variabel bernama `intersection` yang akan menyimpan hasil temuan irisan dari kedua list didefinisikan sebagai list kosong. \n\nPada baris ketiga, program melakukan perulangan untuk setiap item pada `list1`. \n\nPada baris keempat, program memeriksa apakah item yang sedang diperiksa juga ada pada `list2`. \n\nPada baris kelima, jika item ditemukan di kedua list, maka ia akan ditambahkan ke dalam variabel `intersection`.\n\nPada baris keenam, setelah semua perulangan selesai dilakukan, hasil irisan yang ditemukan akan dikembalikan oleh fungsi `find_intersection`.\n\nPada baris kedelapan dan kesembilan, `list1` dan `list2` didefinisikan sebagai list berisi angka-angka. \n\nPada baris kesepuluh, program memanggil fungsi `find_intersection` dengan mengi