# Multi-Modal Fine-Tuning


## Install and import libraries

In [18]:
!pip install matplotlib -q -U
!pip install datasets -q -U
!pip install -q bitsandbytes sentencepiece  accelerate loralib
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q
!pip install Pillow==10.4.0 -q
!pip install pickleshare -q
!pip install peft==0.10.0 -q
!pip install transformers==4.37.2 -q

In [None]:
# Allows for faster downloads
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os

if not os.path.isdir('LLaVA'):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print('LLaVA already exists. Skipping clone.')

In [None]:
import re

# Define the path to the builder.py file
file_path = 'LLaVA/llava/model/builder.py'

# Read the content of the file
with open(file_path, 'r') as file:
    content = file.read()

# Regular expression to find the block between 'vision_tower = model.get_vision_tower()' and 'vision_tower.image_processor'
pattern_block = (
    r'(vision_tower = model.get_vision_tower\(\)\n)'
    r'.*?' # Non-greedy match for any characters
    r'(image_processor = vision_tower.image_processor)'
)

replacement_block = (
    r'\1' # Keep the first line unchanged
    '        if not vision_tower.is_loaded:\n'
    '           print(\'vision_tower is not loades so loading now\')\n'
    '           vision_tower.load_model(device_map=device_map)\n'
    '           vision_tower.to(device=device_map, dtype=torch.float16)\n'
    '        else:\n'
    '           print(\'vision_tower is already loaded\')\n'
    r'      \2' # Keep the last line unchanged
)

# Replace the specific block
content = re.sub(pattern_block, replacement_block, content, flags=re.DOTALL)

# Write the content back to the file
with open(file_path, 'w') as file:
    file.write(content)

print('File modified successfully.')

In [None]:
import re

# Define the path to the builder.py file
file_path = 'LLaVA/llava/model/builder.py'

# Read the content of the file
with open(file_path, 'r') as file:
    content = file.read()

# Regular expression to find 'float16' and replace it with 'float16'
pattern = r'(?<!b)float16'

# CHeck if there are any matches
if re.search(pattern, content):
    # Replace all matches
    modified_content = re.sub(pattern, 'float16', content)

    # Write the content back to the file
    with open(file_path, 'w') as file:
        file.write(content)

    print('File modified successfully.')
else:
    print('No modification needed.')

In [None]:
%cd LLaVA

In [None]:
# Takes quite a while to run
!pip install -e . -q

In [None]:
!pip install protobuf -q -U
!pip install -e '.[train]' -q
!pip install flash-attn --no-build-isolation -q

## Load the Model

In [12]:
import torch
from datasets import load_dataset
import transformers
from peft import LoraConfig, get_peft_model
from PIL import Image
from transformers import AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_path = 'liuhaotian/llava-v1.6-mistral-7b' # needs about 100 GB of VRAM equals 3 x A6000 cards to fine-tune in 16 bit

model_name = get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,
)

## Optional examination

In [None]:
import torch

print('modules not torch.float16:')
for name, module in model.named_modules():
    if hasattr(module, 'parameters') and list(module.parameters()):
        # Check if the module has float16 parameters
        if any(param.dtype != torch.float16 for param in module.parameters()):
            print(f'{name}: {next(module.parameters()).dtype}')
        else:
            pass

## Inference

Methode um die Inference zu messen

In [None]:
# Method to test the model's inference
import torch
import re
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import Image
import requests
from io import BytesIO

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

# Common function to create prompts
def create_prompt(query, model, model_name=model_name, captions=None):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in query:
        if model.config.mm_use_im_start_end:
            query = query.replace(IMAGE_PLACEHOLDER, image_token_se, query)
        else:
            query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
    else:
        if model.config.mm_use_im_start_end:
            query = image_token_se + '\n' + query
        else:
            query = DEFAULT_IMAGE_TOKEN + '\n' + query

        conv_mode = infer_conv_mode(model_name)
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], query)
        if captions is not None:
            conv.append_message(conv.roles[1], captions)
        else:
            conv.append_message(conv.roles[1], None)
        return conv.get_prompt()
    
# Common function to infer conversation mode
def infer_conv_mode(model_name):
    if 'llama-2' in model_name.lower():
        return ' llava_llama_2'
    else:
        return 'mistral_instruct'
        
