In [1]:
import pandas as pd

df = pd.read_parquet('data/train-00000-of-00001.parquet').to_dict(orient = 'records')
len(df)

360298

In [2]:
from dynamicbatch_ttspipeline.fishspeech.load import load_vqgan
import torch
import IPython.display as ipd

  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)
  @autocast(enabled = False)


In [3]:
from transformers import AutoTokenizer, AutoConfig
from transformers import AddedToken
import os
import numpy as np

tokenizer = AutoTokenizer.from_pretrained('mesolitica/SmolLM2-135M-firefly-vqgan')

In [4]:
speaker = df[0]['speaker']
t = df[0]['transcription']
splitted = df[0]['audio_filename'].split('/')
new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')
speech_t = np.load(new_f)
speech_t = ''.join([f'<|{t}|>' for t in speech_t.tolist()])
tts = f'<|text_start|>{speaker}: {t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'

In [5]:
tokens = tts.split('<|speech_start|>')[1].split('<|speech_end|>')[0]

In [6]:
import re

numbers = [int(t) for t in re.findall(r'<\|(\d+)\|>', tokens)]
np.array(numbers).reshape((-1, 8)).T.shape

(8, 153)

In [7]:
# model = load_vqgan(device = 'cuda')
# i = torch.tensor(np.array(numbers).reshape((-1, 8)).T[None])
# y_, _ = model.decode(i.cuda(), torch.tensor([i.shape[-1]]).cuda())
# ipd.Audio(y_.detach().cpu().numpy()[0, 0], rate = model.spec_transform.sample_rate)

In [8]:
import gc

def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

In [9]:
!rm -rf tokenized-2048
!mkdir tokenized-2048

In [10]:
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
}
hashes = 'sha1', 'xxh64'

In [13]:
tokenizer.model_max_length

8192

In [14]:
import time

def loop(files, block_size = 2048):
    rows, index = files
    out_root = f'tokenized-2048/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            
            speaker = row['speaker']
            t = row['transcription']
            splitted = row['audio_filename'].split('/')
            new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')
            speech_t = np.load(new_f)
            speech_t = ''.join([f'<|{t}|>' for t in speech_t.tolist()])
            tts = f'<|text_start|>{speaker}: {t}<|text_end|><|speech_start|>{speech_t}<|speech_end|>'
            
            outputs = tokenizer(tts, add_special_tokens = False)
            if len(outputs['input_ids']) >= tokenizer.model_max_length:
                continue
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)
                
        block, _ = slice_and_balance(last_block, block_size - count)
        block_position, _ = slice_and_balance(last_position_block, block_size - count)

        block.extend(temp)
        block_position.extend(position_ids)

        o = collator(block, block_position)
        if len(o['input_ids']) == block_size:
            out.write(o)
            return o

In [15]:
loop((df[:1000], 0))

100%|██████████████████████████████████████| 1000/1000 [00:01<00:00, 993.29it/s]


{'input_ids': array([49595, 49955, 49696, ..., 49666, 49890, 49153], dtype=uint32),
 'position_ids': array([ 314,  315,  316, ..., 1812, 1813, 1814], dtype=uint32),
 'attention_mask': array([ 138,   95, 1815], dtype=uint32)}

In [16]:
dataset = LocalDataset('tokenized-2048/tokenized-0')
len(dataset)

586

In [17]:
dataset[0]

{'attention_mask': array([1274,  774], dtype=uint32),
 'input_ids': array([49154, 11062,  1483, ..., 49651, 49408, 49282], dtype=uint32),
 'position_ids': array([  0,   1,   2, ..., 771, 772, 773], dtype=uint32)}

In [18]:
tokenizer.decode(dataset[0]['input_ids'])

