In [None]:
import os
import json
import torch
import pickle
import random
import datetime

In [None]:
import numpy as np
from torch import nn
from PIL import Image
from tqdm.auto import tqdm
from torch.optim import AdamW
from functools import partial
import matplotlib.pyplot as plt
from collections import Counter
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig, BitsAndBytesConfig

In [None]:
from util.vision_util import process_vision_info
from util.logutil import init_logger, get_logger

In [None]:
conda_env = os.environ.get("CONDA_DEFAULT_ENV")
print(f"Current Conda environment: {conda_env}")

### Load the Prototype Test Mixed Precision Dataset:

In [None]:
pwd

In [None]:
# Unpickling (De -serialization)

with open('/home/aritrad/MOE-Directory/moe-datasets/TDIUC/custom-moe/Test-Set/prototype-test-set-1.5K-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    test_set_manual_annt = pickle.load(file)

In [None]:
test_set_manual_annt[0:2], len(test_set_manual_annt)

In [None]:
print(f'Lenght of mixed reasoing dataset: {len(test_set_manual_annt)}')

In [None]:
from datasets import Dataset

In [None]:
image_folder_path = "/home/aritrad/MOE-Directory/moe-datasets/TDIUC/TDIUC/Images/val2014"
prefix = "Generate a one word answer for the given image and question: "

In [None]:
expert_names = ["Physical Reasoning.", "Quantity Reasoning.", "Spatial Reasoning.", "Social and Emotional Reasoning."]
label2id = {name: idx for idx, name in enumerate(expert_names)}

In [None]:
# Using list comprehension to update reasoning_type
test_set_manual_annt = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in test_set_manual_annt
]

In [None]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in test_set_manual_annt ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in test_set_manual_annt ],
    'answer': [ dict_['answer'] for dict_ in test_set_manual_annt ], 
}

test_set = Dataset.from_dict(listToDictionary)

In [None]:
test_set

## Model Loading

In [None]:
device = "cuda:0"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_compute_dtype = torch.bfloat16,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant=True,
)

In [None]:
backbone = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype = torch.bfloat16,
    attn_implementation = "flash_attention_2",
    quantization_config = bnb_config,
    device_map = device
)

In [None]:
total_params = sum(p.numel() for p in backbone.parameters())
trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Load processor. 
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=256*28*28, max_pixels=512*28*28, padding_side="left", use_fast=True)

### Fetching Expert Names:

In [None]:
from pathlib import Path

In [None]:
expert_names = []

# Hard code the order of loading the experts as in the Training Data.
inference_adapter_names = [
    "physical",      # trains on "Physical Reasoning."
    "quantitative",  # trains on "Quantity Reasoning."
    "spatial",       # trains on "Spatial Reasoning."
    "social"         # trains on "Social and Emotional Reasoning."
]

ADAPTER_ROOT = '/home/aritrad/MOE-Directory/moe-datasets/TDIUC/custom-moe/moe-end2end/2-2B/best_adapters'

for name in inference_adapter_names:
    path = Path(ADAPTER_ROOT) / f"{name}" 
    backbone.load_adapter(path, adapter_name=name, is_trainable=False)
    expert_names.append(name)
    
print("Experts:", expert_names)  

### Load the Trained Router:

In [None]:
class Router(nn.Module):
    
    def __init__(self, hidden=768, n_experts=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 2, hidden // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 4, n_experts)
        )
        
    def forward(self, fused):          # (B, 1536)
        return self.net(fused)

In [None]:
# Code for testing saved router checkpoint on other test sets.
router = Router(n_experts=len(expert_names)).to(device)

# Comment this block when Traininig.
checkpoint = torch.load( '/home/aritrad/MOE-Directory/moe-datasets/TDIUC/custom-moe/Checkpoints/best_router.pt') 
router.load_state_dict(checkpoint)

print("Router Initialized ✓")

### Contextual Text Embedding From SBERT

In [None]:
from sentence_transformers import SentenceTransformer, util
sbert = SentenceTransformer('all-mpnet-base-v2', device = device)

In [None]:
# sbert

