In [None]:
import os
import json
import torch
import pickle
import datetime
from torch import nn
from functools import partial

In [None]:
from PIL import Image
from peft import LoraConfig
from torch.optim import AdamW

In [None]:
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig

In [None]:
from tqdm.auto import tqdm
from collections import Counter
from util.vision_util import process_vision_info
from util.logutil import init_logger, get_logger

In [None]:
device = "cuda:0"

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Dataset Loading:

In [None]:
# UnPickling (DeSerialization)

with open('/home/aritrad/moe-directory/moe-datasets/TDIUC/prototype-train-set-8k-for-router-machine-automatic-llama3.2-annotation.pickle', 'rb') as file:
    mixed_reasoning_data_prototype = pickle.load(file)

In [None]:
mixed_reasoning_data_prototype[0:2], len(mixed_reasoning_data_prototype)
# Count reasoning types
reasoning_counts = Counter(item["reasoning_type"] for item in mixed_reasoning_data_prototype)
reasoning_counts

### Subsettting Dataset:

In [None]:
"""# Subsetting for testing

mixed_reasoning_data_prototype = mixed_reasoning_data_prototype[:500]"""

In [None]:
print(f'Length of mixed reasoing dataset: {len(mixed_reasoning_data_prototype)}')

In [None]:
from datasets import Dataset

In [None]:
image_folder_path = "/home/aritrad/moe-directory/moe-datasets/TDIUC/TDIUC/Images/val2014"
prefix = "Generate a one word answer for the given image and question: "

In [None]:
expert_names = ["Physical Reasoning.", "Quantity Reasoning.", "Spatial Reasoning.", "Social and Emotional Reasoning."]
label2id = {name: idx for idx, name in enumerate(expert_names)}

In [None]:
# Using list comprehension to update reasoning_type.

mixed_reasoning_data_prototype = [
    {**item, 'reasoning_type': label2id[item['reasoning_type']]}
    for item in mixed_reasoning_data_prototype
]

In [None]:
mixed_reasoning_data_prototype[0]

In [None]:
listToDictionary = {
    'question': [ prefix + dict_['question'] for dict_ in mixed_reasoning_data_prototype ], 
    'image': [ os.path.join(image_folder_path, dict_['image_id']) for dict_ in mixed_reasoning_data_prototype ],
    'answer': [ dict_['answer'] for dict_ in mixed_reasoning_data_prototype ], 
    'expert_labels': [ dict_['reasoning_type'] for dict_ in mixed_reasoning_data_prototype ], 
}

mixed_reasoning_data = Dataset.from_dict(listToDictionary)

In [None]:
len(mixed_reasoning_data)

In [None]:
# Split into Train and Val Set

split = mixed_reasoning_data.train_test_split(test_size=0.2, seed=42)

In [None]:
train_set = split['train']
val_set = split['test']

In [None]:
len(train_set), len(val_set)

### Creating JSON of the Qwen Format

In [None]:
def produceFormattedJSON(targetSet):
    
    formattedJSON = list()

    for idx in tqdm(range(len(targetSet))):
        currentJSON =   {
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "image",
                                "image": f"{targetSet[idx]['image']}"
                            },
                            {"type": "text", "text": f"{targetSet[idx]['question']}"}
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": f"{targetSet[idx]['answer']}"}
                        ]
                    }
                ]
            }  
        formattedJSON.append(currentJSON)
        
    return formattedJSON

In [None]:
formattedJSONTrain = produceFormattedJSON(train_set)
formattedJSONVal = produceFormattedJSON(val_set)

In [None]:
# Save to a JSON file
output_file_train = "/home/aritrad/moe-directory/moe-datasets/TDIUC/trash/mixed-train.json"
output_file_val = "/home/aritrad/moe-directory/moe-datasets/TDIUC/trash/mixed-val.json"

# Use `indent` for pretty printing
with open(output_file_train, "w") as file1, open(output_file_val, 'w') as file2:
    json.dump(formattedJSONTrain, file1, indent=4)  
    json.dump(formattedJSONVal, file2, indent=4)  

print(f"Data saved to: \n{output_file_train} || {output_file_val}")