'<|text_start|>Osman: Sedangkan dalam bahasa Perancis , frira hanya bererti menggoreng di dalam minyak goreng yang banyak hingga terendam .<|text_end|><|speech_start|><|361|><|704|><|26|><|639|><|759|><|587|><|669|><|533|><|530|><|752|><|18|><|479|><|599|><|348|><|708|><|535|><|768|><|712|><|227|><|639|><|679|><|348|><|302|><|327|><|529|><|478|><|495|><|479|><|989|><|739|><|268|><|646|><|328|><|15|><|770|><|545|><|733|><|178|><|846|><|534|><|522|><|7|><|785|><|738|><|453|><|539|><|219|><|508|><|351|><|59|><|465|><|386|><|455|><|448|><|354|><|447|><|755|><|694|><|663|><|788|><|674|><|540|><|590|><|805|><|264|><|65|><|544|><|312|><|427|><|215|><|159|><|447|><|351|><|62|><|308|><|388|><|346|><|226|><|62|><|286|><|948|><|680|><|622|><|478|><|345|><|16|><|909|><|447|><|936|><|902|><|751|><|852|><|315|><|823|><|470|><|965|><|503|><|269|><|810|><|512|><|789|><|29|><|518|><|560|><|751|><|21|><|107|><|548|><|580|><|467|><|77|><|760|><|949|><|530|><|629|><|916|><|104|><|264|><|751|><|247|><|785|

In [19]:
# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py

In [20]:
from multiprocess import Pool
import mp

chunks = mp.chunks(df, 20000)
pool = Pool(10)
pooled = pool.map(loop, chunks)
pool.close()
pool.join()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [21]:
folders = sorted(glob('tokenized-2048/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

['tokenized-2048/tokenized-0',
 'tokenized-2048/tokenized-1',
 'tokenized-2048/tokenized-2',
 'tokenized-2048/tokenized-3',
 'tokenized-2048/tokenized-4',
 'tokenized-2048/tokenized-5',
 'tokenized-2048/tokenized-6',
 'tokenized-2048/tokenized-7',
 'tokenized-2048/tokenized-8',
 'tokenized-2048/tokenized-9',
 'tokenized-2048/tokenized-10',
 'tokenized-2048/tokenized-11',
 'tokenized-2048/tokenized-12',
 'tokenized-2048/tokenized-13',
 'tokenized-2048/tokenized-14',
 'tokenized-2048/tokenized-15',
 'tokenized-2048/tokenized-16',
 'tokenized-2048/tokenized-17',
 'tokenized-2048/tokenized-18']

In [22]:
with MDSWriter(
    out='smollm2-speech-semantic-multipack-2048', columns=columns, compression=None, hashes=hashes) as out:
    for f in folders:
        try:
            dataset = LocalDataset(local=f)
            for i in tqdm(range(len(dataset))):
                out.write(dataset[i])
        except Exception as e:
            print(e)
            pass

100%|██████████████████████████████████| 11692/11692 [00:00<00:00, 28964.85it/s]
100%|██████████████████████████████████| 11617/11617 [00:00<00:00, 23850.61it/s]
100%|██████████████████████████████████| 12904/12904 [00:00<00:00, 25673.80it/s]
100%|██████████████████████████████████| 13956/13956 [00:00<00:00, 22783.10it/s]
100%|██████████████████████████████████| 13866/13866 [00:00<00:00, 26494.07it/s]
100%|██████████████████████████████████| 12705/12705 [00:00<00:00, 25609.02it/s]
100%|██████████████████████████████████| 12289/12289 [00:00<00:00, 25136.33it/s]
100%|██████████████████████████████████| 12951/12951 [00:00<00:00, 26063.83it/s]
100%|██████████████████████████████████| 14811/14811 [00:00<00:00, 22708.50it/s]
100%|██████████████████████████████████| 14884/14884 [00:00<00:00, 22629.24it/s]
100%|██████████████████████████████████| 13082/13082 [00:00<00:00, 25237.33it/s]
100%|██████████████████████████████████| 10739/10739 [00:00<00:00, 22355.06it/s]
100%|███████████████████████

In [23]:
dataset = LocalDataset('smollm2-speech-semantic-multipack-2048')

In [24]:
(len(dataset) * 2048) / 1e9

0.458508288