In [1]:
import os
import json
import torch
import xlora
import pickle
import random
import datetime
import xlsxwriter

In [2]:
pwd

'/home/aritrad/moe-directory/moe-datasets/TDIUC/custom-moe/using-sbert'

In [3]:
import numpy as np
from torch import nn
from PIL import Image
from tqdm.auto import tqdm
from torch.optim import AdamW
from functools import partial
import matplotlib.pyplot as plt
from collections import Counter
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig, BitsAndBytesConfig

In [4]:
conda_env = os.environ.get("CONDA_DEFAULT_ENV")
print(f"Current Conda environment: {conda_env}")

Current Conda environment: stable_env


### Load the Prototype Mixed Precision Dataset (For Router Training).

In [5]:
# Pickling (Serialization)

with open('/home/aritrad/moe-directory/moe-datasets/TDIUC/prototype-train-set-8k-for-router-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    mixed_reasoning_data_prototype = pickle.load(file)

In [6]:
mixed_reasoning_data_prototype[0:2], len(mixed_reasoning_data_prototype)

([{'image_id': 'COCO_val2014_000000500316.jpg',
   'question': 'What color ski gear does the person have?',
   'question_id': 20011883,
   'serial_no': 1509,
   'reasoning_type': 'Physical Reasoning.',
   'answer': 'blue'},
  {'image_id': 'COCO_val2014_000000308645.jpg',
   'question': 'How many people are in the room?',
   'question_id': 30270039,
   'serial_no': 2494,
   'reasoning_type': 'Quantity Reasoning.',
   'answer': 'one'}],
 8000)

In [7]:
# Count reasoning types
reasoning_counts = Counter(item["reasoning_type"] for item in mixed_reasoning_data_prototype)
reasoning_counts

Counter({'Physical Reasoning.': 2000,
         'Quantity Reasoning.': 2000,
         'Spatial Reasoning.': 2000,
         'Social and Emotional Reasoning.': 2000})

In [8]:
"""# subsetting for testing

mixed_reasoning_data_prototype_xlora = mixed_reasoning_data_prototype_xlora[:500]"""

'# subsetting for testing\n\nmixed_reasoning_data_prototype_xlora = mixed_reasoning_data_prototype_xlora[:500]'

In [9]:
print(f'Length of mixed reasoing dataset: {len(mixed_reasoning_data_prototype)}')

Length of mixed reasoing dataset: 8000


In [10]:
from datasets import Dataset

In [11]:
image_folder_path = "/home/aritrad/moe-directory/moe-datasets/TDIUC/TDIUC/Images/val2014"
prefix = "Generate a one word answer for the given image and question: "

In [12]:
expert_names = ["Physical Reasoning.", "Quantity Reasoning.", "Spatial Reasoning.", "Social and Emotional Reasoning."]
label2id = {name: idx for idx, name in enumerate(expert_names)}

In [13]:
# Using list comprehension to update reasoning_type
mixed_reasoning_data_prototype = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in mixed_reasoning_data_prototype
]

In [14]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in mixed_reasoning_data_prototype ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in mixed_reasoning_data_prototype ],
    'reasoning_type': [ dict_['reasoning_type'] for dict_ in mixed_reasoning_data_prototype ]
}

reasonining_type_set = Dataset.from_dict(listToDictionary)

In [15]:
len(reasonining_type_set)

8000

In [16]:
# Split into Train and Test

split = reasonining_type_set.train_test_split(test_size=1000, seed=42)

In [17]:
train_set = split['train']
val_set = split['test']

In [18]:
len(train_set), len(val_set)

(7000, 1000)

### Load Main Test Set (For Measuring Router Accuracy)

In [19]:
# Pickling (Serialization)

with open('/home/aritrad/moe-directory/moe-datasets/TDIUC/prototype-test-set-2k-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    test_set = pickle.load(file)

In [20]:
# Using list comprehension to update reasoning_type

test_set = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in test_set
]

In [21]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in test_set ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in test_set ],
    'reasoning_type': [ dict_['reasoning_type'] for dict_ in test_set ]
}

test_set = Dataset.from_dict(listToDictionary)

In [22]:
test_set[0], len(test_set)

