Ref:

https://github.com/huggingface/smollm/blob/main/finetuning/Smol_VLM_FT.ipynb

In [None]:
import os
import re
import json
import torch
import pickle
import logging
import datasets
from PIL import Image
from tqdm.auto import tqdm
from datasets import Dataset
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader

In [None]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
print(os.getenv("CONDA_DEFAULT_ENV"))

### Set Up Logger

In [None]:
# Clear previous handlers to avoid duplicate logs in Jupyter
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Change to DEBUG for more verbosity
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]  # Ensures it logs to Jupyter cell output
)

logger = logging.getLogger(__name__)
logger.info("Logging is set up in the notebook!")

### Load the Dataset

In [None]:
dataset = load_dataset("dutta18/Quantity-Reasoning-VQA-23K")

In [None]:
dataset = dataset['train']

In [None]:
dataset

### Load COT Think Data

In [None]:
think_data = list()

with open('/home/aritrad/MOE-Directory/COT-Data-Qty-23K/Quantity-Reasoning-VQA-23K-Reasoning-Trace.jsonl', 'r') as file:
    for line in file:
        record = json.loads(line)
        think_data.append(record['generated_cot'])

In [None]:
len(think_data)

In [None]:
print(think_data[0])

### Merge with Dataset

In [None]:
dataset = dataset.add_column("cot_think_data", think_data)

In [None]:
dataset

In [None]:
# Check a single sample

dataset[0]

### Converting Output Number Words to Numeric Strings

In [None]:
num_map = {
    "zero": 0, "one": 1, "two": 2, "three": 3, "four": 4,
    "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9,
    "ten": 10, "eleven": 11, "twelve": 12, "thirteen": 13,
    "fourteen": 14, "fifteen": 15, "sixteen": 16, "seventeen": 17,
    "eighteen": 18, "nineteen": 19, "twenty": 20
}

In [None]:
def convert_answer(example):
    word = example["answer"].strip().lower()
    return {"answer": str(num_map.get(word, None))}

In [None]:
dataset = dataset.map(convert_answer)

In [None]:
# Check a single sample
# Mismatch: 10, 100, 105, 600
dataset[200]

In [None]:
dataset

### Rejection Sampling

In [None]:
def normalize_text(text):
    """
    Standardizes answers for comparison:
    1. Lowers case.
    2. Strips whitespace.
    3. Removes trailing punctuation (like '6.' -> '6').
    4. Converts word-numbers ('six') to digits ('6').
    """
    if text is None:
        return ""
    
    # Basic cleanup
    text = str(text).strip().lower()
    
    # Remove trailing punctuation (e.g., "6." -> "6")
    text = re.sub(r'[^\w\s]', '', text)
    
    # Convert number words to digits using the map
    if text in num_map:
        text = num_map[text]
        
    return text

In [None]:
def rejection_sampling_filter(example):
    """
    Returns True if the CoT answer matches the Ground Truth answer.
    Returns False otherwise.
    """
    ground_truth = example['answer']
    cot_data = example['cot_think_data']
    
    # 1. Extract the answer from inside <answer> tags
    # We use re.DOTALL to handle newlines, though usually answer is short
    match = re.search(r"<answer>(.*?)</answer>", cot_data, re.DOTALL | re.IGNORECASE)
    
    # If no <answer> tag found, REJECT immediately
    if not match:
        return False
        
    generated_answer_content = match.group(1)
    
    # 2. Normalize both
    norm_gt = normalize_text(ground_truth)
    norm_gen = normalize_text(generated_answer_content)
    
    # 3. Compare
    return norm_gt == norm_gen

In [None]:
# --- MAIN EXECUTION ---
print(f"Original Dataset Size: {len(dataset)}")

# Apply the Rejection Sampling
# load_from_cache_file=False ensures we re-run logic if we changed code
filtered_dataset = dataset.filter(rejection_sampling_filter, load_from_cache_file=False)

In [None]:
# Print Statistics

original_count = len(dataset)
filtered_count = len(filtered_dataset)
rejected_count = original_count - filtered_count

In [None]:
print(f"\n--- Rejection Sampling Results ---\n")
print(f"Original: {original_count}")
print(f"Kept:     {filtered_count}")
print(f"Rejected: {rejected_count}")
print(f"Retention Rate: {(filtered_count/original_count)*100:.2f}%")

In [None]:
dataset = filtered_dataset

In [None]:
dataset

### Split Into Train, Test & Val

In [None]:
from datasets import DatasetDict

In [None]:
# 1. First create train (80%) and temp (20%)
train_test = dataset.train_test_split(test_size=0.25, seed=42)

# 2. Split the temp set into validation (10%) and test (10%)
test_val = train_test['test'].train_test_split(test_size=0.6, seed=42)

