In [1]:
from transformers import AutoProcessor, AutoTokenizer
import torch
import numpy as np
import json
from PIL import Image
from typing import Mapping, List, Dict
from datasets import Audio
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset

In [2]:
class DataCollator():

    def __init__(self, tokenizer):

        self.tokenizer = tokenizer

    def __call__(self, features):

        features = [f for f in features if f is not None]

        if not isinstance(features[0], Mapping):
            features = [vars(f) for f in features]

        batch = {}
        bs = len(features)
        first = features[0]

        batch['audio_index'] = torch.tensor([], dtype=torch.int)
        batch['image_index'] = torch.tensor([], dtype=torch.int)

        for index, feature in enumerate(features):
            local_index = index % (bs)

            if feature['audios_bool'][0] and feature['audios'] is not None:
                batch['audio_index'] = torch.cat([batch['audio_index'], torch.tensor(
                    [local_index] * len(feature['audios']), dtype=torch.int)])

            if feature['images_bool'][0] and feature['images'] is not None:
                batch['image_index'] = torch.cat([batch['image_index'], torch.tensor(
                    [local_index] * len(feature['images']), dtype=torch.int)])

        for k, v in first.items():

            if k not in (
                    "audios",
                    "images",
                    "input_ids",
                    "attention_mask"
            ) and not isinstance(v, str):
                if v is None:
                    batch[k] = None
                elif isinstance(v, torch.Tensor):
                    batch[k] = torch.stack([f[k] for f in features]).squeeze(1)
                elif isinstance(v, np.ndarray):
                    batch[k] = torch.tensor(np.stack([f[k] for f in features])).squeeze(1)
            elif k in ("audios", "images"):
                if v is None:
                    batch[k] = None
                else:
                    batch[k] = torch.cat([f[k] for f in features if f[k] is not None])

        input_ids = [{'input_ids': f['input_ids'][0]} for f in features]
        input_ids = self.tokenizer.pad(input_ids)
        batch['input_ids'] = input_ids['input_ids']
        batch['attention_mask'] = input_ids['attention_mask']
        batch['labels'] = input_ids['input_ids'].clone()
        batch['labels'][batch['labels'] == self.tokenizer.pad_token_id] = -100

        image_token = self.tokenizer.convert_tokens_to_ids('<image>')
        image_end_token = self.tokenizer.convert_tokens_to_ids('</image>')
        audio_token = self.tokenizer.convert_tokens_to_ids('<audio>')
        audio_end_token = self.tokenizer.convert_tokens_to_ids('</audio>')

        batch['image_starts'] = torch.tensor([image_token] * bs, dtype=torch.int)
        batch['image_ends'] = torch.tensor([image_end_token] * bs, dtype=torch.int)
        batch['audio_starts'] = torch.tensor([audio_token] * bs, dtype=torch.int)
        batch['audio_ends'] = torch.tensor([audio_end_token] * bs, dtype=torch.int)

        where_is = torch.where((batch['input_ids'] == image_token) | (batch['input_ids'] == audio_token))
        ls = []
        for i in range(len(where_is[0])):
            b, k = where_is[0][i], where_is[1][i]
            l = int(batch['input_ids'][b, k])
            ls.append(l)

        ls = torch.tensor(ls)
        batch['where_is_b'] = where_is[0]
        batch['where_is_k'] = where_is[1]
        batch['ls'] = ls

        return batch

In [3]:
tokenizer = AutoTokenizer.from_pretrained('./combine-tinyllama')
tokenizer.add_tokens(["<image>", "</image>", "<audio>", "</audio>"])
max_length = 8192

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
image_processor = AutoProcessor.from_pretrained('google/siglip-base-patch16-384')
default_height = image_processor.image_processor.size['height']
audio_processor = AutoProcessor.from_pretrained('mesolitica/malaysian-whisper-small')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
class ListOfDict(Encoding):
    def encode(self, obj: List[dict]) -> bytes:
        json_str = json.dumps(obj)
        return json_str.encode('utf-8')

    def decode(self, data: bytes) -> List[dict]:
        json_str = data.decode('utf-8')
        return json.loads(json_str)

_encodings['list_of_dict'] = ListOfDict

