In [None]:
import os, torch, re
print(os.getenv("CONDA_DEFAULT_ENV"))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import json
import pickle
import random
import datasets
import numpy as np
from datasets import Image
from tqdm.auto import tqdm
from datasets import Dataset
from datasets import load_dataset
from torch.utils.data import DataLoader

### Force Determinism

In [None]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

### Load Dataset

In [None]:
with open('/home/aritrad/tallyQA/test.json', 'r') as file:
    test_set = json.load(file)

In [None]:
dataset = load_dataset(
    "json", 
    data_files={ 'test': "/home/aritrad/tallyQA/test.json"}
)

In [None]:
print(f'Length of the Test Set: { len(dataset["test"]) }')

In [None]:
test_set = dataset["test"]

In [None]:
test_set

In [None]:
test_set = test_set.remove_columns(['data_source', 'question_id', 'image_id'])

In [None]:
base_image_path = "/home/aritrad/cric/visual_genome/images"

In [None]:
def format_path(example):
    # example["image"] looks like "VG_100K_2/1.jpg"
    filename = example["image"].split("/")[-1]   # "1.jpg"
    new_path = f"{base_image_path}/{filename}"
    example["image"] = new_path
    return example

In [None]:
test_set = test_set.map(format_path)

In [None]:
test_set = test_set.cast_column("image", Image())

### Filter out Complex Counting Question

In [None]:
test_set_filtered = test_set.filter(lambda x: x["issimple"] == True)

In [None]:
len(test_set_filtered)

### Importing Models

In [None]:
DEVICE = 'cuda'

In [None]:
from peft import PeftModel
from transformers.image_utils import load_image
from transformers import AutoModelForImageTextToText, BitsAndBytesConfig, Idefics3ForConditionalGeneration, AutoProcessor, AutoModelForVision2Seq

In [None]:
model_id = "HuggingFaceTB/SmolVLM-256M-Instruct"
GRPO_finetuned_model_path = '/home/aritrad/main/SmolVLM-2B/RL/chkpts/chkpts_grpo/checkpoint-4000' 

In [None]:
base_model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    dtype=torch.bfloat16, 
    _attn_implementation="flash_attention_2",
    device_map = 'auto',
)

processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left"     # For batched generation.

In [None]:
# Load the QLORA-trained model.

peft_model = PeftModel.from_pretrained(base_model, GRPO_finetuned_model_path)

### DECLARE SYSTEM PROMPT

In [None]:
# The system prompt is extracted from DeepSeek R1 paper, modified for Quantity Reasoning
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question about an image, and the Assistant solves it. " # <--- Added "about an image"
    "The assistant first thinks about the reasoning process by analyzing visual elements and then provides the user with " # <--- Added "analyzing visual elements"
    "the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> "
    "</answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>\n\n"
    "Example:\n"
    "User: How many cats are in the image?\n"
    "Assistant: <think>\n"
    "1. Scanning the image, I see a black cat on the sofa.\n" # <--- "Scanning the image" reinforces vision
    "2. I also see a white cat under the table.\n"
    "Total count is 2.\n"
    "</think>\n"
    "<answer> 2 </answer>"
)

In [None]:
len(test_set_filtered)

In [None]:
test_set_filtered

### Prepare Chat Messages

In [None]:
def collate_fn(examples):
    texts = []
    images = []

    for example in examples:
        image = example["image"]
        
        # 1. MERGE SYSTEM PROMPT (Matches your Training Setup)
        # We inject the instructions directly into the User's text.
        user_text = SYSTEM_PROMPT + "\n\nQuestion: " + example["question"]
        
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": user_text}
                ]
            }
        ]
        
        # 2. GENERATE PROMPT WITH ASSISTANT HEADER
        # add_generation_prompt=True ensures it ends with "Assistant:" (or equivalent)
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        
        # 3. FORCE START (The "Qwen" Trick)
        # We manually append the start tag. The model MUST continue from here.
        text = text + "<think>\n1."
        
        texts.append(text.strip())
        images.append([image])

    # Batch using processor
    # Note: Increase max_new_tokens in generate() later, 256 might be too short for thinking!
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # Cast to bf16
    batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
    return batch

### Create Batches

In [None]:
# Create Test Dataloader.

batch_ = 32
test_loader = DataLoader(test_set, batch_size=batch_, shuffle=False, collate_fn=collate_fn)