# Common function to process images
def process_and_prepare_images(image_files, image_processor, model, device):
    images = [load_image(image_file) for image_file in image_files]
    images_tensor = process_images(
        images, 
        image_processor, 
        model.config
    ).to(
        device,
        dtype=torch.float16
    )
    image_sizes = [image.size for image in images]
    return images_tensor, image_sizes

In [None]:
import torch
import re

def load_image(image_input):
    # Check if the input is a string (path/url)
    if isinstance(image_input, str):
        if image_input.startswith('http') or image_input.startswith('https'):
            response = requests.get(image_input)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(image_input).convert('RGB')
    elif isinstance(image_input, Image.Image):
        # Input is already an image
        image = image_input
    else:
        raise ValueError('Invalid input. Please provide a valid input type.')
    return image

def eval_model(tokenizer, model, image_processor, context_len, image_file, query, model_name=model_name, sep=',', temperature=1.0, num_beams=1, max_new_tokens=512):
    # Model
    disable_torch_init()

    # Create prompt using the common function
    prompt = create_prompt('Describe the image.',model, model_name)

    print(f"Prompt: {prompt}")

    # Process images using the common function
    if isinstance(image_file, list):
        images_tensor, image_sizes = process_and_prepare_images(image_file, image_processor, model, model.device)
    elif isinstance(image_file, str):
        images_tensor, image_sizes = process_and_prepare_images([image_file], image_processor, model, model.device)
    else:
        # If image_files is not a list or a string, it's likely an Image Object
        images = [image_file]
        images_tensor, image_sizes = process_and_prepare_images(images, image_processor, model, model.device)

    # Tokenize the prompt using the custom tokenizer_image_token function
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        .unsqueeze(0)
        .to(model.device)
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=temperature != 1.0,
            temperature=temperature,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
        )
        
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
    print(outputs)


In [None]:
import requests
from PIL import Image
from io import BytesIO

# Raw image URL from GitHub
image_url = 'https://github.com/fuerstfabian/Fine-tuned-LLaVA-Vision-and-Language/blob/main/data_prep/data/figure_9647.jpg?raw=true'

# Download image and open it with PIL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')
plt.show()

# Pass processed image to the eval_model function
eval_model(
    tokenizer,
    model,
    image_processor,
    context_len, 
    image, 
    'What do you see in the image?', 
    model_name=model_name
)

## Fine-Tuning Dataset

Vorbereiten des Datasets, welches für das Fine-Tuning verwendet wird

In [None]:
from torch.nn.utils.rnn import pad_sequence

def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch['image']
    ignore_index = -100

    images_tensor, image_sizes = process_and_prepare_images(image_files, image_processor, model, model.device)

    query = 'What do you see in the image?'

    # Tokenize the conversation without the captionss to determine which tokens are used
    tokenized_conversations_without_captions = [
        torch.tensor(tokenizer_image_token(create_prompt(query, model, model_name, None), tokenizer))
        for _ in example_batch['captions']
    ]

    # Tokenize the full conversation
    tokenized_conversations_with_captions = [
        torch.tensor(tokenizer_image_token(create_prompt(query, model, model_name, captions), tokenizer))
        for captions in example_batch['captions']
    ]

    # Pad the tokenized conversations to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_with_captions], batch_first=True, padding_value=pad_token_id)

    # Create attention_mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long().to(device)

    # Create the labels tensor which is a copy of input_ids but with ignore_index for padding tokens
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_without_captions):
        # Set ignore_index for the tokens corresponding to the conversation
        input_id_without_captions = tcwc.squeeze(0)
        labels[i, len(input_id_without_captions):] = input_ids[i, len(input_id_without_captions):]

    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'images': images_tensor,
        'image_sizes': image_sizes,
    }

    return inputs

# Make sure to define the function outside of the lambda to ensure it's picklable
def transform_batch(batch):
    return tokenize_and_create_labels(batch, image_processor, tokenizer, model)

# load and prepare the dataset
ds = load_dataset('fuerstfabian/cat_figures')

train_ds = ds['train']
eval_ds = ds['test']

# Apply the transformation function to the dataset
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

## LoRA

Nachdem erstellen der Config des Low-rank Adapters (LoRA) können wir das PeftModel mit der get_peft_model Funktion laden

In [None]:
print(model)