In [None]:
def get_text_repr(batch):
    """
    batch is a dict with
      - batch["question"]: List[str]
      - batch["labels"]:   Tensor
    returns a torch.Tensor of shape (B, 768) on `device`
    """    
    # SBERT.encode by default runs under no_grad, so SBERT stays frozen.
    embeds = sbert.encode(
        batch["question"],
        convert_to_tensor=True,
        device=device,
    )

    # dtype=torch.float32
    return embeds  

### Choose Expert

In [None]:
@torch.no_grad()
def choose_expert(batch):

    # Get embedding
    sent_vec = get_text_repr(batch)     
    
    # Pick the highest‑scoring expert for each sample
    # Tensor of shape (B,)
    return router(sent_vec).argmax(dim=-1)           

In [None]:
"""# Unit Test: Test the choose_expert function.

sample_batch = next(iter(test_loader))
print(choose_expert(sample_batch))"""

### Generate

In [None]:
@torch.no_grad()
def generate_answer(batch, expert_idx):
    """
    batch        : dict from collate_fn with keys "image" (list[str]) and "question" (list[str])
    expert_idx   : Tensor of shape (B,), each entry in [0..3]

    Returns: list of (adapter_name, answer) length B
    """
    # print(f"Question: {batch['question']}")
    
    answers = []
    for j, idx in enumerate(expert_idx.tolist()):
        adapter = expert_names[idx]

        # switch to this expert (activate the adapter)
        backbone.set_adapter(adapter)   

        # 1) build a single‐sample chat message
        message = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": batch["image"][j]},
                    {"type": "text",  "text": batch["question"][j] + " ?"},
                ],
            }
        ]

        # 2) apply the chat template (adds system prompt, generation prompt)
        text_input = processor.apply_chat_template(
            message, tokenize=False, add_generation_prompt=True
        )

        # 3) extract vision inputs (uses your existing util)
        image_inputs, video_inputs = process_vision_info(message)

        # 4) pack everything into model tensors
        inputs = processor(
            text   = text_input,
            images = image_inputs,
            videos = video_inputs,
            padding=True,
            return_tensors="pt",
        ).to(device)

        # 5) generate with this expert
        generated_ids = backbone.generate(
            **inputs,
            max_new_tokens=128,
        )

        # 6) trim off the prompt tokens and decode
        in_len   = inputs.input_ids.shape[1]
        out_ids  = generated_ids[0, in_len:]
        answer   = processor.batch_decode(
            [out_ids],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )[0].strip()

        answers.append((adapter, answer))
        
    return answers

### Test-Set Pre-processing

In [None]:
def collate_fn(examples):
    """
    examples: list of dicts, each with keys
      - "image"    : str (filepath)
      - "question" : str

    Returns a batch dict with two lists:
      - batch["image"]    : list[str]
      - batch["question"] : list[str]
    """
    return {
        "image":    [ex["image"]    for ex in examples],
        "question": [ex["question"] for ex in examples],
        "answer": [ex["answer"] for ex in examples]
    }


In [None]:
test_loader = DataLoader(test_set,
                          batch_size = 16,
                          shuffle = False,
                          collate_fn = collate_fn)

In [None]:
# Test the batch.

for batch in test_loader:
    print(batch)
    break

## Final Loop

In [None]:
OUTPUT_JSONL = "/home/aritrad/MOE-Directory/moe-datasets/TDIUC/custom-moe/JSON-Reports/end-to-end-trained-architecture-predictions-report.jsonl"
chosen_experts, answers = [],[]

In [None]:
with Path(OUTPUT_JSONL).open("w") as fout:
    
    for batch in tqdm(test_loader):
        # ❶ run your router
        expert_idx = choose_expert(batch)             # returns Tensor of size (B,)

        # ❷ generate with the chosen expert per sample
        preds = generate_answer(batch, expert_idx)    # list of (adapter, answer)

        # print(f'Answer: {preds}\n')

        # Separate and accumulate
        adapters, ans = zip(*preds)
        chosen_experts.extend(adapters)
        answers.extend(ans)

        # ❸ write out each line
        for img_path, question, groundTruthAnswer, (adapter, answer) in zip(
                batch["image"],
                batch["question"],
                batch["answer"],
                preds
        ):
            fout.write(json.dumps({
                #"image"        : img_path,
                "question"     : question.split(': ')[1],
                "chosenExpert" : adapter,
                "groundTruth"  : groundTruthAnswer,
                "answer"       : answer,
            }) + "\n")

