In [30]:
import torch

torch.set_grad_enabled(False)

import torch.nn as nn
import pandas as pd
from datasets import Audio
from transformers import AutoTokenizer
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json
from multiprocess import Pool
import itertools

def chunks(l, n):
    for i in range(0, len(l), n):
        yield (l[i: i + n], i // n)

def multiprocessing(strings, function, cores=6, returned=True):
    df_split = chunks(strings, len(strings) // cores)
    pool = Pool(cores)
    pooled = pool.map(function, df_split)
    pool.close()
    pool.join()

    if returned:
        return list(itertools.chain(*pooled))

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
    'audio': 'str',
    'text': 'str'
}
hashes = 'sha1', 'xxh64'

In [2]:
df_mapping = pd.read_parquet('chunk-streaming-flatten.parquet')
mapping = {}
for i in tqdm(range(len(df_mapping))):
    mapping[df_mapping.iloc[i]['filename_audio']] = i

100%|██████████| 5327569/5327569 [01:11<00:00, 74898.87it/s]


In [3]:
df = pd.read_parquet('chunk-streaming.parquet').to_dict(orient = 'records')
df[0]

{'chunk': array([array(['Menurutnya, kejadian itu dipercayai',
               'chunk-streaming/prepare-dataset-normalizer-text-malay-news-husein_41906_0.mp3'],
              dtype=object)                                                                     ,
        array(['berlaku',
               'chunk-streaming/prepare-dataset-normalizer-text-malay-news-husein_41906_1.mp3'],
              dtype=object)                                                                     ,
        array(['di Kilometer tiga puluh empat lebuh raya berkenaan pada pukul enam titik sepuluh petang,',
               'chunk-streaming/prepare-dataset-normalizer-text-malay-news-husein_41906_2.mp3'],
              dtype=object)                                                                               ,
        array(['kelmarin.',
               'chunk-streaming/prepare-dataset-normalizer-text-malay-news-husein_41906_3.mp3'],
              dtype=object)                                                         

In [4]:
with open('accept-streaming-chunk.json') as fopen:
    accepted = set(json.load(fopen))

df = [df[i] for i in range(len(df)) if i in accepted]
len(df)

977091

In [5]:
tokenizer = AutoTokenizer.from_pretrained('mesolitica/Malaysian-TTS-1.7B')

In [6]:
import gc

def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
        'audio': '',
        'text': '',
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

In [25]:
!rm -rf tokenized-4k-qwen3-streaming
!mkdir tokenized-4k-qwen3-streaming

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [26]:
chunk = df[0]['chunk']
speaker = 'husein' if 'husein' in chunk[0][1] else 'idayu'
prompt = ''
for c in chunk:
    with open(f'chunk-streaming-flatten/{mapping[c[1]]}.json') as fopen:
        token = json.load(fopen)
    token = ''.join([f'<|speech_{t}|>' for t in token])
    prompt += f'{c[0]}<|speech_start|>{token}<|im_end|>'
prompt = f'streaming,{speaker}: {prompt}'

In [27]:
import time

sequence_length = 4096
def loop(files, block_size = sequence_length):
    rows, index = files
    out_root = f'tokenized-4k-qwen3-streaming/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):

            chunk = row['chunk']
            speaker = 'husein' if 'husein' in row['chunk'][0][1] else 'idayu'
            prompt = ''
            for c in chunk:
                with open(f'chunk-streaming-flatten/{mapping[c[1]]}.json') as fopen:
                    token = json.load(fopen)
                token = ''.join([f'<|speech_{t}|>' for t in token])
                prompt += f'{c[0]}<|speech_start|>{token}<|im_end|>'
            prompt = f'streaming,{speaker}: {prompt}'
            
            outputs = tokenizer(prompt, add_special_tokens = False)
            position = range(len(outputs['input_ids']))
            length = len(outputs['input_ids'])
            
            if count + length > block_size:
                o = collator(temp, position_ids)
                out.write(o)
                temp = [outputs['input_ids']]
                position_ids = [position]
                count = length
                
            else:
                temp.append(outputs['input_ids'])
                position_ids.append(range(len(outputs['input_ids'])))
                count += len(outputs['input_ids'])
        
        if len(temp):
            o = collator(temp, position_ids)
            out.write(o)
            

In [28]:
loop((df[:100], 0))

100%|██████████| 100/100 [00:00<00:00, 585.21it/s]


In [None]:
from multiprocess import Pool

chunks = chunks(df, 20000)
pool = Pool(20)
pooled = pool.map(loop, chunks)
pool.close()
pool.join()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av