In [29]:
import torch
import json
from transformers import AutoProcessor, AutoTokenizer
from datasets import Audio
from PIL import Image
from collections.abc import Mapping

In [22]:
class MMDataset(torch.utils.data.Dataset):

    def __init__(self, folder):
        if folder.endswith('.json'):
            with open(folder) as fopen:
                self.dataset = json.load(fopen)
        elif folder.endswith('.jsonl'):
            self.dataset = []
            with open(folder) as fopen:
                for l in fopen:
                    self.dataset.append(json.loads(l))
        else:
            self.dataset = LocalDataset(folder)  # Assuming LocalDataset is defined elsewhere
        
        self.image_processor = AutoProcessor.from_pretrained('google/siglip-base-patch16-224')
        self.audio_processor = AutoProcessor.from_pretrained('mesolitica/malaysian-whisper-small')
        self.tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-tinyllama-1.1b-16k-instructions-v3')
        
        self.tokenizer.pad_token = self.tokenizer.unk_token
        self.tokenizer.add_bos_token = False
        self.tokenizer.add_eos_token = False
        self.tokenizer.padding_side = "right"
        self.tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
        self.sr = 16000
        self.audio = Audio(sampling_rate=self.sr)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        audio_list = []
        image_list = []

        for x in data['filename']:
            if x.endswith('.mp3'):
                audio = self.audio.decode_example(self.audio.encode_example(x))['array']

                audio_features = self.audio_processor(audio, sampling_rate=self.sr, return_tensors='pt')
                
                audio_list.append(audio_features['input_features']) 

            elif x.endswith('.jpg'):
                image = Image.open(x)

                image_output = self.image_processor(images=image, return_tensors='pt')['pixel_values']

                image_list.append(image_output)

        full_text = self.tokenizer.apply_chat_template(data['conversations'], tokenize=False) # Assuming preprocessor_new is defined elsewhere

        outputs = self.tokenizer(full_text, return_tensors='pt',truncation=True,padding="max_length",max_length=4096,return_overflowing_tokens=False,return_length=False)

        outputs['labels'] = outputs['input_ids'].clone()

        outputs['audios'] = torch.cat(audio_list, dim=0) if audio_list else None
        outputs['images'] = torch.cat(image_list, dim=0) if image_list else None

        return outputs

    def __len__(self):
        return len(self.dataset)

In [23]:
dataset = MMDataset('prepared-combine-ms.jsonl')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [24]:
dataset[1]

{'input_ids': tensor([[    1,   518, 25580,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[    1,   518, 25580,  ...,     0,     0,     0]]), 'audios': tensor([[[ 0.9664,  0.5291,  0.3115,  ..., -0.5044, -0.5044, -0.5044],
         [ 0.9912,  0.4331,  0.4939,  ..., -0.5044, -0.5044, -0.5044],
         [ 0.8548,  0.9311,  0.9561,  ..., -0.5044, -0.5044, -0.5044],
         ...,
         [ 0.2473,  0.1495,  0.1284,  ..., -0.5044, -0.5044, -0.5044],
         [ 0.2915, -0.0713, -0.0756,  ..., -0.5044, -0.5044, -0.5044],
         [ 0.2650, -0.2462, -0.5044,  ..., -0.5044, -0.5044, -0.5044]]]), 'images': tensor([[[[ 0.4667, -0.1686, -0.0667,  ...,  0.7725,  0.7647,  0.7725],
          [ 0.6157, -0.1686, -0.0980,  ...,  0.7725,  0.7647,  0.7804],
          [ 0.6157, -0.2235, -0.2078,  ...,  0.7725,  0.7725,  0.7725],
          ...,
          [ 0.3725,  0.3255,  0.2941,  ..., -0.2471, -0.2549, -0.2471],
          [ 0.3333,  0.3098,  0.3020

In [25]:
class DataCollator():

    def __init__(self, tokenizer):

        self.tokenizer = tokenizer

    def __call__(self, features):

        if not isinstance(features[0], Mapping):
            features = [vars(f) for f in features]

        batch = {}
        bs = len(features)
        first = features[0]

        batch['audio_index'] = torch.tensor([], dtype=torch.int)
        batch['image_index'] = torch.tensor([], dtype=torch.int)

        for index, feature in enumerate(features):
            local_index = index % (bs // torch.cuda.device_count()) if bs > 1 else index % (bs)
            if feature['audios'] is not None:
                batch['audio_index'] = torch.cat([batch['audio_index'], torch.tensor(
                    [local_index] * len(feature['audios']), dtype=torch.int)])

            if feature['images'] is not None:
                batch['image_index'] = torch.cat([batch['image_index'], torch.tensor(
                    [local_index] * len(feature['images']), dtype=torch.int)])

        for k, v in first.items():

            if k not in ("audios", "images") and not isinstance(v, str):
                if v is None:
                    batch[k] = None
                elif isinstance(v, torch.Tensor):
                    batch[k] = torch.stack([f[k] for f in features]).squeeze(1)
                elif isinstance(v, np.ndarray):
                    batch[k] = torch.tensor(np.stack([f[k] for f in features])).squeeze(1)
            elif k in ("audios", "images"):
                if v is None:
                    batch[k] = None
                else:
                    batch[k] = torch.cat([f[k] for f in features if f[k] is not None])

        batch['image_starts'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('<image>')] * bs, dtype=torch.int)
        batch['image_ends'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('</image>')] * bs, dtype=torch.int)
        batch['audio_starts'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('<audio>')] * bs, dtype=torch.int)
        batch['audio_ends'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('</audio>')] * bs, dtype=torch.int)

        return batch

In [27]:
collator = DataCollator(dataset.tokenizer)

In [30]:
collator([dataset[i] for i in range(3)])

{'audio_index': tensor([0, 1, 2, 2], dtype=torch.int32),
 'image_index': tensor([0, 0, 0, 1, 1, 2, 2], dtype=torch.int32),
 'input_ids': tensor([[    1,   518, 25580,  ...,     0,     0,     0],
         [    1,   518, 25580,  ...,     0,     0,     0],
         [    1,   518, 25580,  ...,     0,     0,     0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'labels': tensor([[    1,   518, 25580,  ...,     0,     0,     0],
         [    1,   518, 25580,  ...,     0,     0,     0],
         [    1,   518, 25580,  ...,     0,     0,     0]]),
 'audios': tensor([[[ 0.0796,  0.4540,  0.5375,  ..., -0.5824, -0.5824, -0.5824],
          [ 0.3646,  0.4318,  0.3781,  ..., -0.5824, -0.5824, -0.5824],
          [ 0.2090,  0.3394,  0.2992,  ..., -0.5824, -0.5824, -0.5824],
          ...,
          [-0.5824, -0.5824, -0.5824,  ..., -0.5824, -0.5824, -0.5824],
          [-0.5824, -0.5824, -0.5824,  ..., -0.5824, -0.58