print(f"✓ Done. Predictions saved to {OUTPUT_JSONL}")

In [None]:
groundtruth_answer = test_set['answer']
generated_answer = answers

In [None]:
pwd

### Calculate Accuracy

### Evaluate Exact String Match (EM)

In [None]:
# Initialize variables for accuracy calculation
correct_predictions = 0
total_predictions = len(generated_answer)

# Loop through the results and compare answers
for i in range(len(generated_answer)):
    if generated_answer[i].strip().lower() == groundtruth_answer[i].strip().lower():
        correct_predictions += 1

# Calculate accuracy
accuracy = (correct_predictions / total_predictions) * 100

print(f"Accuracy: {accuracy:.2f} %")

### BERT Score

Evaluating with BERT Score:

Precision (P): How much of the candidate's content is relevant.

Recall (R): How much of the reference's content is covered by the candidate.

F1 Score (F1): Harmonic mean of Precision and Recall, commonly used as the final metric.

In [None]:
from bert_score import score

# Example references and candidates
# references = ['stool','no','person','stool','sign','bronze','door','no','red','chair','red','black']
# candidates = ['stool','no','child','stool','sign','gold','picture','no','brown','chair','brown','black']

In [None]:
# Compute BERTScore, answerList_test = ground truth, result_list = model generated.

P, R, F1 = score(generated_answer, groundtruth_answer, lang="en", verbose=True, device='cuda')

In [None]:
# Print scores
print("Mean Precision:", np.round(np.mean(P.tolist()) * 100, 2) )
print("Mean Recall:", np.round(np.mean(R.tolist()) * 100, 2) )
print("Mean F1 Score:", np.round(np.mean(F1.tolist()) * 100, 2) )

### Evaluating with BLEU-1 Score

In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [None]:
bleu_scores = []

# Function to compute BLEU-1 score for a list of ground truth and predicted answers
def calculate_bleu_1_score(ground_truth, predicted):

    # This sets BLEU-1 to only consider unigram precision
    weights = [1.0] + [0.0] * 3  
    
    # Smoothing function to handle cases with no n-gram matches
    smoothing_function = SmoothingFunction().method1  
    
    for gt, pred in zip(ground_truth, predicted):
        score = sentence_bleu([gt], pred, weights=weights, smoothing_function=smoothing_function)  
        bleu_scores.append(score)
    
    avg_bleu_score = sum(bleu_scores) / len(bleu_scores)
    
    return avg_bleu_score

In [None]:
# Calculate the BLEU score

avg_bleu = calculate_bleu_1_score(groundtruth_answer, generated_answer)
print(f"Average BLEU score: {np.round(avg_bleu*100, 2)}")

### ROUGE Score

In [None]:
from rouge_score import rouge_scorer

In [None]:
def calculate_avg_rouge_scores(references, candidates):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    rouge1_scores, rouge2_scores, rougeL_scores = [], [], []

    for ref, cand in zip(references, candidates):
        scores = scorer.score(ref, cand)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)

    return {
        "ROUGE-1": np.round(np.mean(rouge1_scores), 4),
        "ROUGE-2": np.round(np.mean(rouge2_scores), 4),
        "ROUGE-L": np.round(np.mean(rougeL_scores), 4),
    }

In [None]:
avg_rouge = calculate_avg_rouge_scores(groundtruth_answer, generated_answer)
avg_rouge = { k:round(v*100, 2) for k,v in avg_rouge.items()}
print("Average ROUGE scores:", avg_rouge)

## Cosine Accuracy

In [None]:
sbert = SentenceTransformer('all-mpnet-base-v2', device = device)

In [None]:
def findCosSim(word1:str, word2:str) -> int:

    # Compute the embeddings
    embedding1 = sbert.encode(word1, convert_to_tensor=True)
    embedding2 = sbert.encode(word2, convert_to_tensor=True)
    
    # Compute cosine similarity
    cosine_score = util.pytorch_cos_sim(embedding1, embedding2)
    return round(cosine_score.item(), 2)

In [None]:
cosineAccuracy = [ findCosSim( generated_answer[idx], groundtruth_answer[idx] ) > 0.71 for idx in tqdm(range(len(generated_answer))) ]

In [None]:
( sum(cosineAccuracy) / len(cosineAccuracy) ) * 100