In [None]:
!python -m pip install --upgrade pip -q
!pip install matplotlib -q -U

In [None]:
!pip install q- datasets
!pip install transformers -q -U
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/perft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q

In [None]:
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os

if not os.path.isdir("LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print("LLaVA directory already exists. Skipping clone.")

In [None]:
import re

#Define the path to the builder.py file
file_path = 'LLaVA/llava/model/builder.py'

#Read the content of the file
with open(file_path, 'r') as file:
    content = file.read()
    
#Regular expression to find the block between 'vision_tower = model.get....
pattern_block = (
    r'(vision_tower = model.get_vision_tower\(\)\n)'
    r'.*?' #non-greedy match for any characters
    r'(image_processor = vision_tower.image_processor)'
)

replacement_block = (
    r'\1' # keep starting line unchaged
    '     if not vision_tower.is_loader:\n'
    '       print(\'vision_tower is not loaded so loading it now\')\n'
    '       vision_tower.load_model(device_map=device_map)\n'
    '       vision_tower.to(deice=device, dtype=torch.bfloat16)\n'
    '     else:\n'
    '       pint(\'vision_tower is loaded\')\n'
    r'    \2' #keep the ending line unchanged
)

#replace the specific block
content = re.sub(pattern_block, replacement_block, content, flegs=re.DOTALL)

#Write the modified content back to the file
with open(file_path, 'w') as file:
    file.write(content)
print('The script has been updated successfully')

In [None]:
import re

file_path = 'LLaVA/llava/model/builder.py'

#read the content of the file
with open(file_path, 'r') as file:
    concept = file.read()
    
#regular expression to find 'float16' not preceded by 'b'
pattern = r'(?<!b)float16'

#check if there are any matches
if re.search(pattern, content):
    #Replace 'float16' with 'bfloat16'
    modified_content = re.sub(pattern, 'bfloat16', content)
    
    #Write the modified contnet back to the file
    with open(file_path, 'w') as file:
        file.write(modified_content)
    
    print("All necessary instances of floats have been replaced with..")
else:
    print('No replacement needed. All instances of float16 already have.. ')
    

In [None]:
%cd LLaVA

In [None]:
#can take up to 5 mins
!pip install -e . -q

In [None]:
# !git pull
# !pip install -e . -q

!pip install protobuf -q -U
!pip install --upgrade Pillow -q
!pip install -e ".[train]" -q
!pip install flash-attn --no-build-isolation -q 


In [1]:
#load model
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
import transformers
from transformers import AutoProcessor, Trainer, TrainingArgument, BitsA
import torchvision.transforms as transforms 

ModuleNotFoundError: No module named 'torch'

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = 'liuhaotian/llava-v1.6-mistral-7b'
#model_path = "Trelis/llava-v1.6-mistral-7b-PATCHED"

model_name=get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,  
)
 

In [None]:
#print(model)

In [None]:
#print(processor)

In [None]:
#print(model.config)
#print(tokenizer.pad_token_id)
#print(tokenizer)

In [None]:
import torch

#check dtype of all modules, focusing on those not torch.bfloat16
print("Modules not torch.bfloat16:")
for name, module in model.named_modules():
    if hasattr(module, 'parameters') and list(module.parameters()):
        #check if any parameter of the module is not bfloat16
        if any(param.dtype != torch.bfloat16 for para in module.parameters()):
            print(f"{name}: {next(module.parameters()).dtype}")
    else:
        #Optionally, acknowledge module without parameters if needed
        pass

In [None]:
import re
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.transforms.functional import to_tensor, to_pil_image
from PIL import Image
import requests
from io import BytesIO

from llava.constants import (
  IMAGE_TOKEN_INDEX,
  DEFAULT_IMAGE_TOKEN,
  DEFAULT_IM_START_TOKEN,
  DEFAULT_IM_END_TOKEN,
  IMAGE_PLACEHOLDER,
)
from llava.converters import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
  process_images,
  tokenizer_image_token,
  get_model_name_from_path,
)

# MODIFY THIS TO YOUR OWN MODEL
# CAPTION SHOULD BE CHANGED TO THE CORRESPONDING MODEL
def create_prompt(query, model, model_name=model_name, caption=None):
  image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
  if IMAGE_PLACEHOLDER in query:
    if model.config.mm_use_im_start_end:
      query = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
    else:
      query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
  else:
    if model.config.mm_use_im_start_end:
      query = image_token_se + "\n" + query
    else:
      query = DEFAULT_IMAGE_TOKEN + "\n" + query
  
  conv_mode = infer_conv_mode(model_name)
  conv = conv_templates[conv_mode].copy()
  conv.append_message(conv.roles[0], query)
  if caption is not None:
    conv.append_message(conv.roles[1], caption)
  else:
    conv.append_message(conv.roles[1], None)
  return conv.get_prompt()

def infer_conv_mode(model_name):
  if "llama-2" in model_name.lower():
    return "llava_llama_2"
  elif "mistral" in model_name.lower():
    return "mistral_instruct"
  elif "v1.6-34b" in model_name.lower():
    return "chatml_direct"
  elif "v1" in model_name.lower():
    return "llava_v1"
  elif "mpt" in model_name.lower():
    return "mpt"
  else: 
    return "llava_v0"
  