### Generate Answers

In [None]:
decoded_outputs=list()

for batch in tqdm(test_loader):
    
    batch = {key: value.to('cuda') for key, value in batch.items()}
    
    with torch.no_grad():
        outputs = peft_model.generate(
            **batch, 
            max_new_tokens=256,
            do_sample=True,                # enable sampling
            temperature=0.6,               # randomness factor
            num_return_sequences=1,
            repetition_penalty=1.05,       # Slight penalty to prevent <think><think> loops
        )
        model_generated_output_only = outputs[:, batch["input_ids"].shape[-1]:]
        decoded_output = processor.batch_decode(
            model_generated_output_only, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )
        decoded_outputs.extend(decoded_output)

In [None]:
print('<think>\n1.' + decoded_outputs[0])

### Manual Addition of < think > tokens to Generated Outputs ( Added Force-Tokens )¶

In [None]:
decoded_outputs = ['<think>\n1.' + decoded_outputs[i] for i in range(len(decoded_outputs))]

### Manual Verification

In [None]:
idx = 0
print(decoded_outputs[idx])

In [None]:
print(test_set_filtered[idx]['question'])

In [None]:
image = test_set_filtered[idx]['image']
image.thumbnail((400, 800))
image

### End of Manual Verification

### Parsing Outputs

In [None]:
generated_outputs = list()

In [None]:
def process_clean_trace(rawOutputs, outputList):

    for response_text in rawOutputs:

        try:
            thought_part = response_text.split('<think>')[1].strip().split("</think>")[0].strip()
            answer_part = response_text.split("<answer>")[1].split("</answer>")[0].strip()
    
            outputList.append( (thought_part, answer_part) )
            
        except:
            
            # Poor Formatted Outputs
            outputList.append( ('NULL', 'NULL') )

In [None]:
process_clean_trace(decoded_outputs, generated_outputs)

In [None]:
groundTruthAnswer = test_set_filtered['answer']

In [None]:
len(generated_outputs), len(groundTruthAnswer)

### Calculate Proper Formatted Output Percentage.

In [None]:
def calculateFormattedOutputPercent(targetList):
    count = 0
    for item in targetList:
        if item[0]=='NULL':
            count += 1

    return ((len(targetList)-count)/len(targetList))*100

In [None]:
calculateFormattedOutputPercent(decoded_ouputs)

### Cosine Function

In [None]:
from sentence_transformers import SentenceTransformer, util
sbert = SentenceTransformer('all-mpnet-base-v2', device = 'cuda')

In [None]:
def findCosSim(word1:str, word2:str) -> int:

    # Compute the embeddings
    embedding1 = sbert.encode(word1, convert_to_tensor=True)
    embedding2 = sbert.encode(word2, convert_to_tensor=True)
    
    # Compute cosine similarity
    cosine_score = util.pytorch_cos_sim(embedding1, embedding2)
    return round(cosine_score.item(), 2)

### Accuracy - Short Answers¶

In [None]:
decoded_ouputs_short_answers = [decoded_ouputs[idx][1] for idx in range(len(generated_outputs))]

In [None]:
groundTruth_outputs_short_answers = [ str(groundTruthAnswer[idx]) for idx in range(len(groundTruthAnswer))]

In [None]:
groundTruth_outputs_short_answers[500:505]

In [None]:
decoded_ouputs_short_answers[500:505]

In [None]:
accuracy = ( sum([ 1 if groundTruth_outputs_short_answers[i].strip()==decoded_ouputs_short_answers[i].strip() else 0 for i in range(len(decoded_ouputs)) ]) / len(decoded_ouputs) ) * 100

In [None]:
print(f'Final Short Answer Accuracy: {round(accuracy, 2)} %')

### Check Umatched Indices

In [None]:
unmatched_indices = [
    idx
    for idx in range(len(groundTruth_outputs_short_answers))
    if groundTruth_outputs_short_answers[idx].strip() != decoded_ouputs_short_answers[idx].strip()
]

In [None]:
unmatched_indices[0:10]

In [None]:
idx=28

In [None]:
print("<think>\nThe user asked" + resultGeneratedAnswers[idx])

In [None]:
print('Q:', test_set_filtered[idx]['question'],'\n\nA:', test_set_filtered[idx]['answer'], '\n\nidx:', idx)

In [None]:
test_set_filtered[idx]['image']