In [None]:
import os
import json
import torch
import xlora
import pickle
import random
import datetime
import xlsxwriter

In [None]:
pwd

In [None]:
import numpy as np
from torch import nn
from PIL import Image
from tqdm.auto import tqdm
from torch.optim import AdamW
from functools import partial
import matplotlib.pyplot as plt
from collections import Counter
from torch.utils.data import DataLoader
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig, BitsAndBytesConfig

In [None]:
"""from util.vision_util import process_vision_info
from util.logutil import init_logger, get_logger"""

In [None]:
conda_env = os.environ.get("CONDA_DEFAULT_ENV")
print(f"Current Conda environment: {conda_env}")

In [None]:
"""def formatImagePath(dict_):
    imageName = dict_['image_id']
    # modifiedImageName = f"COCO_val2014_{imageName:012d}.jpg"
    imagePath = '/home/aritrad/moe-directory/moe-datasets/TDIUC/TDIUC/Images/val2014'
    dict_['image_id'] = os.path.join(imagePath, imageName)
    return dict_"""

### Load the Prototype Mixed Precision Dataset (For Router Training).

In [None]:
# Pickling (Serialization)

with open('/home/aritrad/moe-directory/moe-datasets/TDIUC/prototype-train-set-8k-for-router-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    mixed_reasoning_data_prototype = pickle.load(file)

In [None]:
mixed_reasoning_data_prototype[0:2], len(mixed_reasoning_data_prototype)

In [None]:
# Count reasoning types
reasoning_counts = Counter(item["reasoning_type"] for item in mixed_reasoning_data_prototype)
reasoning_counts

In [None]:
"""# subsetting for testing

mixed_reasoning_data_prototype_xlora = mixed_reasoning_data_prototype_xlora[:500]"""

In [None]:
print(f'Length of mixed reasoing dataset: {len(mixed_reasoning_data_prototype)}')

In [None]:
from datasets import Dataset

In [None]:
image_folder_path = "/home/aritrad/moe-directory/moe-datasets/TDIUC/TDIUC/Images/val2014"
prefix = "Generate a one word answer for the given image and question: "

In [None]:
expert_names = ["Physical Reasoning.", "Quantity Reasoning.", "Spatial Reasoning.", "Social and Emotional Reasoning."]
label2id = {name: idx for idx, name in enumerate(expert_names)}

In [None]:
label2id

In [None]:
# Using list comprehension to update reasoning_type
mixed_reasoning_data_prototype = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in mixed_reasoning_data_prototype
]

In [None]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in mixed_reasoning_data_prototype ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in mixed_reasoning_data_prototype ],
    'reasoning_type': [ dict_['reasoning_type'] for dict_ in mixed_reasoning_data_prototype ]
}

reasonining_type_set = Dataset.from_dict(listToDictionary)

In [None]:
len(reasonining_type_set)

In [None]:
# Split into Train and Test

split = reasonining_type_set.train_test_split(test_size=1000, seed=42)

In [None]:
train_set = split['train']
val_set = split['test']

In [None]:
len(train_set), len(val_set)

### Load Main Test Set (For Measuring Router Accuracy)

In [None]:
# Pickling (Serialization)

with open('/home/aritrad/moe-directory/moe-datasets/TDIUC/prototype-test-set-2k-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    test_set = pickle.load(file)

In [None]:
# Using list comprehension to update reasoning_type
test_set = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in test_set
]

In [None]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in test_set ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in test_set ],
    'reasoning_type': [ dict_['reasoning_type'] for dict_ in test_set ]
}

test_set = Dataset.from_dict(listToDictionary)

In [None]:
test_set[0], len(test_set)

## Model Loading

In [None]:
device = "cuda:0"

In [None]:
from sentence_transformers import SentenceTransformer, util
sbert = SentenceTransformer('all-mpnet-base-v2', device = device)

### Dataset Pre-processing

In [None]:
def collate_fn(examples):
    # pull out the raw strings and integer labels
    questions = [ex["question"] for ex in examples]
    labels    = torch.tensor([ex["reasoning_type"] for ex in examples],
                             dtype=torch.long)
    return {
        'questions': questions, 
        'labels': labels
    }

In [None]:
train_loader = DataLoader(train_set,
                          batch_size = 32,
                          shuffle = True,
                          collate_fn = collate_fn)

val_loader = DataLoader(val_set,
                          batch_size = 32,
                          shuffle = True,
                          collate_fn = collate_fn)

test_loader = DataLoader(test_set,
                          batch_size = 32,
                          shuffle = False,
                          collate_fn = collate_fn)

In [None]:
# Test the batch.

for batch in train_loader:
    print(batch)
    break

In [None]:
print(f'Number of batches in the Training Set: {len(train_loader)}, Test Set: {len(test_loader)} and Val Set: {len(val_loader)}')

### Router Definition

In [None]:
pwd

In [None]:
class Router(nn.Module):
    
    def __init__(self, hidden=768, n_experts=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 2, hidden // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 4, n_experts)
        )
        
    def forward(self, fused):          # (B, 1536)
        return self.net(fused)

In [None]:
# Code for testing saved router checkpoint on other test sets.
router = Router(n_experts=len(expert_names)).to(device)

# Comment this block when Traininig.
checkpoint = torch.load('./router_best.pt') 
router.load_state_dict(checkpoint)

print("Router Initialized ✓")

In [None]:
param_count = sum(p.numel() for p in router.parameters() if p.requires_grad)
print(f"Trainable parameters: {param_count}")

In [None]:
optimizer = torch.optim.AdamW(router.parameters(), lr = 1e-4)
loss_func = nn.CrossEntropyLoss()
epochs = 12

### Get Embedding

In [None]:
def get_text_repr(batch):
    """
    batch is a dict with
      - batch["question"]: List[str]
      - batch["labels"]:   Tensor
    returns a torch.Tensor of shape (B, 768) on `device`
    """
    # SBERT.encode by default runs under no_grad, so SBERT stays frozen.
    embeds = sbert.encode(
        batch["questions"],
        convert_to_tensor=True,
        device=device,
    )

    # dtype=torch.float32
    return embeds  

### Train and Evaluate:

In [None]:
def evaluate(loader):
    router.eval()
    correct = total = 0
    
    with torch.inference_mode():
        for batch in tqdm(loader):
            
            labels = batch["labels"].to(device)
            fused  = get_text_repr(batch)
            preds = router(fused).argmax(dim=-1)
                
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
            
    return correct / total

In [None]:
print("[•] Training router …")
best = 0.0

In [None]:
for epoch in range(epochs):
    
    router.train()
    for step, batch in enumerate(train_loader):
        
        labels = batch["labels"].to(device)
        fused  = get_text_repr(batch)   # (B,768) float32
        logits = router(fused)     
            
        loss = loss_func(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % 50 == 0:
            print(f"epoch {epoch+1} step {step:4d} loss {loss.item():.4f}")

    # validation
    val_acc = evaluate(val_loader)
    print(f"\n|| Epoch {epoch+1} → val acc = {val_acc: .3%} ||\n")
    if val_acc > best:
        best = val_acc
        torch.save(router.state_dict(), "router_best_1.pt")

In [None]:
test_acc = evaluate(test_loader)
print(f'Test Set accuracy of the Router is: {test_acc*100:.4f}')