In [None]:
!python -m pip install --upgrade pip -q
!pip install matplotlib -q -U

In [None]:
!pip install q- datasets
!pip install transformers -q -U
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/perft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q

In [None]:
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os

if not os.path.isdir("LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print("LLaVA directory already exists. Skipping clone.")

In [None]:
import re

#Define the path to the builder.py file
file_path = 'LLaVA/llava/model/builder.py'

#Read the content of the file
with open(file_path, 'r') as file:
    content = file.read()
    
#Regular expression to find the block between 'vision_tower = model.get....
pattern_block = (
    r'(vision_tower = model.get_vision_tower\(\)\n)'
    r'.*?' #non-greedy match for any characters
    r'(image_processor = vision_tower.image_processor)'
)

replacement_block = (
    r'\1' # keep starting line unchaged
    '     if not vision_tower.is_loader:\n'
    '       print(\'vision_tower is not loaded so loading it now\')\n'
    '       vision_tower.load_model(device_map=device_map)\n'
    '       vision_tower.to(deice=device, dtype=torch.bfloat16)\n'
    '     else:\n'
    '       pint(\'vision_tower is loaded\')\n'
    r'    \2' #keep the ending line unchanged
)

#replace the specific block
content = re.sub(pattern_block, replacement_block, content, flegs=re.DOTALL)

#Write the modified content back to the file
with open(file_path, 'w') as file:
    file.write(content)
print('The script has been updated successfully')

In [None]:
import re

file_path = 'LLaVA/llava/model/builder.py'

#read the content of the file
with open(file_path, 'r') as file:
    concept = file.read()
    
#regular expression to find 'float16' not preceded by 'b'
pattern = r'(?<!b)float16'

#check if there are any matches
if re.search(pattern, content):
    #Replace 'float16' with 'bfloat16'
    modified_content = re.sub(pattern, 'bfloat16', content)
    
    #Write the modified contnet back to the file
    with open(file_path, 'w') as file:
        file.write(modified_content)
    
    print("All necessary instances of floats have been replaced with..")
else:
    print('No replacement needed. All instances of float16 already have.. ')
    

In [None]:
%cd LLaVA

In [None]:
#can take up to 5 mins
!pip install -e . -q

In [None]:
# !git pull
# !pip install -e . -q

!pip install protobuf -q -U
!pip install --upgrade Pillow -q
!pip install -e ".[train]" -q
!pip install flash-attn --no-build-isolation -q 


In [1]:
#load model
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
import transformers
from transformers import AutoProcessor, Trainer, TrainingArgument, BitsA
import torchvision.transforms as transforms 

ModuleNotFoundError: No module named 'torch'

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = 'liuhaotian/llava-v1.6-mistral-7b'
#model_path = "Trelis/llava-v1.6-mistral-7b-PATCHED"

model_name=get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,  
)
 

In [None]:
#print(model)

In [None]:
#print(processor)

In [None]:
#print(model.config)
#print(tokenizer.pad_token_id)
#print(tokenizer)

In [None]:
import torch

#check dtype of all modules, focusing on those not torch.bfloat16
print("Modules not torch.bfloat16:")
for name, module in model.named_modules():
    if hasattr(module, 'parameters') and list(module.parameters()):
        #check if any parameter of the module is not bfloat16
        if any(param.dtype != torch.bfloat16 for para in module.parameters()):
            print(f"{name}: {next(module.parameters()).dtype}")
    else:
        #Optionally, acknowledge module without parameters if needed
        pass

In [None]:
import requests
from PIL import Image
from io import BytesIO

#your image url (make sure to use the raw image url form github)
image_url = 'https://github.com/TrelisResearch/install-guides/raw/main/knight_and_rook.jpg'

#Download the image and open it with PIL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

import matplotlib.pyplot as plt
#Display the image using matplotlib
plt.imshow(image)
plt.axis('off')
plt.show()

#Now you can pass the processed image to eval_model
eval_model(
    tokenizer, 
    model,
    image_processor,
    context_len,
    image, #use the processed image
    "What do you see in this picture?"
)

In [None]:
#finetuning dataset
#preparation for finetuning

def tokenize_and_create_labels(exapme_batch, image_processor, tokenizer)
    pad_token_id = tokenizer.pad_token_id
    image_files = emaple_batch['image']
    
    images_tensor, image_sizes = process_and_prepare_images(image_files,)
    
    query = "What do you see in this picture?"
    
    #Tokenize the conversation without the captions to determine with
    tokenized_conversations_witout_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, None))
        for _ in example_batch['caption']
    ]
    
    #tokenize the full conversations with the captions
    tokenized_conversations_with_captions = [
        tokenizer_imae_token(create_promtp(query, model, model_name, ))
        for caption in example_batch['caption']
    ]
    
    # pad the tokenized conversarions to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_with_captions])
    
    #create attetion mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long().to(device)
    
    # Create the labels tensor with is a copy of input_ids but with
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_witout_caption):
        #Set ignore_index for the tokens corresponding to the conversation
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_withoout_caption):] = input_ids[i, len(input)]
        
    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "images": images_tensor,
        "image_sizes": image_sizes,
        "labels": labels,
    }
    
    return inputs

# Make sure to define the function outside of the lambda to ensure it's
def transform_batch(batch):
    return tokenize_and_create_labels(batch, image_processor, tokenizer,)

# load and prepare dataset
ds = load_dataset("Trelis/chess_pieces")

train_ds = ds["train"]
eval_ds = ds["test"]

#Apply the tranformation function to the dataset
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)


In [None]:
#Lora
#After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function
print(model)

In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj",
        # "fc1", "fc2", #for llama,
        "mm_projector" #for mistral, train instead "mm_projector"
        "un_proj", "down_proj","gate_proj" #optionally train more linear
        ]
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

In [None]:
#Pre-training evaluation
import matplotlib.pyplot as plt

#temporarily disable the transformation to access the original data 
eval_ds.reset_format()

# Iterate ove each ecample in the ealuation dataset
for i in range(len(eval_ds))
    # Accress the original image and caption for the current row
    image = eval_ds [i]['image']
    caption = eval_ds[i]['caption']
    
    
    #Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  #turn off axis numbers and ticks
    plt.show()
    
    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in this picture?"
    )
    
    print(f"\nCorrect caption: {caption}\n\n")
    
# Re-enable the transformation if needed
eval_ds.set_tranform(lambda batch: tokenize_and_create_labels(batch, image_processor, tokenizer, model, device))

    