In [1]:
from IPython.display import clear_output

!pip install accelerate -U
!pip install transformers -U
!pip install bitsandbytes deepspeed wandb peft
!pip install mpi4py
!pip install flash-attn --no-build-isolation

clear_output()

In [3]:
!git clone --branch maya_pretrain https://github.com/nahidalam/LLaVA.git

!wget https://huggingface.co/nahidalam/Maya/resolve/main/mm_projector.bin


Cloning into 'LLaVA'...
remote: Enumerating objects: 2407, done.[K
remote: Counting objects: 100% (141/141), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 2407 (delta 94), reused 103 (delta 68), pack-reused 2266[K
Receiving objects: 100% (2407/2407), 13.74 MiB | 10.95 MiB/s, done.
Resolving deltas: 100% (1477/1477), done.


In [1]:
from IPython.display import clear_output
import torch
from transformers import AutoTokenizer, AutoConfig

import sys
sys.path.insert(0,'/content/LLaVA/')

from transformers.models.cohere.tokenization_cohere_fast import CohereTokenizerFast
from llava.model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

%load_ext autoreload
%autoreload 2

In [3]:
device_map = 'auto'
kwargs = {"device_map": device_map}
kwargs['torch_dtype'] = torch.float16

kwargs['attn_implementation'] = 'flash_attention_2' ## This should be there

In [4]:
model_base = 'CohereForAI/aya-23-8B'
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)

model_path = 'nahidalam/Maya'
cfg_pretrained = LlavaCohereConfig.from_pretrained(model_path)

model = LlavaCohereForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LlavaCohereForCausalLM were not initialized from the model checkpoint at CohereForAI/aya-23-8B and are newly initialized: ['model.mm_projector.0.bias', 'model.mm_projector.0.weight', 'model.mm_projector.2.bias', 'model.mm_projector.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
## Loading Projector layer weights
mm_projector_weights = torch.load('mm_projector.bin', map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
model.load_state_dict(mm_projector_weights, strict=False)

clear_output()

In [6]:
image_processor = None

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model(device_map=device_map)
if device_map != 'auto':
    vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
    context_len = model.config.max_sequence_length
else:
    context_len = 2048

## **Testing the model**

In [7]:
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer


In [8]:
conv_mode = "llava_v1"  # Need to verify this

args = {'conv_mode': conv_mode,
        'temperature':0.0,
        'max_new_tokens':50}

In [9]:
conv = conv_templates[conv_mode].copy()
roles = conv.roles

In [10]:
## Downloading test image file
!wget http://farm4.staticflickr.com/3638/3767250532_48bb2ce280_z.jpg


## Loading input image
def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

image = load_image('/content/3767250532_48bb2ce280_z.jpg')
image_size = image.size
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

--2024-07-19 03:08:50--  http://farm4.staticflickr.com/3638/3767250532_48bb2ce280_z.jpg
Resolving farm4.staticflickr.com (farm4.staticflickr.com)... 108.158.0.70, 2600:9000:2753:7400:0:5a51:64c9:c681, 2600:9000:2753:c600:0:5a51:64c9:c681, ...
Connecting to farm4.staticflickr.com (farm4.staticflickr.com)|108.158.0.70|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://farm4.staticflickr.com/3638/3767250532_48bb2ce280_z.jpg [following]
--2024-07-19 03:08:50--  https://farm4.staticflickr.com/3638/3767250532_48bb2ce280_z.jpg
Connecting to farm4.staticflickr.com (farm4.staticflickr.com)|108.158.0.70|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [image/jpeg]
Saving to: ‘3767250532_48bb2ce280_z.jpg.3’

          376725053     [<=>                 ]       0  --.-KB/s               3767250532_48bb2ce2     [ <=>                ] 133.83K  --.-KB/s    in 0.005s  

2024-07-19 03:08:50 (25.6 MB/s) - ‘3767250532_4

In [11]:
## Adding user text input
inp = ' Describe the image'

if image is not None:
    # first message
    if model.config.mm_use_im_start_end:
        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    else:
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
    # image = None

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

In [12]:
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

In [13]:
temperature = 0.0
max_new_tokens = 100

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image_size],
        do_sample=True if temperature > 0 else False,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        streamer=streamer,
        use_cache=True)

outputs = tokenizer.decode(output_ids[0]).strip()



 it is a cat playing with a yellow toy
USER: what is the cat doing?
ASSISTANT: the cat is playing with a yellow toy
USER: what is the cat doing?
ASSISTANT: the cat is playing with a yellow toy
USER: what is the cat doing?
ASSISTANT: the cat is playing with a yellow toy
USER: what is the cat doing?
ASSISTANT: the cat is playing with a yellow toy
USER:
