Ref:

Finetuning Script: https://github.com/zhangfaen/finetune-Qwen2-VL/blob/main/finetune.py

Repo: https://github.com/zhangfaen/finetune-Qwen2-VL

#### Compatibility: Qwen2.5 requires Transformers 4.52 version (present in backup_env)

In [None]:
import os 
print(os.getenv("CONDA_DEFAULT_ENV"))

In [None]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import json
import torch
import pickle
import logging
import datasets
import datetime

In [None]:
from tqdm.auto import tqdm
from datasets import Dataset
from functools import partial
from torch.optim import AdamW
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig

In [None]:
from qwen_vl_utils import process_vision_info
from util.logutil import init_logger, get_logger

In [None]:
from datasets import Dataset
from datasets import load_dataset

### Initialize Logger

In [None]:
logger = logging.getLogger("train_logger")
logger.setLevel(logging.INFO)

formatter = logging.Formatter(
    "[%(asctime)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)


### Load Dataset

In [None]:
dataset = load_dataset("dutta18/Quantity-Reasoning-VQA-23K")

In [None]:
dataset = dataset['train']

In [None]:
dataset

In [None]:
dataset = dataset.select(range(4450))

In [None]:
dataset

In [None]:
pwd

### Load COT Think Data

In [None]:
with open('./qty-reasoning-cot-data-8000.pkl', 'rb') as file:
    cot_think_data = pickle.load(file)

In [None]:
cot_think_data = cot_think_data[:4450]

In [None]:
len(cot_think_data)

In [None]:
dataset = dataset.add_column("cot_think_data", cot_think_data)

In [None]:
dataset

### Split Into Train, Test & Val

In [None]:
from datasets import DatasetDict

In [None]:
# 1. First create train (80%) and temp (20%)
train_test = dataset.train_test_split(test_size=0.25, seed=42)

# 2. Split the temp set into validation (10%) and test (10%)
test_val = train_test['test'].train_test_split(test_size=0.6, seed=42)

In [None]:
splits = {
    'train': train_test['train'],
    'validation': test_val['train'],
    'test': test_val['test'],
}

dataset_dict = DatasetDict(splits)

In [None]:
train_set, val_set, test_set = dataset_dict['train'], dataset_dict['validation'], dataset_dict['test']

In [None]:
train_set

In [None]:
val_set

In [None]:
test_set

### Creating JSON Format of the AOKVQA

In [None]:
formattedJSONTrain = list()

for idx in tqdm(range(len(train_set))):
    currentJSON =   {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": train_set[idx]['image']
                        },
                        {
                            "type": "text", 
                            "text": train_set[idx]['question']
                        }
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text", 
                            "text": train_set[idx]['cot_think_data']
                        }
                    ]
                }
            ]
        }  
    formattedJSONTrain.append(currentJSON)

In [None]:
formattedJSONVal = list()

for idx in tqdm(range(len(val_set))):
    currentJSON =   {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": val_set[idx]['image']
                        },
                        {
                            "type": "text", 
                            "text": val_set[idx]['question']
                        }
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text", 
                            "text": val_set[idx]['cot_think_data']
                        }
                    ]
                }
            ]
        }  
    formattedJSONVal.append(currentJSON)

In [None]:
formattedJSONTest = list()

for idx in tqdm(range(len(test_set))):
    currentJSON =   {
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": test_set[idx]['image']
                        },
                        {
                            "type": "text", 
                            "text": test_set[idx]['question']
                        }
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text", 
                            "text": test_set[idx]['cot_think_data']
                        }
                    ]
                }
            ]
        }  
    formattedJSONTest.append(currentJSON)

### Prepare Dataloaders

In [None]:
from torch.utils.data import Dataset

In [None]:
class qtyDataset(Dataset):
    def __init__(self, formatted_json_data):
        super().__init__()
        self.data = formatted_json_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
pwd

In [None]:
train_dataset = qtyDataset(formattedJSONTrain)
val_dataset = qtyDataset(formattedJSONVal)
test_dataset = qtyDataset(formattedJSONTest)

In [None]:
def find_assistant_content_sublist_indexes(l):
    '''
    A message from train_data/data.json may look like below:
        {
            "messages": [
                {'role': 'user', 'content': [{'type': 'image', 'image': 'train_data/1.jpeg'}, {'type': 'text', 'text': '描述一下这个图片'}]}, 
                {'role': 'assistant', 'content': [{'type': 'text', 'text': '这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子，坐在沙滩上，与她的金毛犬互动。她们的手臂伸展着，似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空，阳光洒在沙滩上，营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。'}]}
            ]
        }
    After apply_chat_template, the text will look like below:
        ['<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>描述一下这个图片<|im_end|>\n<|im_start|>assistant\n这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子，坐在沙滩上，与她的金毛犬互动。她们的手臂伸展着，似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空，阳光洒在沙滩上，营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。<|im_end|>\n']

    This function tries to find the indexes of the assistant content in the input_ids list to build labels.
    '''
    start_indexes = []
    end_indexes = []

    # Iterate through the list to find starting points
    for i in range(len(l) - 1):
        # Check if the current and next elements form the start sequence
        if l[i] == 151644 and l[i+1] == 77091 and l[i+2] == 198:
            start_indexes.append(i+3)
            # Now look for the first 151645 and 198 after the start
            for j in range(i+3, len(l)-1):
                if l[j] == 151645 and l[j+1] == 198:
                    end_indexes.append(j+2) # **NOTE** the <|im_end|>\n 2 tokens should be included in the label, so that model can predicate end of output.
                    break  # Move to the next start after finding the end

    return list(zip(start_indexes, end_indexes))

