In [1]:
!nvidia-smi

Mon Feb 12 19:59:07 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000001:00:00.0 Off |                    0 |
| N/A   39C    P0              66W / 400W |      7MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-80GB          On  | 00000002:00:00.0 Off |  

In [2]:
from modeling_vision import MM_LLMs, MM_LLMs_Config
from transformers import AutoModelForCausalLM, CLIPProcessor, CLIPModel,AutoModel, AutoTokenizer, AutoProcessor,AutoConfig,CLIPConfig, LlamaConfig, WhisperConfig, WhisperModel, LlamaModel, LlamaTokenizer
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import torch
import numpy as np
from torch import nn
from streaming import LocalDataset
from typing import List

In [3]:
from transformers.trainer_utils import get_last_checkpoint

latest = get_last_checkpoint('vision-alignment-mistral')
latest

'vision-alignment-mistral/checkpoint-24800'

In [4]:
model = MM_LLMs.from_pretrained(
    latest,flash_attention = True, dtype = torch.bfloat16, torch_dtype = torch.bfloat16
)

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
_ = model.cuda()

In [6]:
model.dtype

torch.bfloat16

In [7]:
model.llm

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32004, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNor

In [8]:
image_processor = AutoProcessor.from_pretrained('google/siglip-base-patch16-384')
tokenizer = AutoTokenizer.from_pretrained(latest)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
from PIL import Image
import librosa
import torch
import numpy as np
from collections.abc import Mapping

class DataCollator():

    def __init__(self, tokenizer):

        self.tokenizer = tokenizer

    def __call__(self, features):

        if not isinstance(features[0], Mapping):
            features = [vars(f) for f in features]

        batch = {}
        bs = len(features)
        first = features[0]

        batch['audio_index'] = torch.tensor([],dtype=torch.int)
        batch['image_index'] = torch.tensor([],dtype=torch.int)
        
        for index, feature in enumerate(features):
            local_index = index % (bs) 

            if feature['images'] is not None:
                batch['image_index'] = torch.cat([batch['image_index'], torch.tensor([local_index] * len(feature['images']), dtype=torch.int)])

        for k, v in first.items():

            if k not in ("audios","images") and not isinstance(v, str):
                if v is None:
                    batch[k] = None
                elif isinstance(v, torch.Tensor):
                    batch[k] = torch.stack([f[k] for f in features]).squeeze(1)
                elif isinstance(v, np.ndarray):
                    batch[k] = torch.tensor(np.stack([f[k] for f in features])).squeeze(1)
            elif k in ("audios","images"):
                if v is None:
                    batch[k] = None
                else:         
                    batch[k] = torch.cat([f[k] for f in features if f[k] is not None])

        batch['image_starts'] = torch.tensor([self.tokenizer.convert_tokens_to_ids('<image>')] * bs, dtype=torch.int)
        batch['image_ends'] = torch.tensor([self.tokenizer.convert_tokens_to_ids('</image>')] * bs, dtype=torch.int)

        return batch

collator = DataCollator(tokenizer)

In [10]:
height = image_processor.image_processor.size['height']

In [11]:
def prepare_dataset(messages, images: List[str] = None):
    if images is not None:
        images = [Image.open(f).convert('RGB') for f in images]
        image_output = image_processor(images=images, return_tensors='pt')['pixel_values']
    else:
        image_output = None
    
    prompt = tokenizer.apply_chat_template(messages, tokenize = False)
    outputs = tokenizer(
                    prompt,
                    return_tensors='pt',
                    return_overflowing_tokens=False,
                    return_length=False)

    outputs['images'] = image_output
    return outputs

In [12]:
messages = [
    {'role': 'user', 'content': '<image> </image> ini gambar apa'},
]
outputs = prepare_dataset(messages, images = ['motosikal.jpeg'])
ok = collator([outputs])
ok['labels'] = ok['input_ids']

for k in ok.keys():
    if ok[k] is not None:
        ok[k] = ok[k].cuda()
        
for k in ['images']:
    if ok[k] is not None:
        ok[k] = ok[k].type(model.dtype)

In [13]:
with torch.no_grad():
    model_inputs = model.prepare_inputs_for_generation(**ok)
r = model_inputs.pop('input_ids', None)
label = model_inputs.pop('labels', None)
label = label.detach().cpu().numpy()
ok['input_ids'].shape, model_inputs['inputs_embeds'].shape

(torch.Size([1, 17]), torch.Size([1, 593, 4096]))

In [14]:
model.push_to_hub('malaysian-mistral-siglip-base-384-vision-alignment', organization='mesolitica', safe_serialization=True)



model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-siglip-base-384-vision-alignment/commit/a441e48b21218675c7b7b1afff5cc50bd3b0eeab', commit_message='Upload MM_LLMs', commit_description='', oid='a441e48b21218675c7b7b1afff5cc50bd3b0eeab', pr_url=None, pr_revision=None, pr_num=None)

In [15]:
image_processor.push_to_hub('malaysian-mistral-siglip-base-384-vision-alignment', organization='mesolitica', safe_serialization=True)



README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-siglip-base-384-vision-alignment/commit/3137b6fea993473bd9c91331665f50872d26e5ac', commit_message='Upload processor', commit_description='', oid='3137b6fea993473bd9c91331665f50872d26e5ac', pr_url=None, pr_revision=None, pr_num=None)

In [16]:
tokenizer.push_to_hub('malaysian-mistral-siglip-base-384-vision-alignment', organization='mesolitica', safe_serialization=True)

CommitInfo(commit_url='https://huggingface.co/mesolitica/malaysian-mistral-siglip-base-384-vision-alignment/commit/e5d07b083ba1f0adc6867ebd8cefb88b1a3f770f', commit_message='Upload tokenizer', commit_description='', oid='e5d07b083ba1f0adc6867ebd8cefb88b1a3f770f', pr_url=None, pr_revision=None, pr_num=None)