class MMDataset(torch.utils.data.Dataset):

    def __init__(self, folder):
        if folder.endswith('.json'):
            with open(folder) as fopen:
                self.dataset = json.load(fopen)
        else:
            self.dataset = LocalDataset(folder)

        self.sr = 16000
        self.audio = Audio(sampling_rate=self.sr)

    def __getitem__(self, idx):
        try:
            data = self.dataset[idx]
            
            audio = np.zeros((self.sr * 10,))
            audio_features = audio_processor(audio, sampling_rate=self.sr, return_tensors='pt')
            audio_list = [audio_features['input_features']]
            audio_bool = [True]

            image = np.zeros((3, default_height, default_height))

            image_output = image_processor(
                images=image, return_tensors='pt')['pixel_values']
            image_list = [image_output]
            image_bool = [True] 

            for x in data['filename']:
                if x.endswith('.mp3'):
                    audio = self.audio.decode_example(self.audio.encode_example(x))['array']

                    audio_features = audio_processor(
                        audio, sampling_rate=self.sr, return_tensors='pt')

                    audio_list.append(audio_features['input_features'])
                    audio_bool.append(True)

                elif x.endswith('.jpg'):
                    image = Image.open(x).convert('RGB')

                    image_output = image_processor(
                        images=image, return_tensors='pt')['pixel_values']

                    image_list.append(image_output)
                    image_bool.append(True)

            full_text = tokenizer.apply_chat_template(data['conversations'], tokenize=False)
            full_text = f'<image> <audio> {full_text}'

            outputs = tokenizer(
                full_text,
                return_tensors='pt',
                truncation=True,
                max_length=max_length,
                return_overflowing_tokens=False,
                return_length=False
            )

            outputs['audios'] = torch.cat(audio_list, dim=0) if audio_list else None
            outputs['images'] = torch.cat(image_list, dim=0) if image_list else None
            outputs['audios_bool'] = audio_bool
            outputs['images_bool'] = image_bool

            return outputs
        except Exception as e:
            print(e)
            return None

    def __len__(self):
        return len(self.dataset)

In [6]:
train_dataset = MMDataset('mosaic-multimodal')
data_collator = DataCollator(tokenizer=tokenizer)

In [7]:
import random

ranged = list(range(len(train_dataset)))

In [8]:
b = [train_dataset[random.choice(ranged)] for i in range(10)]
input_ids = data_collator(b)

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [9]:
%%time
input_ids = data_collator(b)

CPU times: user 38.8 s, sys: 74.3 ms, total: 38.9 s
Wall time: 473 ms


In [10]:
input_ids

{'audio_index': tensor([0, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9], dtype=torch.int32),
 'image_index': tensor([0, 0, 0, 1, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 8, 9], dtype=torch.int32),
 'audios': tensor([[[-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          ...,
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000]],
 
         [[-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          ...,
          [-1.5000, -1.5000, -1.5000,  ..., -1.5000, -1.5000, -1.5000],
          [-1.5000, -1.5000, -1.5000,  ..., -1.5

In [11]:
from modeling_combine import MM_LLMs, MM_LLMs_Config

In [12]:
model = MM_LLMs.from_pretrained('./combine-tinyllama')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
r = model.prepare_inputs_for_generation(**input_ids)

In [20]:
print(r['labels'][0].numpy().tolist()[:1100])

[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -10

In [25]:
r['labels'][0][r['labels'][0] != -100]

tensor([29871,     1,   518, 25580, 29962, 32000, 32001, 32000, 32001,  1724,
          338,  4475,  1546,  7623, 29871, 29896,   322,  7623, 29871, 29906,
          518, 29914, 25580, 29962,  3112,   338, 25057,   393,   727,   338,
          263,  1513,  9443,  1546,  7623, 29871, 29896,   322,  7623, 29871,
        29906, 29889, 28908, 29871, 29896,   338,   385,  1967,   310,   263,
        16423,  6492,  1754,   411,   380,  7358, 29892,  1550,  7623, 29871,
        29906,  3697,  1023,  5648,  9763,  1634,   272,  2153, 13587,  2723,
          567,  1550,  9963, 29889,  3118,  1950,  3957,  1033,   367,   393,
          278, 16423,  6492,   297,  7623, 29871, 29896,   338,  1641, 15000,
          297,   263,  9763,  3461, 29892,   322,   278,  1634,   272,  2153,
          297,  7623, 29871, 29906,   526,  5353,   292,   372, 29889,  2398,
        29892,  1728,  5684,  3030, 29892,   445,   338,   925,  1580,  2785,
        29889,     2])

In [26]:
tokenizer.decode(r['labels'][0][r['labels'][0] != -100])

'<s> [INST]<image></image><image></image> What is related between picture 1 and picture 2 [/INST]It is unlikely that there is a direct relationship between picture 1 and picture 2. Picture 1 is an image of a garden plot made with sticks, while picture 2 shows two TV news reporters holding cups while talking. One possible connection could be that the garden plot in picture 1 is being featured in a news report, and the reporters in picture 2 are discussing it. However, without additional context, this is just speculation.</s>'

In [None]:
r['input_ids'].shape

In [None]:
r['inputs_embeds'].shape

In [30]:
print(r['attention_mask'][0].numpy().tolist()[:1100])

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [None]:
r['labels'] == -100