# Common function to proccess images
def process_and_prepare_images(image_files, image_processor, model, device):
  images = [load_image(image_file) for image_file in image_files]
  images_tensor = process_images(
    images,
    image_processor,
    model.config
  ).to(
    device,
    dtype=torch.bfloat16
  )
  image_sizes = [image.size for image in images]
  return images_tensor, image_sizes

In [None]:
import torch
import re

def load_image(image_input):
  # Check if the input is a string (path or URL)
  if isinstance(image_input, str):
    if image_input.startswith("http") or image_input.startswith("https"):
      response = requests.get(image_input)
      image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
      image = Image.open(image_input).convert("RGB")
  elif isinstance(image_input, Image.Image):
    image = image_input
  else:
    raise ValueError("Unsupported image input type")
  return image

def eval_model(tokenizer, model, image_processor, context_len, image_file, query, model_name=model_name, sep=",", temperature=1.0, num_beams=1, max_tokens=512):
  # Model
  disable_torch_init()
  prompt = create_prompt(query, model, model_name)
  
  if isinstance(image_file, list):
    images_tensor, image_sizes = process_and_prepare_images(image_file, image_processor, model, model.device)
  elif isinstance(image_file, str):
    images_tensor, image_sizes = process_and_prepare_images([image_file], image_processor, model, model.device)
  else:
    # if image_file is neither a list or string
    images = [image_file]
    images_tensor, images_sizes = process_and_prepare_images(images, image_processor, model, model.device)
  
  input_ids = (
    #revisar return_tensors
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .to(model.device)
  )
  
  with torch.inference_mode():
    output_ids = model.generate(
      input_ids,
      images=images_tensor,
      image_sizes=image_sizes,
      do_sample=temperature != 1.0,
      temperature=temperature,
      #top_p=top_p,
      num_beams=num_beams,
      max_new_tokens=max_tokens,
      use_cache=True,
    )
  
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0].strip()
  print(outputs)

In [None]:
import requests
from PIL import Image
from io import BytesIO

#your image url (make sure to use the raw image url form github)
image_url = 'https://github.com/TrelisResearch/install-guides/raw/main/knight_and_rook.jpg'

#Download the image and open it with PIL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

import matplotlib.pyplot as plt
#Display the image using matplotlib
plt.imshow(image)
plt.axis('off')
plt.show()

#Now you can pass the processed image to eval_model
eval_model(
    tokenizer, 
    model,
    image_processor,
    context_len,
    image, #use the processed image
    "What do you see in this picture?"
)

In [None]:
#finetuning dataset
#preparation for finetuning

def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model, device):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch['image']
    
    #modified
    images_tensor, image_sizes = process_and_prepare_images(image_files, image_processor, model, device)
    
    query = "What do you see in this picture?"
    
    #Tokenize the conversation without the captions to determine with
    tokenized_conversations_without_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, None))
        for _ in example_batch['caption']
    ]
    
    #tokenize the full conversations with the captions
    tokenized_conversations_with_captions = [
        tokenizer_imae_token(create_promtp(query, model, model_name, caption))
        for caption in example_batch['caption']
    ]
    
    # pad the tokenized conversarions to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_with_captions])
    
    #create attetion mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long().to(device)
    
    # Create the labels tensor with is a copy of input_ids but with
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_without_caption):
        #Set ignore_index for the tokens corresponding to the conversation
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption)]
        
    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "images": images_tensor,
        "image_sizes": image_sizes,
        "labels": labels,
    }
    
    return inputs

# Make sure to define the function outside of the lambda to ensure it's
def transform_batch(batch):
    return tokenize_and_create_labels(batch, image_processor, tokenizer,)

# load and prepare dataset
ds = load_dataset("Trelis/chess_pieces")

train_ds = ds["train"]
eval_ds = ds["test"]

#Apply the tranformation function to the dataset
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)


In [None]:
#Lora
#After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function
print(model)

In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj",
        # "fc1", "fc2", #for llama,
        "mm_projector" #for mistral, train instead "mm_projector"
        "un_proj", "down_proj","gate_proj" #optionally train more linear
        ]
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

In [None]:
#Pre-training evaluation
import matplotlib.pyplot as plt

#temporarily disable the transformation to access the original data 
eval_ds.reset_format()

# Iterate ove each ecample in the ealuation dataset
for i in range(len(eval_ds))
    # Accress the original image and caption for the current row
    image = eval_ds [i]['image']
    caption = eval_ds[i]['caption']
    
    
    #Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  #turn off axis numbers and ticks
    plt.show()
    
    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in this picture?"
    )
    
    print(f"\nCorrect caption: {caption}\n\n")
    
# Re-enable the transformation if needed
eval_ds.set_tranform(lambda batch: tokenize_and_create_labels(batch, image_processor, tokenizer, model, device))

    

In [None]:
# training, finally we are using the hugging face trainer to finetune the model. Fine-tuning in mixed precision fp16 can lead to overflows. As such, we recommend training in mixed precision bf16 when possible