({'question': 'Generate a one word answer for the given image and question: How many stop signs can be seen?',
  'image': '/home/aritrad/moe-directory/moe-datasets/TDIUC/TDIUC/Images/val2014/COCO_val2014_000000252332.jpg',
  'reasoning_type': 1},
 2000)

## Model Loading

In [None]:
device = "cuda:0"

In [None]:
from transformers import AutoTokenizer, AutoModel

In [None]:
tokenizer   = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
sbert_model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2').to(device)

### Dataset Pre-processing

In [None]:
def collate_fn(examples):
    # pull out the raw strings and integer labels
    questions = [ex["question"] for ex in examples]
    labels    = torch.tensor([ex["reasoning_type"] for ex in examples],
                             dtype=torch.long)
    return {
        'questions': questions, 
        'labels': labels
    }

In [None]:
train_loader = DataLoader(train_set,
                          batch_size = 32,
                          shuffle = True,
                          collate_fn = collate_fn)

val_loader = DataLoader(val_set,
                          batch_size = 32,
                          shuffle = True,
                          collate_fn = collate_fn)

test_loader = DataLoader(test_set,
                          batch_size = 32,
                          shuffle = False,
                          collate_fn = collate_fn)

In [None]:
# Test the batch.

for batch in train_loader:
    print(batch)
    break

In [None]:
print(f'Number of batches in the Training Set: {len(train_loader)}, Test Set: {len(test_loader)} and Val Set: {len(val_loader)}')

### Router Definition

In [None]:
class Router(nn.Module):
    
    def __init__(self, hidden=768, n_experts=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 2, hidden // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden // 4, n_experts)
        )
        
    def forward(self, fused):          # (B, 1536)
        return self.net(fused)

In [None]:
# Code for testing saved router checkpoint on other test sets.
# Comment this block when Traininig.

router = Router(n_experts=len(expert_names)).to(device)
print("Router Initialized ✓")

In [None]:
param_count = sum(p.numel() for p in router.parameters() if p.requires_grad)
print(f"Trainable parameters: {param_count}")

In [None]:
def mean_pooling(last_hidden_state, attention_mask):
    mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    summed         = torch.sum(last_hidden_state * mask_expanded, dim=1)
    counts         = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
    return summed / counts  # (B, D)

In [None]:
# List parameters of both the router and the sbert

optimizer = AdamW(
    list(sbert_model.parameters()) + list(router.parameters()),
    lr = 5e-6
)

loss_func = nn.CrossEntropyLoss()
epochs = 20

### Train and Evaluate:

In [None]:
def evaluate(loader):
    
    sbert_model.eval()
    router.eval()
    correct = total = 0
    
    with torch.inference_mode():
        
        for batch in tqdm(loader):
            questions, labels = batch["questions"], batch["labels"].to(device)
            
            # tokenize & embed
            enc = tokenizer(
                questions,
                padding=True,
                truncation=True,
                return_tensors="pt"
            ).to(device)
            
            out    = sbert_model(**enc)
            embeds = mean_pooling(out.last_hidden_state, enc["attention_mask"])
            
            # router forward
            preds = router(embeds).argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)
            
    return correct / total

In [None]:
print("[•] Training router …")
best = 0.0

In [None]:
for epoch in range(epochs):
    
    sbert_model.train()
    router.train()

    for step, batch in enumerate(train_loader):
        questions, labels = batch["questions"], batch["labels"].to(device)

        # inline get_text_repr
        enc = tokenizer(
            questions,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(device)
        
        out    = sbert_model(**enc)
        embeds = mean_pooling(out.last_hidden_state, enc["attention_mask"])

        # forward + loss (float32)
        logits = router(embeds)
        loss   = loss_func(logits, labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % 40 == 0:
            print(f"Epoch {epoch+1} Step {step:4d} → loss {loss.item():.4f}")

    val_acc = evaluate(val_loader)
    print(f"\n|| Epoch {epoch+1} → val acc = {val_acc:.3%} ||\n")

    if val_acc > best:
        best = val_acc
        torch.save({
            "sbert_model": sbert_model.state_dict(),
            "router":      router.state_dict(),
            "optimizer":   optimizer.state_dict()
        }, "best_joint.pt")

In [None]:
test_acc = evaluate(test_loader)
print(f'Test Set accuracy of the Router is: {test_acc*100:.4f}')