In [None]:
output_dir = f'train_output/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}/'
init_logger(output_dir)
logger = get_logger()

device = "cuda:0"

### Dataset Class

In [None]:
from torch.utils.data import Dataset

In [None]:
class mixedTrainData(Dataset):
    
    def __init__(self, msg_path, orig_dataset):
        with open(msg_path) as f:
            self.msgs = json.load(f)
        self.orig = orig_dataset   # HF Dataset with 'expert_labels'

    def __len__(self):
        return len(self.msgs)

    def __getitem__(self, idx):
        entry = self.msgs[idx]
        # copy the messages
        out = {"messages": entry["messages"]}
        # pull the label from the original
        out["expert_label"] = self.orig[idx]["expert_labels"]
        return out

In [None]:
train_dataset = mixedTrainData(output_file_train, train_set)
val_dataset = mixedTrainData(output_file_val, val_set)

In [None]:
train_dataset[0]

In [None]:
def raw_collate(batch):
    return {
        "messages":     [ex["messages"]     for ex in batch],
        "expert_label": torch.tensor([ex["expert_label"] for ex in batch], dtype=torch.long),
    }

In [None]:
batch_ = 6

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size = batch_,
    shuffle = True,
    collate_fn = raw_collate,
)

val_loader = DataLoader(
    val_dataset,
    batch_size = batch_,
    shuffle = True,
    collate_fn = raw_collate,
)

In [None]:
# Batch Test

for batch in val_loader:
    print(batch)
    break

In [None]:
print(f'Total number of steps in: \nTrain Loader: {len(train_loader)}\nVal Loader: {len(val_loader)}')

### Helper functions

In [None]:
def find_assistant_content_sublist_indexes(l):
    '''
    A message from train_data/data.json may look like below:
        {
            "messages": [
                {'role': 'user', 'content': [{'type': 'image', 'image': 'train_data/1.jpeg'}, {'type': 'text', 'text': '描述一下这个图片'}]}, 
                {'role': 'assistant', 'content': [{'type': 'text', 'text': '这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子，坐在沙滩上，与她的金毛犬互动。她们的手臂伸展着，似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空，阳光洒在沙滩上，营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。'}]}
            ]
        }
    After apply_chat_template, the text will look like below:
        ['<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>描述一下这个图片<|im_end|>\n<|im_start|>assistant\n这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子，坐在沙滩上，与她的金毛犬互动。她们的手臂伸展着，似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空，阳光洒在沙滩上，营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。<|im_end|>\n']

    This function tries to find the indexes of the assistant content in the input_ids list to build labels.
    '''
    start_indexes = []
    end_indexes = []

    # Iterate through the list to find starting points
    for i in range(len(l) - 1):
        # Check if the current and next elements form the start sequence
        if l[i] == 151644 and l[i+1] == 77091 and l[i+2] == 198:
            start_indexes.append(i+3)
            # Now look for the first 151645 and 198 after the start
            for j in range(i+3, len(l)-1):
                if l[j] == 151645 and l[j+1] == 198:
                    end_indexes.append(j+2) # **NOTE** the <|im_end|>\n 2 tokens should be included in the label, so that model can predicate end of output.
                    break  # Move to the next start after finding the end

    return list(zip(start_indexes, end_indexes))

In [None]:
@torch.no_grad()
def get_router_inputs(batch):
    """
    Extracts the question text from each example's `messages` and returns
    SBERT embeddings of shape (B, 768) on `device`.
    
    batch: dict with key "messages": list[list[dict]]
    """
    # For each example, the first message is the user prompt with content list
    questions = [
        next(item["text"] for item in msgs[0]["content"] if item["type"] == "text")
        for msgs in batch["messages"]
    ]
    # SBERT under no_grad, returns tensor on device
    return sbert.encode(questions, convert_to_tensor=True, device=device)

### Define Router & Load Chkpt.

In [None]:
class Router(nn.Module):
    
    def __init__(self, hidden=768, n_experts=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden, hidden//2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden//2, hidden//4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden//4, n_experts)
        )
        
    def forward(self, x): 
        return self.net(x)