from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Assuming train_ds is your training dataset prepared as a PyTorch Dataset object 
batch_size = 4 #Specify the batch size you want to use
train_load = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

#Assuming train_loader is yout Dataloader instance for the trianing dataset
for batch in train_loader:
    print(batch.keys()) # print the diccionary keys to see what data is included in a  ba
    
    
    # if 'images'is a key, this indicates that images are being loaded
    if 'images' in batch:
        print("Images are included in the DataLoader.")
        print(f"Batch 'images' shape: {batch['images'].shape}") #Print the shape of the
        
    #similarly, check for other expected keys, like 'input_ids' and 'attention_mask'
    if 'input_ids' in batch and 'attention_mask' in batch:
        # Print the first row of input ids to check out-of-range token IDs
        input_ids_first_row = batch['input_ids'][1]
        print(f"First row of 'input_ids': \n{input_ids_first_row.tolist()}")
        
        # # Check if any token IDs are out of range
        # vocab_size = tokenizer.vocab_size
        # out_of_range_tokens = [token_id for token_id in input_ids_first_row i token_id]
        # if out_of_range_tokens:
        #   print(f"Out-of-range token IDs: {out_of_range_tokens})
        
        # # Decode the first row of input_ids to text, if all token IDs are in range
        # If not out_of_range_tokens: 
        #   decoded_inputs = tokenize.decode(input_ids_first_row, skip_special_toekns=)
        #   print(f"Decoded input tokens: {decoded_inputs}")
        # else:
        #   print("Cannot decode input_ids due to out-of-range token IDs.")
        
        print("Text inputs are included in the DataLoader.")
        print(f"Batch 'input_ids' shape: {batch['input_ids'].shape}")
        print(f"Batch 'attention_mask' shape: {batch['attention_mask'].shape}")
        
        # # Decode the first row of input_ids to text
        # decoded_inputs = tokenizer.decode(batch['input_ids'][0], skip_special_tokens=False)
        # print(f"Decoded input tokens: {decoded inputs}")
        
        # Print the first row of labels, replacing ingore_index with string '[IGNORE]
        labels = batch['labels'][1].tolist()
        labels_str =  ['[IGNORE]' if label == -100 else str(label) for label in labels]
        print(f"Labels: {labels_str}")
        
        # Print the first row of the attention_mask
        attention_mask_str = batch['attention_mask'][1].tolist()
        print(f"Attention mask: {attention_mask_str}")
        
    # Optionally, display an image from the batch to visually confirm loading 
    if 'images' in batch:
        image_tensor = batch['images'][1]
        print(f"First Row Image Data type: {image_tensor.dtype}")
        print(f"First Row Image Shape: {image_tensor.shape}")
        print(f"First Row Image Value range: [{image_tensor.min()}, {image_tensor.max()}]")
        
    break # Only check the first batch
        
        

In [None]:
output_sample=[1, 733, 16289, 28793, 28705, 13, 3195, 511, 368, 1032, 297, 456, 575]
print(tokenizer.decode(output_sample))

In [None]:
print(IMAGE_TOKEN_INDEX)

In [2]:
# TEST_QUERY = "[INST] <image>\nWhat do you see in this picture? [/INST]"

#image_token_se = DEFAULT_IMAGE_TOKEN

# print(f"image_token_se is: {image_token_se}")

# print(tokenizer_image_token(test_query, tokenizer, IMAGE_TOKEN_INDEX))


In [None]:
# from torch.nn import CrossEntroppyLoss

# def compute_loss(model, imputs, return_outputs=False):
#   labels = inputs.pop("labels")
#   outputs = model(**inputs)
#   logits = outputs.logits
#   loss_fct = CrossEntropyLoss(ignore_index=ignore_index)
#   loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
#   return (loss, outputs) if return_outputs else loss

output_model_name=f"{model_name}-chess"

training_args = TrainingArgument(
    out_dir=output_model_name,
    learning_rate=1e-4,
    # fp16=True, #for non ampere gpus
    bf16=True,
    peer_device_train_batch_size=4,
    peer_devide_eval_batch_size=6,
    gradient_accumulation_step=1,
    dataloader_pin_memory=False,
    save_total_limit=2,
    ealuation_stategy="steps",
    save_steps=0.2,
    eval_steps=0.2,   
    logging_steps=1
    num_train_epochs=3,
    # max_steps=3,
    remove_unused_columns=False,
    push_to_hub=Flase,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="adamw_torch"   
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    # compute_loss=compute_loss, # pass the custom compute
)

trainer.train()


In [None]:
# Eval after training 
import matplotlib.pyplot as plt

# Temporarily disable the transformation to access the original data
eval_ds.reset_format()

# Iterate over each example in the evaluation dataset
for i in range(len(eval_ds)):
    # Access the original image and caption for the current row
    image = eval_ds[i]['image']
    caption = eval_ds[i]['caption']
    
    #Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off') #turn off axis numbers and ticks
    plt.show()
    
    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        "What do you see in the picture?"
    )
    
    print(f"\nCorrect caption: {caption}\n\n")
        
#Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: ds_transforms(batch, image_processor, tokenizer, model))
    