In [1]:
!nvidia-smi

Thu Feb 15 16:44:28 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000001:00:00.0 Off |                    0 |
| N/A   58C    P0             267W / 400W |  45485MiB / 81920MiB |     99%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000002:00:00.0 Off |  

In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [3]:
from modeling_audio import MM_LLMs, MM_LLMs_Config
from transformers import AutoModelForCausalLM, CLIPProcessor, CLIPModel,AutoModel, AutoTokenizer, AutoProcessor,AutoConfig,CLIPConfig, LlamaConfig, WhisperConfig, WhisperModel, LlamaModel, LlamaTokenizer
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torch
import numpy as np
from torch import nn
from streaming import LocalDataset
from typing import List

In [4]:
MM_LLMs.register_for_auto_class()
MM_LLMs_Config.register_for_auto_class()

In [5]:
from transformers.trainer_utils import get_last_checkpoint

latest = get_last_checkpoint('audio-alignment-mistral')
latest

'audio-alignment-mistral/checkpoint-7800'

In [6]:
model = MM_LLMs.from_pretrained(
    latest,flash_attention = True, dtype = torch.bfloat16, torch_dtype = torch.bfloat16
)

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
model.config

MM_LLMs_Config {
  "_name_or_path": "audio-alignment-mistral/checkpoint-7800",
  "architectures": [
    "MM_LLMs"
  ],
  "audio_config": {
    "_name_or_path": "mesolitica/malaysian-whisper-small",
    "activation_dropout": 0.0,
    "activation_function": "gelu",
    "add_cross_attention": false,
    "apply_spec_augment": false,
    "architectures": [
      "WhisperForConditionalGeneration"
    ],
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": [
      220,
      50257
    ],
    "bos_token_id": 50257,
    "chunk_size_feed_forward": 0,
    "classifier_proj_size": 256,
    "cross_attention_hidden_size": null,
    "d_model": 768,
    "decoder_attention_heads": 12,
    "decoder_ffn_dim": 3072,
    "decoder_layerdrop": 0.0,
    "decoder_layers": 12,
    "decoder_start_token_id": 50258,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_attention_heads": 12,
    "encoder_ffn_dim": 3072,
  

In [8]:
_ = model.cuda()

In [9]:
audio_processor = AutoProcessor.from_pretrained('mesolitica/malaysian-whisper-small')
tokenizer = AutoTokenizer.from_pretrained(latest)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
# model.llm.generation_config.eos_token_id = tokenizer.eos_token_id
model.llm.generation_config

GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}

In [11]:
from PIL import Image
import librosa
import torch
import numpy as np
from collections.abc import Mapping

class DataCollator():

    def __init__(self, tokenizer):

        self.tokenizer = tokenizer

    def __call__(self, features):

        if not isinstance(features[0], Mapping):
            features = [vars(f) for f in features]

        batch = {}
        bs = len(features)
        first = features[0]

        batch['audio_index'] = torch.tensor([],dtype=torch.int)
        
        for index, feature in enumerate(features):
            local_index = index % (bs) 

            if feature['audios'] is not None:
                batch['audio_index'] = torch.cat([batch['audio_index'], torch.tensor(
                    [local_index] * len(feature['audios']), dtype=torch.int)])

        for k, v in first.items():

            if k not in ("audios","images") and not isinstance(v, str):
                if v is None:
                    batch[k] = None
                elif isinstance(v, torch.Tensor):
                    batch[k] = torch.stack([f[k] for f in features]).squeeze(1)
                elif isinstance(v, np.ndarray):
                    batch[k] = torch.tensor(np.stack([f[k] for f in features])).squeeze(1)
            elif k in ("audios","images"):
                if v is None:
                    batch[k] = None
                else:         
                    batch[k] = torch.cat([f[k] for f in features if f[k] is not None])

        batch['audio_starts'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('<audio>')] * bs, dtype=torch.int)
        batch['audio_ends'] = torch.tensor(
            [self.tokenizer.convert_tokens_to_ids('</audio>')] * bs, dtype=torch.int)

        return batch

collator = DataCollator(tokenizer)

In [12]:
def prepare_dataset(messages, audio: List[str] = None, sr = 16000):
    if audio is not None:
        audio = [librosa.load(f, sr=sr)[0] for f in audio]
        audio_features = audio_processor(audio, sampling_rate=sr, return_tensors='pt',)['input_features']
    else:
        audio_features = None
    
    prompt = tokenizer.apply_chat_template(messages, tokenize = False)
    outputs = tokenizer(
                    prompt,
                    return_tensors='pt',
                    return_overflowing_tokens=False,
                    return_length=False)

    outputs['audios'] = audio_features
    return outputs

In [13]:
messages = [
    {'role': 'user', 'content': '<audio> </audio> audio ni tentang apa'},
]
outputs = prepare_dataset(messages, audio = ['test.mp3'])
ok = collator([outputs])
ok['labels'] = ok['input_ids']

for k in ok.keys():
    if ok[k] is not None:
        ok[k] = ok[k].cuda()
        
for k in ['audios']:
    if ok[k] is not None:
        ok[k] = ok[k].type(model.dtype)

In [14]:
with torch.no_grad():
    model_inputs = model.prepare_inputs_for_generation(**ok)
r = model_inputs.pop('input_ids', None)
label = model_inputs.pop('labels', None)
label = label.detach().cpu().numpy()
ok['input_ids'].shape, model_inputs['inputs_embeds'].shape

(torch.Size([1, 16]), torch.Size([1, 503, 4096]))

In [15]:
model.push_to_hub('malaysian-mistral-malaysian-whisper-small-audio-alignment', organization='mesolitica', safe_serialization=True)



model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/285M [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-malaysian-whisper-small-audio-alignment/commit/fb8e179382ab387fc46cabb9c8a529190acdb1f3', commit_message='Upload MM_LLMs', commit_description='', oid='fb8e179382ab387fc46cabb9c8a529190acdb1f3', pr_url=None, pr_revision=None, pr_num=None)

In [16]:
audio_processor.push_to_hub('malaysian-mistral-malaysian-whisper-small-audio-alignment', organization='mesolitica', safe_serialization=True)



README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-malaysian-whisper-small-audio-alignment/commit/c3bc88e628878d7f1e0418ad23cccaaabd1b9fe0', commit_message='Upload processor', commit_description='', oid='c3bc88e628878d7f1e0418ad23cccaaabd1b9fe0', pr_url=None, pr_revision=None, pr_num=None)

In [17]:
tokenizer.push_to_hub('malaysian-mistral-malaysian-whisper-small-audio-alignment', organization='mesolitica', safe_serialization=True)

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-malaysian-whisper-small-audio-alignment/commit/dba5b3b9c9feaed3b90ec4c49c7160574896e21e', commit_message='Upload tokenizer', commit_description='', oid='dba5b3b9c9feaed3b90ec4c49c7160574896e21e', pr_url=None, pr_revision=None, pr_num=None)