In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        'q_proj', 'k_proj', 'v_proj',
        'train', 
        'down_proj', 'gate_proj'
    ],
    lora_dropout=0.05,
    bias='none',
)
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

## Pre-Training Evaluation

In [None]:
import matplotlib.pyplot as plt

# Temporarly disable the transformation to access the original data
eval_ds.reset_format()

# iterate over each example in the val dataset
for i in range(len(eval_ds)):
    # Access the original image and captions for the current row
    image = eval_ds[i]['image']
    captions = eval_ds[i]['captions']

    # Display the image with matplotlib
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        'What do you see in the image?'
    )

    print(f"\nCorrect captions: {captions}\n\n")

# Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: tokenize_and_create_labels(batch, image_processor, tokenizer, model))

## Training 

Nach all diesen Schritten können wir endlich das Modell mit dem Hugging Face Trainer fine-tunen!!🎉<br /><br />Da das fine-tuning mit fp16 zu overflows führen kann trainieren wir mit bf16 mixed precision.

In [None]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Assuming train_ds is your training dataset prepared as a PyTorch Dataset object
batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# Assuming train_loader is your DataLoader for the training dataset
for batch in train_loader:
    print(batch.keys()) # Print the dictionary keys to see what data is included in a batch

    # If 'images' is a key, this indicates that images are being load_dataset
    if 'images' in batch:
        print('Images are included in the DataLoader')
        print(f"Batch 'images' shape: {batch['images'].shape}")

    #Similarly, check for other expected keys, like 'input_ids', 'attention_mask'
    if 'input_ids' in batch and 'attention_mask' in batch:
        input_ids_first_row = batch['input_ids'][1]
        print(f"First row of input_ids: \n{input_ids_first_row.tolist()}")
        print('Text inputs are included in the DataLoader')
        print(f"Batch 'attention_mask' shape: {batch['attention_mask'].shape}")

        # Print the first row of labels, replace ignore_index with the string '[IGNORE]'
        labels = batch['labels'][1].tolist()
        labels_str = ['[IGNORE]' if label == -100 else str(label) for label in labels]
        print(f"Lables: {labels_str}")

        # Print the first row of the attention mask
        attention_mask_str = batch['attention_mask'][1].tolist()
        print(f"Attention mask: {attention_mask_str}")

    # Optionally, display an image from the batch to visually inspect the data
    if 'images' in batch:
        images_tensor = batch['images'][1]
        print(f" First Row Image Data type: {images_tensor.dtype}")
        print(f" First Row Image Shape: {images_tensor.shape}")
        print(f"First Row Image Value range: [{images_tensor.min()}, {images_tensor.max()}]")

    break # Only check the first batch

In [None]:
'''output_sample=[<paste first row of input_ids here>] # Output of the model, # Delete the image token(-200) entry
print(tokenizer.decode(output_sample))'''

In [None]:
print(IMAGE_TOKEN_INDEX)

In [None]:
output_moddel_name = f"{model_name}-figures"

training_args = TrainingArguments(
    output_dir=output_moddel_name,
    learning_rate=1e-4,
    # fp16=True for non ampere GPUs 
    bf16=True, # RTX A6000 in this case
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=1,
    dataloader_pin_memory=False,
    save_total_limit=2,
    evaluation_strategy='steps',
    save_strategy='steps',
    save_steps=0.2,
    eval_steps=0.2,
    logging_steps=1,
    num_train_epochs=3,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=['labels'],
    load_best_model_at_end=True,
    report_to=None,
    optim='adamw_torch',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
)

trainer.train()

## Post-Training Evaluation

In [None]:
import matplotlib.pyplot as plt

# Temporarly disable the transformation to access the original data
eval_ds.reset_format()

# iterate over each example in the val dataset
for i in range(len(eval_ds)):
    # Access the original image and captions for the current row
    image = eval_ds[i]['image']
    captions = eval_ds[i]['captions']

    # Display the image with matplotlib
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    eval_model(
        tokenizer,
        model,
        image_processor,
        context_len,
        image,
        'What do you see in the image?'
    )

    print(f"\nCorrect captions: {captions}\n\n")

# Re-enable the transformation if needed
eval_ds.set_transform(lambda batch: ds_transforms(batch, image_processor, tokenizer, model, device))