In [None]:
def collate_fn(batch, processor, device):
    
    messages = [m['messages'] for m in batch]
    texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages]
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    inputs = inputs.to(device)

    input_ids_lists = inputs['input_ids'].tolist()
    assert len(messages) == len(input_ids_lists)

    labels_list = []
    for ids_list in input_ids_lists:
        label_ids = [-100] * len(ids_list)
        for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list):
            label_ids[begin_end_indexs[0]:begin_end_indexs[1]] = ids_list[begin_end_indexs[0]:begin_end_indexs[1]]
        labels_list.append(label_ids)

    labels_ids = torch.tensor(labels_list, dtype=torch.int64)
    return inputs, labels_ids

### Model Loading & Quantization

In [None]:
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor

In [None]:
# Quantization Configuration.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  
    bnb_4bit_compute_dtype=torch.float16,  # Use float16 for computation
    bnb_4bit_use_double_quant=True,  
    bnb_4bit_quant_type="nf4",  
)

In [None]:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage = True,
    attn_implementation="flash_attention_2",
    device_map='auto'
)

# Load processor. 
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", 
    min_pixels=256*28*28, 
    max_pixels=512*28*28, 
    padding_side="left",
    use_fast=True
)

In [None]:
model.save_pretrained("/home/aritrad/MOE-Directory/temp/Llama-3.2-11B-Vision-Instruct")
processor.save_pretrained("/home/aritrad/MOE-Directory/temp/Llama-3.2-11B-Vision-Instruct")

In [None]:
# model = prepare_model_for_kbit_training(model)

In [None]:
# model

### LORA Settings

In [None]:
device = 'cuda'

In [None]:
LORA_Rank = 16

In [None]:
lora_config = LoraConfig(
    r=LORA_Rank,
    lora_alpha=LORA_Rank*2,  
    lora_dropout=0.05,  
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "attn.qkv", "attn.proj"], 
    task_type="CAUSAL_LM",
    inference_mode=False,  
)

#### Get PEFT Wrapper Model

In [None]:
qlora_qwen_model = get_peft_model(model, lora_config)

In [None]:
def report_trainable_params():
    
    # Simple param report
    trainable = sum(p.numel() for p in qlora_qwen_model.parameters() if p.requires_grad)
    print(f"Total trainable params: {trainable/1e6:.1f} M")

In [None]:
report_trainable_params()

### Create & Test Dataloader

In [None]:
batchSize_ = 4

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batchSize_,
    collate_fn=partial(collate_fn, processor=processor, device=device),
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batchSize_,
    collate_fn=partial(collate_fn, processor=processor, device=device)
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batchSize_,
    collate_fn=partial(collate_fn, processor=processor, device=device)
)

In [None]:
print(f'Length of the Train Dataloader: {len(train_loader)}')
print(f'Length of the Val Dataloader: {len(val_loader)}')
print(f'Length of the Test Dataloader: {len(test_loader)}')

In [None]:
# Test the dataloader before the forward pass
for batch in train_loader:
    inputs, labels = batch

    for k, v in inputs.items():
        print(f'{k} -> {v.dtype}')
    break

### Validation Function

In [None]:
def do_validation():
    
    qlora_qwen_model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader):
            inputs, labels = batch
            outputs = qlora_qwen_model(**inputs, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    qlora_qwen_model.train()
    torch.cuda.empty_cache()
    return avg_val_loss

In [None]:
def do_test():
    
    qlora_qwen_model.eval()
    test_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(test_loader):
            inputs, labels = batch
            outputs = qlora_qwen_model(**inputs, labels=labels)
            loss = outputs.loss
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader)
    qlora_qwen_model.train()
    torch.cuda.empty_cache()
    return avg_test_loss

### Training Hyperparams

In [None]:
from transformers import get_cosine_schedule_with_warmup

In [None]:
LR = 5e-5
epochs = 1
weight_decay = 0.00
gradient_accumulation_steps = 2

In [None]:
global_step = 0
best_val_loss = float("inf")

In [None]:
steps_per_epoch     = len(train_loader) // gradient_accumulation_steps
total_train_steps   = steps_per_epoch * epochs
num_warmup_steps    = int(0.05 * total_train_steps)          # 5 %   (quick)  
# ➟ for medium: 0.03 works fine

In [None]:
print(total_train_steps, num_warmup_steps)

In [None]:
optimizer = AdamW(qlora_qwen_model.parameters(), lr=LR, weight_decay=weight_decay)

In [None]:
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_train_steps,
)

In [None]:
saveDir = '/home/aritrad/main/Qwen2.5-VL-3B/GRPO/chkpts'

### Train Loop

In [None]:
for epoch in range(epochs):

    accumulated_loss = 0
    
    for idx, batch in enumerate(train_loader):
        inputs, labels = batch

        outputs = qlora_qwen_model(**inputs, labels=labels)
        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()
            
        accumulated_loss += loss.item()
        
        if (idx+1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(qlora_qwen_model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            logger.info(f"[ Epoch {epoch+1} | idx: {idx} | Optim Step {global_step} | Train Loss: {loss.item():.4f} ]")

            if global_step % 150 == 0:
                avg_val_loss = do_validation()
                logger.info(f"Val Loss @ Optim step: {global_step} -> {avg_val_loss:.4f}\n")
            
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    qlora_qwen_model.save_pretrained(os.path.join(saveDir, 'qwen2.5-qty-chkpt'))
                    logger.info(f"***** ✅ Checkpoint Saved *****\n")

    logger.info(f"Epoch {epoch+1} completed. Avg loss: {accumulated_loss / len(train_loader):.4f}")