In [None]:
# Code for testing saved router checkpoint on other test sets.
router = Router(n_experts=len(expert_names)).to(device)

# SBERT for frozen text embeddings
sbert = SentenceTransformer("all-mpnet-base-v2", device=device)

for p in sbert.parameters():
    p.requires_grad_(False)
    
# Load the Router chceckpoint.
checkpoint = torch.load('/home/aritrad/moe-directory/moe-datasets/TDIUC/custom-moe/using-sbert/router_best.pt') 
router.load_state_dict(checkpoint)

print("Router Initialized ✓")

### Qwen Backbone Model Loading:

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_compute_dtype = torch.bfloat16,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_use_double_quant = True,
)

In [None]:
backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct", 
    attn_implementation = "flash_attention_2", 
    torch_dtype=torch.bfloat16, 
    device_map="auto", 
    quantization_config = bnb_config,
)

In [None]:
# Load processor. 
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=256*28*28, max_pixels=512*28*28, padding_side="left", use_fast = True)

In [None]:
# LORA Adapters will be trained as usual. (Only the Qwen Backbone will be freezed.)

backbone.eval()
for p in backbone.parameters(): 
    p.requires_grad_(False)  

### Load Experts:

In [None]:
# Load & **unfreeze** each LoRA adapter

ADAPTER_ROOT = '/home/aritrad/moe-directory/moe-datasets/TDIUC/best_experts-r32-2.5/'

expert_names = ["physical", "quantitative", "spatial", "social"]

for name in expert_names:
    path = os.path.join(ADAPTER_ROOT, name)
    backbone.load_adapter(path, adapter_name=name, is_trainable=True)

### Hyperparams:

In [None]:
from torch.nn.utils   import clip_grad_norm_

In [None]:
# two‑group optimiser: router vs adapters
EPOCHS       = 5
router_lr   = 1e-5
adapter_lr  = 5e-6               # lower step helps convergence
weight_decay = 1e-2

adapter_params = [p for p in backbone.parameters() if p.requires_grad]
optimizer = AdamW(
    [
        {"params": router.parameters(),  "lr": router_lr},
        {"params": adapter_params,       "lr": adapter_lr},
    ],
    weight_decay=weight_decay,
)

ce_router  = nn.CrossEntropyLoss()   # unchanged

In [None]:
# Plateau scheduler configured to watch *adapter* group (index 1)
def scheduler_step(val_loss):
    plateau.step(val_loss)
    optimizer.param_groups[0]["lr"] = 1e-5


plateau = ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=1, min_lr=5e-7, verbose=True
)

### Training & Validation:

In [None]:
@torch.no_grad()
def validate():
    """
    Hard‑routing evaluation: router CE + weighted expert x‑entropy.
    No scheduler step inside this function.
    """
    backbone.eval()
    router.eval()
    total_val = 0.0
    batch_sz  = val_loader.batch_size

    for batch in tqdm(val_loader, leave=False):
        # ---- router part -------------------------------------------------
        emb    = get_router_inputs(batch)
        logits = router(emb)
        r_loss = ce_router(logits, batch["expert_label"].to(device))

        # ---- expert part -------------------------------------------------
        gate_idx     = logits.argmax(dim=-1)       # (B,)
        exp_loss_acc = 0.0

        for i, expert in enumerate(expert_names):
            sel = (gate_idx == i).nonzero(as_tuple=True)[0]
            if sel.numel() == 0:
                continue
            backbone.set_adapter(expert)

            msgs = [batch["messages"][j] for j in sel.tolist()]

            # a) text prompt
            chat_texts = [
                processor.apply_chat_template(m, tokenize=False, add_generation_prompt=False)
                for m in msgs
            ]
            # b) vision tensors
            image_inputs, video_inputs = process_vision_info(msgs)
            inputs = processor(
                text   = chat_texts,
                images = image_inputs,
                videos = video_inputs,
                padding=True,
                return_tensors="pt",
            ).to(device)

            # c) labels
            lbl = []
            for ids in inputs.input_ids.tolist():
                msk = [-100]*len(ids)
                for s,e in find_assistant_content_sublist_indexes(ids):
                    msk[s:e] = ids[s:e]
                lbl.append(msk)
            label_ids = torch.tensor(lbl, dtype=torch.long, device=device)

            out = backbone(**inputs, labels=label_ids)

            # weight by share of samples this expert handled
            exp_loss_acc += (sel.numel() / batch_sz) * out.loss

        total_val += (r_loss + exp_loss_acc).item()

    return total_val / len(val_loader)

