In [1]:
# !pip3 install transformers -U

In [2]:
from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
from transformers import CLIPVisionModel, CLIPImageProcessor
from transformers import MistralConfig
import numpy as np
import requests
from PIL import Image

In [3]:
vision_config = CLIPVisionConfig.from_pretrained('openai/clip-vit-base-patch16')
vision_config

CLIPVisionConfig {
  "attention_dropout": 0.0,
  "dropout": 0.0,
  "hidden_act": "quick_gelu",
  "hidden_size": 768,
  "image_size": 224,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "model_type": "clip_vision_model",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "projection_dim": 512,
  "transformers_version": "4.36.2"
}

In [4]:
text_config = MistralConfig.from_pretrained('mesolitica/malaysian-mistral-7b-32k-instructions-v3.5')

In [5]:
configuration = LlavaConfig(vision_config, text_config)

In [6]:
model = LlavaForConditionalGeneration(configuration)

In [7]:
model.vision_tower = model.vision_tower.from_pretrained('openai/clip-vit-base-patch16')

In [8]:
model.language_model = model.language_model.from_pretrained('mesolitica/malaysian-mistral-7b-32k-instructions-v3.5')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
from transformers import AutoProcessor, AutoTokenizer

In [10]:
processor = AutoProcessor.from_pretrained('llava-hf/bakLlava-v1-hf')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [11]:
tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-mistral-7b-32k-instructions-v3.5')
tokenizer.add_tokens(["<image>", "<pad>"])

2

In [12]:
clip_processor = AutoProcessor.from_pretrained('openai/clip-vit-base-patch16')

In [13]:
processor.tokenizer = tokenizer
processor.image_processor.crop_size = clip_processor.image_processor.crop_size

In [14]:
model.resize_token_embeddings(len(tokenizer))

Embedding(32002, 4096)

In [15]:
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
prompt1 = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
vision_feature_layer = model.config.vision_feature_layer

In [16]:
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(
    [prompt, prompt1], 
    [np.array(raw_image), np.array(raw_image)], 
    return_tensors='pt',
    padding=True,
)

In [17]:
inputs.input_ids.shape

torch.Size([2, 15])

In [18]:
inputs.pixel_values.shape

torch.Size([2, 3, 224, 224])

In [19]:
pixel_values = inputs['pixel_values']
image_outputs = model.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
selected_image_feature.shape

torch.Size([2, 197, 768])

In [20]:
processor.push_to_hub('huseinzol05/dummy-clip-vit-base-patch16-malaysian-mistral-7b-32k-instructions-v3.5')

CommitInfo(commit_url='https://huggingface.co/huseinzol05/dummy-clip-vit-base-patch16-malaysian-mistral-7b-32k-instructions-v3.5/commit/866802f1eb3eb647ea533493e60361605e4cad29', commit_message='Upload processor', commit_description='', oid='866802f1eb3eb647ea533493e60361605e4cad29', pr_url=None, pr_revision=None, pr_num=None)

In [21]:
import torch

model = model.type(torch.bfloat16)

In [22]:
model.push_to_hub('huseinzol05/dummy-clip-vit-base-patch16-malaysian-mistral-7b-32k-instructions-v3.5')

model-00003-of-00003.safetensors:   0%|          | 0.00/4.86G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/huseinzol05/dummy-clip-vit-base-patch16-malaysian-mistral-7b-32k-instructions-v3.5/commit/989076d8d3221599b26b93bf58e03ecc3bcdcbef', commit_message='Upload LlavaForConditionalGeneration', commit_description='', oid='989076d8d3221599b26b93bf58e03ecc3bcdcbef', pr_url=None, pr_revision=None, pr_num=None)