In [10]:
# !wget https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav
# !wget https://huggingface.co/datasets/mesolitica/Malaysian-SFT/resolve/main/combine/combined-malaysian-sft-20k-sample.jsonl

In [12]:
# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py

In [2]:
import os

os.environ['HF_HOME'] = '/home/husein/ssd3'
os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [13]:
import librosa
import torch
import torch.nn as nn
import pandas as pd
from datasets import Audio
from peft import LoraConfig, get_peft_model
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, AutoConfig, AutoModelForCausalLM
from streaming import MDSWriter
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import streaming
import numpy as np
from tqdm import tqdm
from glob import glob
import os
import json
import mp

class UInt32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint32)

_encodings['uint32'] = UInt32

columns = {
    'input_ids': 'uint32',
    'position_ids': 'uint32',
    'attention_mask': 'uint32',
    'audio': 'str',
    'text': 'str'
}
hashes = 'sha1', 'xxh64'

In [4]:
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
audio_token = "<|AUDIO|>"
audio_bos_token = "<|audio_bos|>"
audio_eos_token = "<|audio_eos|>"
audio_token_id = processor.tokenizer._convert_token_to_id_with_added_voc('<|AUDIO|>')
pad_token_id = processor.tokenizer.pad_token_id
tokenizer = processor.tokenizer
torch_dtype = torch.bfloat16
min_dtype = torch.finfo(torch_dtype).min
sequence_length = 4096

In [5]:
combine = []
with open('combined-malaysian-sft-20k-sample.jsonl') as fopen:
    for l in fopen:
        l = json.loads(l)
        combine.append(l)

len(combine)

884949

In [6]:
import gc

def collator(batch, batch_position_ids):
    input_ids = []
    position_ids = []
    masks = []
    for i in range(len(batch)):
        l = len(batch[i])
        input_ids.extend(batch[i])
        position_ids.extend(batch_position_ids[i])
        masks.append(l)
    
    return {
        'input_ids': np.array(input_ids).astype(np.uint32),
        'position_ids': np.array(position_ids).astype(np.uint32),
        'attention_mask': np.array(masks).astype(np.uint32),
        'audio': '',
        'text': '',
    }

def slice_and_balance(nested_list, size):
    first = []
    balance = []
    current_size = 0

    for sublist in nested_list:
        if current_size < size:
            remaining_space = size - current_size
            if len(sublist) <= remaining_space:
                first.append(sublist)
                current_size += len(sublist)
            else:
                first.append(sublist[:remaining_space])
                balance.append(sublist[remaining_space:])
                current_size = size
        else:
            balance.append(sublist)
    
    return first, balance

In [7]:
!mkdir tokenized-4k

In [8]:
import time

def loop(files, block_size = sequence_length):
    rows, index = files
    out_root = f'tokenized-4k/tokenized-{index}'
    os.system(f'rm -rf {out_root}')
    count = 0
    temp = []
    position_ids = []
    last_block, last_position_block = None, None
    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:
        for row in tqdm(rows):
            prompt = tokenizer.apply_chat_template(row, tokenize=False)
            outputs = tokenizer(prompt, add_special_tokens = False)
            temp.append(outputs['input_ids'])
            position_ids.append(range(len(outputs['input_ids'])))
            count += len(outputs['input_ids'])
            while count >= block_size:
                block, temp = slice_and_balance(temp, block_size)
                block_position, position_ids = slice_and_balance(position_ids, block_size)
                count = count - block_size
                o = collator(block, block_position)
                last_block = block
                last_position_block = block_position
                out.write(o)
                
        block, _ = slice_and_balance(last_block, block_size - count)
        block_position, _ = slice_and_balance(last_position_block, block_size - count)

        block.extend(temp)
        block_position.extend(position_ids)

        o = collator(block, block_position)
        if len(o['input_ids']) == block_size:
            out.write(o)
            return o

In [14]:
from multiprocess import Pool
import mp

chunks = mp.chunks(combine, 50000)
pool = Pool(10)
pooled = pool.map(loop, chunks)
pool.close()
pool.join()

 11%|█████████                                                                        | 5623/50000 [00:07<01:02, 709.43it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (10352 > 8192). Running this sequence through the model will result in indexing errors
100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:24<00:00, 2065.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:24<00:00, 2034.56it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:37<00:00, 1335.10it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:17<00:00, 645.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 50000/50000 [01:25<00:00, 582.98it/s]
100%|█████████████████████████████████████████████████████████████

In [15]:
folders = sorted(glob('tokenized-4k/tokenized-*'), key = lambda x: int(x.split('-')[-1]))
folders

['tokenized-4k/tokenized-0',
 'tokenized-4k/tokenized-1',
 'tokenized-4k/tokenized-2',
 'tokenized-4k/tokenized-3',
 'tokenized-4k/tokenized-4',
 'tokenized-4k/tokenized-5',
 'tokenized-4k/tokenized-6',
 'tokenized-4k/tokenized-7',
 'tokenized-4k/tokenized-8',
 'tokenized-4k/tokenized-9',
 'tokenized-4k/tokenized-10',
 'tokenized-4k/tokenized-11',
 'tokenized-4k/tokenized-12',
 'tokenized-4k/tokenized-13',
 'tokenized-4k/tokenized-14',
 'tokenized-4k/tokenized-15',
 'tokenized-4k/tokenized-16',
 'tokenized-4k/tokenized-17']