In [None]:
splits = {
    'train': train_test['train'],
    'validation': test_val['train'],
    'test': test_val['test'],
}

dataset_dict = DatasetDict(splits)

In [None]:
train_set, val_set, test_set = dataset_dict['train'], dataset_dict['validation'], dataset_dict['test']

In [None]:
print(f'Length of the train set: {len(train_set)}, Val set: {len(val_set)} and Test Set: {len(test_set)}')

In [None]:
# Check A Particular Sample For Reproducibility

print(train_set[100]['question'])
print(val_set[100]['question'])
print(test_set[100]['question'])

### Importing Model

In [None]:
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoModelForImageTextToText, BitsAndBytesConfig, Idefics3ForConditionalGeneration, AutoProcessor, AutoModelForVision2Seq

In [None]:
model_id = "HuggingFaceTB/SmolVLM-256M-Instruct"

In [None]:
device = 'cuda:0'

### Model Loading & Quantization

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config = bnb_config,
    torch_dtype = torch.bfloat16, 
    _attn_implementation = "flash_attention_2",
    device_map = 'cuda:0'
)

processor = AutoProcessor.from_pretrained(model_id)

### LORA Config

In [None]:
rank_everywhere = 32

In [None]:
lora_config = LoraConfig(
    r=rank_everywhere,
    lora_alpha=rank_everywhere*2,
    lora_dropout=0.05,
    target_modules=[ 'k_proj', 'q_proj', 'v_proj', 'out_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'mlp.fc1', 'mlp.fc2'],
    init_lora_weights="gaussian",
    inference_mode = False
)

In [None]:
qlora_model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
qlora_model = get_peft_model(qlora_model, lora_config)

### Count Number of Params.

In [None]:
def report_trainable_params(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable params: {trainable/1e6:.1f} M")

report_trainable_params(qlora_model)

In [None]:
train_set

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[processor.tokenizer.additional_special_tokens.index("<image>")]

In [None]:
def collate_fn(examples):
    texts = []
    images = []

    for example in examples:
        image = example["image"]
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image"
                    },
                    {
                        "type": "text", 
                        "text": example["question"]
                    }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text", 
                    "text": example["cot_think_data"]
                }
               ]
            }
        ]
        text = processor.apply_chat_template(messages, add_generation_prompt=False)
        texts.append(text.strip())
        images.append([image])

    # Batch using processor
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # Manually set labels
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    # Now cast pixel_values explicitly
    batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
    return batch

### Validation Function

In [None]:
@torch.no_grad
def do_validation():

    qlora_model.eval()
    val_loss = 0.0
    
    for batch in tqdm(val_loader, desc="Validating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = qlora_model(**batch)
        loss = outputs["loss"]
        val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    torch.cuda.empty_cache()
    qlora_model.train()
    return avg_val_loss

### Training Params

In [None]:
batch_ = 4
epochs = 5
grad_accum_steps = 2

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_, shuffle=False, collate_fn=collate_fn)

In [None]:
optimizer = torch.optim.AdamW(qlora_model.parameters(), lr=2e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

In [None]:
global_step = 0
best_val_loss = float("inf")

In [None]:
_ = qlora_model.train()
qlora_model.use_cache = False

In [None]:
saveDir = '/home/aritrad/main/SmolVLM-2B/RL/chkpts/'

### Training Loop

In [None]:
for epoch in tqdm(range(epochs)):  
    
    accumulated_loss = 0
    
    for idx, batch in enumerate(train_loader):
        batch = {k: v.to('cuda:0') for k, v in batch.items()}
        outputs = qlora_model(**batch)
        loss = outputs["loss"] / grad_accum_steps

        loss.backward()
        accumulated_loss += loss.item()
        
        if (idx + 1) % grad_accum_steps == 0: 
            
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

            logger.info(f"[ Epoch {epoch+1} | idx: {idx} | Optim Step {global_step} | Train Loss: {loss.item():.4f} ]")

            if global_step % 10 == 0:
                avg_val_loss = do_validation()
                logger.info(f"Val Loss @ Optim step: {global_step} -> {avg_val_loss:.4f}\n")
            
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    qlora_model.save_pretrained(os.path.join(saveDir, 'best-smolvlm-256M-qty-chkpt-32'))
                    logger.info(f"***** ✅ Checkpoint Saved *****\n")
    
    # StepLR Scheduler is updated at Last.
    scheduler.step() 
    qlora_model.save_pretrained(os.path.join(saveDir, f'smolvlm-256M-qty-chkpt-{epoch+1}'))
    logger.info(f"***** ✅ Checkpoint Saved *****\n")
    logger.info(f"Epoch {epoch+1} completed. Avg loss: {accumulated_loss / len(train_loader):.4f}")