In [None]:
def train_epoch():
    """
    One epoch of joint router‑adapter training with:
      • hard routing
      • per‑sample weighted expert loss
      • gradient clipping
    """
    backbone.train()
    router.train()
    total_loss = 0.0
    batch_sz   = train_loader.batch_size

    for step, batch in enumerate(train_loader, 1):
        # 1) Router forward
        emb     = get_router_inputs(batch)                 # (B,768)
        logits  = router(emb)                              # (B,4)
        r_loss  = ce_router(logits, batch["expert_label"].to(device))

        # 2) Expert forward (hard routing)
        gate_idx     = logits.argmax(dim=-1)               # (B,)
        exp_loss_acc = 0.0       # weighted sum over experts
        seen_samples = 0         # how many samples contributed

        for i, expert in enumerate(expert_names):
            sel = (gate_idx == i).nonzero(as_tuple=True)[0]
            if sel.numel() == 0:
                continue

            backbone.set_adapter(expert)
            seen_samples += sel.numel()

            group_msgs = [batch["messages"][j] for j in sel.tolist()]

            # 2a) build chat text (teacher forcing)
            chat_texts = [
                processor.apply_chat_template(
                    m, tokenize=False, add_generation_prompt=False
                )
                for m in group_msgs
            ]
            # 2b) vision‑to‑tensor
            image_inputs, video_inputs = process_vision_info(group_msgs)
            inputs = processor(
                text   = chat_texts,
                images = image_inputs,
                videos = video_inputs,
                padding=True,
                return_tensors="pt",
            ).to(device)

            # 2c) build labels
            lbl = []
            for ids in inputs.input_ids.tolist():
                msk = [-100] * len(ids)
                for s, e in find_assistant_content_sublist_indexes(ids):
                    msk[s:e] = ids[s:e]
                lbl.append(msk)
            label_ids = torch.tensor(lbl, dtype=torch.long, device=device)

            # 2d) forward & accumulate *scaled* by (#samples / batch_size)
            out  = backbone(**inputs, labels=label_ids)
            exp_loss_acc += (sel.numel() / batch_sz) * out.loss

        # 3) Total loss = router + weighted expert
        loss = r_loss + exp_loss_acc
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(adapter_params + list(router.parameters()), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

        # ---- logging every 100 steps
        if step % 100 == 0:
            print(f"[step {step:4d}] router={r_loss.item():.4f}  "
                  f"expert={exp_loss_acc.item():.4f}")

    return total_loss / len(train_loader)

### Driver Code

In [None]:
pwd

In [None]:
# Paths
OUTPUT_DIR   = "/home/aritrad/moe-directory/moe-datasets/TDIUC/custom-moe/using-sbert/moe-end2end"

In [None]:
best_val = float("inf")

for epoch in tqdm(range(1, EPOCHS + 1)):
    
    tr_loss  = train_epoch()
    val_loss = validate()
    scheduler_step(val_loss)

    print(f"Epoch {epoch}  train_loss={tr_loss:.4f}  val_loss={val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss

        ckpt_root = os.path.join(OUTPUT_DIR, "best-adapters-2.5-7B")
        os.makedirs(ckpt_root, exist_ok=True)

        for name in expert_names:

            # Activate that particular adapter before saving.
            backbone.set_adapter(name)                             
            
            out_dir = os.path.join(ckpt_root, name)
            os.makedirs(out_dir, exist_ok=True)

            # Save only the selected adapter in the loop currently.
            backbone.save_pretrained(out_dir, adapter_name=name)
            
        router_path = os.path.join(OUTPUT_DIR, "best-router-2.5-7B.pt")
        torch.save(router.state_dict(), router_path)
        print(f" ↳ Saved checkpoint @ epoch {epoch}, val_loss={best_val:.4f}")