In [1]:
import json
import sys
from collections import OrderedDict
from contextlib import contextmanager
import time

from src.arguments import ModelArguments, DataArguments, TrainingArguments
from transformers import HfArgumentParser, AutoConfig

from src.model.model import MMEBModel
from src.data.dataset.mmeb_dataset import EvalDataset
from src.data.collator.eval_collator import EvalCollator
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import numpy as np
import pickle
import os
from datasets import load_dataset
from evaluation.mmeb_baselines.eval_utils import get_pred
from src.utils import print_rank
from src.model.processor import get_backbone_name, load_processor, COLPALI
from torch.nn.utils.rnn import pad_sequence
import shutil 

  from .autonotebook import tqdm as notebook_tqdm
[2025-11-13 17:41:53,569] DEBUG [datasets:54] PyTorch version 2.7.1 available.


In [2]:
model_args = ModelArguments(
    # model_name="apple/FastVLM-0.5B",
    model_name="dangnguyens1/sft-fastvlm-2e",
    lora=True,
    # model_name="raghavlite/B3_Qwen2_2B",
    pooling="eos",
    normalize=True
)

data_args = DataArguments(
    encode_output_path="encode_output_path",
    dataset_name="TIGER-Lab/MMEB-eval",
    subset_name=["WebQA"],
    dataset_split="test",
    tgt_prefix_mod=True,
    image_dir="../VLMEmbed/eval-data",
)

training_args = TrainingArguments(
    per_device_eval_batch_size=2,
)

os.makedirs(data_args.encode_output_path, exist_ok=True)

hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
    model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
    setattr(model_args, 'model_backbone', model_backbone)
    setattr(training_args, 'model_backbone', model_backbone)
print_rank(f'model_backbone: {model_args.model_backbone}')
processor = load_processor(model_args, data_args)
model = MMEBModel.build(model_args, is_trainable=True)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
model = model.to(training_args.device)

eval_collator = EvalCollator(
    data_args=data_args,
    model_args=model_args,
    processor=processor,
)

[2025-11-13 17:41:57,215] DEBUG [urllib3.connectionpool:1049] Starting new HTTPS connection (1): huggingface.co:443
[2025-11-13 17:41:57,550] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-2e/resolve/main/config.json HTTP/1.1" 307 0
[2025-11-13 17:41:57,559] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/dangnguyens1/sft-fastvlm-2e/fd902c53ee56461d6859fe863529c73551f2b683/config.json HTTP/1.1" 200 0
[2025-11-13 17:41:57,813] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-2e/resolve/main/llava_qwen.py HTTP/1.1" 307 0
[2025-11-13 17:41:57,834] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/dangnguyens1/sft-fastvlm-2e/fd902c53ee56461d6859fe863529c73551f2b683/llava_qwen.py HTTP/1.1" 200 0
  @register_model
[2025-11-13 17:41:58,618] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangn

Detected model type: llava_qwen2
Determined model backbone: llava_qwen2
Processor load here for LLAVA-QWEN2


[2025-11-13 17:41:58,919] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-2e/resolve/main/tokenizer_config.json HTTP/1.1" 307 0
[2025-11-13 17:41:58,933] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/dangnguyens1/sft-fastvlm-2e/fd902c53ee56461d6859fe863529c73551f2b683/tokenizer_config.json HTTP/1.1" 200 0
[2025-11-13 17:41:59,258] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "GET /api/models/dangnguyens1/sft-fastvlm-2e/tree/main/additional_chat_templates?recursive=False&expand=False HTTP/1.1" 404 64
[2025-11-13 17:41:59,906] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-2e/resolve/main/config.json HTTP/1.1" 307 0
[2025-11-13 17:41:59,916] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/dangnguyens1/sft-fastvlm-2e/fd902c53ee56461d6859fe863529c73551f2b683/config.json HTTP/1.1" 200 0
[20

Detected model type: llava_qwen2
Determined model backbone: llava_qwen2


[2025-11-13 17:42:00,876] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /apple/FastVLM-0.5B/resolve/main/model.safetensors HTTP/1.1" 302 0
[2025-11-13 17:42:01,929] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /apple/FastVLM-0.5B/resolve/main/generation_config.json HTTP/1.1" 307 0
[2025-11-13 17:42:01,938] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/apple/FastVLM-0.5B/16375720c2d673fa583e57e9876afde27549c7d0/generation_config.json HTTP/1.1" 200 0
[2025-11-13 17:42:02,194] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /apple/FastVLM-0.5B/resolve/main/custom_generate/generate.py HTTP/1.1" 404 0
[2025-11-13 17:42:02,465] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-2e/resolve/main/adapter_config.json HTTP/1.1" 307 0
[2025-11-13 17:42:02,479] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cach

In [4]:
for a, b in model.encoder.named_parameters():
    if b.requires_grad:
        print(a)

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.q_proj.lora_magnitude_vector.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_magnitude_vector.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_magnitude_vector.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_magnitude_vector.default.weight
base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight
base_model.model.mod

In [None]:
from transformers import AutoProcessor

repo_id = "dangnguyens1/sft-fastvlm-1e" 
processor = AutoProcessor.from_pretrained(repo_id)

print(f"Processor loaded from {repo_id}")

[2025-11-08 10:31:11,846] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-1e/resolve/main/processor_config.json HTTP/1.1" 404 0
[2025-11-08 10:31:12,156] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /dangnguyens1/sft-fastvlm-1e/resolve/main/preprocessor_config.json HTTP/1.1" 307 0
[2025-11-08 10:31:12,492] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/models/dangnguyens1/sft-fastvlm-1e/eb3df95553c7a0a6cb1667bb715a3db89ba464f4/preprocessor_config.json HTTP/1.1" 200 0
[2025-11-08 10:31:12,496] DEBUG [filelock:331] Attempting to acquire lock 140179351359312 on /home/user2/.cache/huggingface/hub/.locks/models--dangnguyens1--sft-fastvlm-1e/12032f8aaa74bfa77b08688d5508981733fed85e.lock
[2025-11-08 10:31:12,498] DEBUG [filelock:334] Lock 140179351359312 acquired on /home/user2/.cache/huggingface/hub/.locks/models--dangnguyens1--sft-fastvlm-1e/12032f8aaa74bfa77b08688d5508981733fed85e.l

Processor loaded from dangnguyens1/sft-fastvlm-1e


In [7]:
from huggingface_hub import login
token="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"
login(token=token)

ckpt_dir = "test-save"
# processor.tokenizer.save_pretrained(ckpt_dir)
processor.save_pretrained(ckpt_dir)
model.encoder.save_pretrained(ckpt_dir)

[2025-11-08 10:31:22,606] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "GET /api/whoami-v2 HTTP/1.1" 401 47


HTTPError: Invalid user token.

In [None]:

from huggingface_hub import HfApi, HfFolder, Repository, create_repo

def push_to_hub(repo_name=None, token=None, commit_message="Upload model", 
                local_dir="./temp_model", private=False):
    try:
        if not repo_name:
            raise ValueError("must specify a repo name to push to hub")
        
        if not os.path.exists(local_dir):
            raise ValueError(f"local_dir {local_dir} does not exist")
        
        print_rank(f"Pushing model to the hub at {repo_name}...")
        api = HfApi()
        create_repo(repo_name, token=token, private=private, exist_ok=True)
        api.upload_folder(
            folder_path=local_dir,
            repo_id=repo_name, 
            token=token, 
            commit_message=commit_message
        )

        print_rank(f"Model has been pushed to the hub at: {repo_name}")
        return True
        
    except Exception as e:
        print_rank(f"Error pushing to hub: {str(e)}")
        return False

In [None]:
push_to_hub(
    repo_name="dangnguyens1/sft-fastvlm-kd_final_e",
    token=token,
    local_dir="test-save",
    
)

[2025-11-07 10:55:20,306] INFO [src.utils:12] Pushing model to the hub at dangnguyens1/sft-fastvlm-kd_final_e...
[2025-11-07 10:55:20,310] DEBUG [urllib3.connectionpool:289] Resetting dropped connection: huggingface.co
[2025-11-07 10:55:23,132] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "POST /api/repos/create HTTP/1.1" 200 145
[2025-11-07 10:55:23,562] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "POST /api/models/dangnguyens1/sft-fastvlm-kd_final_e/preupload/main HTTP/1.1" 200 878
[2025-11-07 10:55:23,867] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "GET /api/models/dangnguyens1/sft-fastvlm-kd_final_e/revision/main?expand=xetEnabled HTTP/1.1" 200 95
[2025-11-07 10:55:24,162] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "GET /api/models/dangnguyens1/sft-fastvlm-kd_final_e/xet-write-token/main HTTP/1.1" 200 418
Processing Files (2 / 2): 100%|██████████| 20.8MB / 20.8MB,  0.00B/s  
New Data Upload: |          |  0.00

True

In [4]:
POS_MOD_CLASS_LABEL = "Represent the class label: "
POS_MOD_IMAGE_CAPTION = "Represent the image caption: "
POS_MOD_ANSWER = "Represent the answer: "

POS_MOD_DICT = {
                "ImageNet-1K": POS_MOD_CLASS_LABEL,"HatefulMemes":POS_MOD_CLASS_LABEL,"SUN397":POS_MOD_CLASS_LABEL,"N24News":POS_MOD_CLASS_LABEL,"VOC2007":POS_MOD_CLASS_LABEL, "Place365":POS_MOD_CLASS_LABEL,"ImageNet-A":POS_MOD_CLASS_LABEL,"ImageNet-R":POS_MOD_CLASS_LABEL,"ObjectNet":POS_MOD_CLASS_LABEL,"Country211":POS_MOD_CLASS_LABEL,
                
                "OK-VQA":POS_MOD_ANSWER, "A-OKVQA":POS_MOD_ANSWER, "DocVQA":POS_MOD_ANSWER, "InfographicsVQA":POS_MOD_ANSWER, "ChartQA":POS_MOD_ANSWER, "Visual7W":POS_MOD_ANSWER,"ScienceQA":POS_MOD_ANSWER, "GQA":POS_MOD_ANSWER, "TextVQA":POS_MOD_ANSWER, "VizWiz":POS_MOD_ANSWER,
                
                "MSCOCO_i2t":POS_MOD_IMAGE_CAPTION, "VisualNews_i2t":POS_MOD_IMAGE_CAPTION,
                }

eval_qry_dataset = EvalDataset(
    data_args=data_args,
    model_args=model_args,
    subset=data_args.subset_name[0],
    text_field="qry_text",
    img_path_field="qry_img_path",
)
eval_tgt_dataset = EvalDataset(
    data_args=data_args,
    model_args=model_args,
    subset=data_args.subset_name[0],
    text_field="tgt_text",
    img_path_field="tgt_img_path",
    mod_instruction=POS_MOD_DICT.get(data_args.subset_name[0], None) if data_args.tgt_prefix_mod else None
)

[2025-11-09 16:34:32,394] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /datasets/TIGER-Lab/MMEB-eval/resolve/main/README.md HTTP/1.1" 307 0
[2025-11-09 16:34:32,454] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /api/resolve-cache/datasets/TIGER-Lab/MMEB-eval/2f069730be515ea60778413777816b53e2d2a697/README.md HTTP/1.1" 200 0
[2025-11-09 16:34:32,746] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "HEAD /datasets/TIGER-Lab/MMEB-eval/resolve/2f069730be515ea60778413777816b53e2d2a697/MMEB-eval.py HTTP/1.1" 404 0
[2025-11-09 16:34:32,751] DEBUG [urllib3.connectionpool:1049] Starting new HTTPS connection (1): s3.amazonaws.com:443
[2025-11-09 16:34:38,828] DEBUG [urllib3.connectionpool:544] https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/TIGER-Lab/MMEB-eval/TIGER-Lab/MMEB-eval.py HTTP/1.1" 404 0
[2025-11-09 16:34:39,195] DEBUG [urllib3.connectionpool:544] https://huggingface.co:443 "GET /api/datasets/TIG

AttributeError: 'list' object has no attribute 'replace'

In [None]:
eval_qry_dataset[0]

('<image>\nFind a Wikipedia image that answers this question: Does a Minnetonka Rhododendron flower have petals in a cup shape?\n',
 None)

In [4]:
eval_qry_loader = DataLoader(
    eval_qry_dataset,
    batch_size=training_args.per_device_eval_batch_size,
    collate_fn=eval_collator,
    shuffle=False,
    drop_last=False,
    num_workers=0,
)
eval_tgt_loader = DataLoader(
    eval_tgt_dataset,
    batch_size=training_args.per_device_eval_batch_size,
    collate_fn=eval_collator,
    shuffle=False,
    drop_last=False,
    num_workers=0,
)

def batch_to_device(batch, device):
    _batch = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            _batch[key] = value.to(device)
        else:
            _batch[key] = value
    return _batch

In [5]:
for batch in eval_qry_loader:
    print(batch)
    break

<|image_1|>
Find a Wikipedia image that answers this question: Does a Minnetonka Rhododendron flower have petals in a cup shape?
 
<|image_1|>
Find a Wikipedia image that answers this question: What water-related object is sitting in front of the Torre del Reloj?
 
[{'text': ['<image>\nFind a Wikipedia image that answers this question: Does a Minnetonka Rhododendron flower have petals in a cup shape?\n'], 'image': [None]}, {'text': ['<image>\nFind a Wikipedia image that answers this question: What water-related object is sitting in front of the Torre del Reloj?\n'], 'image': [None]}]
{'input_ids': tensor([[  -200,    198,   9885,    264,  26587,   2168,    429,  11253,    419,
           3405,     25,  12553,    264,   3386,   4711,    263,   4554,  17968,
            347,    347,    408,   2248,  22351,    614,  95640,    304,    264,
          10525,   6083,   5267],
        [  -200,    198,   9885,    264,  26587,   2168,    429,  11253,    419,
           3405,     25,   3555,   30

In [11]:
x[0]

{'qry_text': '<|image_1|>\nFind a Wikipedia image that answers this question: Does a Minnetonka Rhododendron flower have petals in a cup shape?\n',
 'qry_img_path': '',
 'tgt_text': ['<|image_1|>\nRepresent the given Wikipedia image with related text information: 2020-05-08 15 17 05 Minnetonka Rhododendron flower along Tranquility Court in the Franklin Farm section of Oak Hill, Fairfax County, Virginia Minnetonka Rhododendron flower along Tranquility Court in the Franklin Farm section of Oak Hill, Fairfax County, Virginia.\n',
  '<|image_1|>\nRepresent the given Wikipedia image with related text information: Louvre - French sculptures - Room 25 - 03.\n',
  '<|image_1|>\nRepresent the given Wikipedia image with related text information: Mission Point Studio building, Mackinac Island, 1960.jpeg Construction of the Mission Point Film Studio and Fine Arts Building was completed on Mackinac Island, MI in 1960. It has been used in filming several films (including Somewhere in Time), as a per

In [6]:
for batch in eval_qry_loader:
    batch = batch_to_device(batch, "cuda")
    with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
    # batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
    
        pooled_output, hidden_states, image_features, all_layers_embeds, attention_matrix = model.encode_input(batch)

/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [30,0,0], thread: [0,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [30,0,0], thread: [1,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [30,0,0], thread: [2,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [30,0,0], thread: [3,0,0] Assertion `ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernelUtils.cu:16: vectorized_gather_kernel: block: [30,0,0], thread: [4,0,0] Assertion 

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
with open(os.path.join("test", "output_t_qry.pkl"), "wb") as f:
    pickle.dump({
        "pooled_output": pooled_output,
        "hidden_states": hidden_states,
        "image_features": image_features,
        "all_layers_embeds": all_layers_embeds,
        "attention_matrix": attention_matrix,
        "input_data": batch,
    }, f)

In [None]:

with open(os.path.join("test", "output_s_qry.pkl"), "rb") as f:
    x = pickle.load(f)

NameError: name 'os' is not defined

In [None]:
with open(os.path.join("test", "output_t_qry.pkl"), "rb") as f:
    y = pickle.load(f)

In [None]:
y['attention_matrix'][0].size()

torch.Size([2, 12, 278, 278])

In [None]:
len(x['image_features'])

2

In [None]:
with open(os.path.join("test", "output_s_qry.pkl"), "wb") as f:
    pickle.dump({
        "pooled_output": pooled_output,
        "hidden_states": hidden_states,
        "image_features": image_features,
        "all_layers_embeds": all_layers_embeds,
        "attention_matrix": attention_matrix,
        "input_data": batch,
    }, f)

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
import gc, torch

# delete references to large tensors/models you created
for name in ("batch", "pooled_output", "image_features", "attention_matrix", "model"):
    if name in globals():
        try:
            del globals()[name]
        except Exception:
            pass

gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

# optional: reset tracking and print status
for i in range(torch.cuda.device_count()):
    torch.cuda.reset_peak_memory_stats(i)
print("allocated:", torch.cuda.memory_allocated(), "cached:", torch.cuda.memory_reserved())

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
with open("/home/user2/dangnh/VLM_Embed/test/output_t_qry.pkl", 'rb') as file:
    # Use pickle.load() to read the byte stream and deserialize the object
    loaded_object = pickle.load(file)
    # load to cuda
    loaded_object = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in loaded_object.items()}
    t_qry_reps, t_qry_hidden_states, t_qry_img_feats, t_qry_layers_embeds, t_qry_attention, t_qry_input_data = \
                                                                                loaded_object.values()
    t_pos_reps, t_pos_hidden_states, t_pos_img_feats, t_pos_layers_embeds, t_pos_attention, t_pos_input_data = \
                                                                                loaded_object.values()
    
with open("/home/user2/dangnh/VLM_Embed/test/output_s_qry.pkl", 'rb') as file:
    # Use pickle.load() to read the byte stream and deserialize the object
    loaded_object = pickle.load(file)
    loaded_object = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in loaded_object.items()}

    s_qry_reps, s_qry_hidden_states, s_qry_img_feats, s_qry_layers_embeds, s_qry_attention, s_qry_input_data = \
                                                                loaded_object.values()
    s_pos_reps, s_pos_hidden_states, s_pos_img_feats, s_pos_layers_embeds, s_pos_attention, s_pos_input_data = \
                                                                loaded_object.values()

NameError: name 'model' is not defined

# Test

In [10]:
import torch
import torch.nn as nn 
import torch.distributed as dist
import torch.nn.functional as F
from src.criterions.soft_DTW import SoftDTW
import ot
# ot.backend.get_backend('pytorch')

def create_semi_orthogonal_matrix(tensor):
    rows, cols = tensor.shape
    if rows >= cols:
        # QR trực tiếp
        a = torch.randn(rows, cols, device=tensor.device, dtype=tensor.dtype)
        q, _ = torch.linalg.qr(a, mode='reduced')
        tensor.data[:] = q[:, :cols]
    else:
        # QR trên ma trận transpose để đảm bảo W W^T = I
        a = torch.randn(cols, rows, device=tensor.device, dtype=tensor.dtype)
        q, _ = torch.linalg.qr(a, mode='reduced')
        tensor.data[:] = q.T[:rows, :]
    return tensor

class Distiller(nn.Module):
    def __init__(self, model_args, training_args, device):
        super(Distiller, self).__init__()
        self.model_args = model_args
        self.training_args = training_args
        self.device = device

        self.student_hidden_dim = 896
        self.teacher_hidden_dim = 1536
        self.temperature = 0.02
        self.set_projector()
        print("Projectors set.")

        self.t2s_img_align = nn.Sequential(
            nn.Linear(self.teacher_hidden_dim, self.student_hidden_dim),
            nn.ReLU()
        )
        self.t2s_img_align.to(device="cuda")

        # for simple kd
        self.last_layer_projector = nn.Sequential(
            nn.Linear(self.teacher_hidden_dim, self.student_hidden_dim),
            nn.ReLU()
        )
        self.last_layer_projector.to(device="cuda")

        # for Soft-DTW
        self.num_chosen_hidden_states = 3
        self.t2s = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.teacher_hidden_dim, self.student_hidden_dim),
                nn.ReLU()
            )       
        ] * self.num_chosen_hidden_states)

        self.t2s.to(device="cuda")
        
    def set_projector(self):
        self.projectors = nn.ModuleDict()
        projector_config = json.load(open("/home/user2/dangnh/VLM_Embed/config/projector_config.json", 'r'))
        
        name_dict = {
            "s": self.student_hidden_dim,
            "t": self.teacher_hidden_dim,
            "relu": nn.ReLU()
        }
        
        for name, cfg in projector_config.items():
            if not cfg.get("enabled", False):
                continue
            seq = nn.Sequential()
            parts = cfg["structure"].split("-")
            parsed = []
            
            for p in parts:
                if p == "relu":
                    parsed.append("relu")
                else:
                    coef = int(p[:-1]) if len(p) > 1 and p[:-1].isdigit() else 1
                    parsed.append(coef * name_dict[p[-1]])
            for i in range(len(parsed) -1):
                a, b = parsed[i], parsed[i+1]
                if isinstance(a, int) and isinstance(b, int):
                    layer = nn.Linear(a, b)
                    create_semi_orthogonal_matrix(layer.weight)
                    seq.append(layer)
                elif b == "relu":
                    seq.append(name_dict[b])
                elif a =="relu" and isinstance(b, int):
                    prev_out = parsed[i-1] if isinstance(parsed[i-1], int) else None
                    layer = nn.Linear(prev_out, b)
                    create_semi_orthogonal_matrix(layer.weight)
                    seq.append(layer)
            self.projectors[name] = seq
            print(f"Projector {name} created with structure: {seq}")
    
    def add_optimizer_param_group(self, optimizer):
        if hasattr(self, 'projectors'):
            lr = 0.001
            optimizer.add_param_group({
                "params": [p for proj in self.projectors.values() for p in proj.parameters()],
                "lr": lr
            })
        print("Projector parameters added to optimizer.")
        return optimizer

class StrongerKD(nn.Module):
    def __init__(self, args):
        super(StrongerKD, self).__init__()
        self.args = args
        self.rkd_loss_weight = 0.5
        self.simple_kd_weight = 0.5
        self.intra_rkd_weight = 0.5
        self.cross_modal_kd_weight = 0.01
        self.ot_loss_weight = 0.5
        self.num_chosen_hidden_states = 3

        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.img_align_loss_weight = 0.1
        self.sdtw = SoftDTW(use_cuda=True, gamma=0.001)
        self.mse_loss = nn.MSELoss(reduction="mean")


    def forward(self, distiller, input_data):
        self.distiller = Distiller(model_args, training_args, "cuda")
        # student_model = distiller.student
        # teacher_model = distiller.teacher
        
        with open("/home/user2/dangnh/VLM_Embed/test/output_t_qry.pkl", 'rb') as file:
            # Use pickle.load() to read the byte stream and deserialize the object
            loaded_object = pickle.load(file)
            # load to cuda
            loaded_object = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in loaded_object.items()}
            t_qry_reps, t_qry_hidden_states, t_qry_img_feats, t_qry_layers_embeds, t_qry_attention, t_qry_input_data = \
                                                                                        loaded_object.values()
            t_pos_reps, t_pos_hidden_states, t_pos_img_feats, t_pos_layers_embeds, t_pos_attention, t_pos_input_data = \
                                                                                        loaded_object.values()
            
            # t_qry_hidden_states, t_qry_img_feats = torch.stack(t_qry_hidden_states, dim=0), torch.stack(t_qry_img_feats, dim=0)
            # t_pos_hidden_states, t_pos_img_feats = torch.stack(t_pos_hidden_states, dim=0), torch.stack(t_pos_img_feats, dim=0)
        with open("/home/user2/dangnh/VLM_Embed/test/output_s_qry.pkl", 'rb') as file:
            # Use pickle.load() to read the byte stream and deserialize the object
            loaded_object = pickle.load(file)
            loaded_object = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in loaded_object.items()}

            s_qry_reps, s_qry_hidden_states, s_qry_img_feats, s_qry_layers_embeds, s_qry_attention, s_qry_input_data = \
                                                                        loaded_object.values()
            s_pos_reps, s_pos_hidden_states, s_pos_img_feats, s_pos_layers_embeds, s_pos_attention, s_pos_input_data = \
                                                                        loaded_object.values()
            # s_qry_hidden_states, s_qry_img_feats = torch.stack(s_qry_hidden_states, dim=0), torch.stack(s_qry_img_feats, dim=0)
            # s_pos_hidden_states, s_pos_img_feats = torch.stack(s_pos_hidden_states, dim=0), torch.stack(s_pos_img_feats, dim=0)

        ## contrastive
        # scores = student_model.compute_similarity(s_qry_reps, s_pos_reps)
        # scores = scores.view(s_qry_reps.size(0), -1)
        # target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
        # target = target * (s_qry_reps.size(0) // s_pos_reps.size(0))
        # contrastive_loss = self.cross_entropy_loss(scores / self.distiller.temperature, target)
        contrastive_loss = torch.tensor(0.0).to("cuda")
        ## image alignments
        img_align_loss = 0.0
        cur_idx_qry_img = 0
        cur_idx_pos_img = 0
        batch_size = s_qry_reps.size(0)

        for i in range(batch_size):
            if s_qry_img_feats is not None and t_qry_img_feats is not None:
                if cur_idx_qry_img < len(s_qry_img_feats) and cur_idx_qry_img < len(t_qry_img_feats):
                    tmp_s_qry_img_feats = F.normalize(s_qry_img_feats[i], p=2, dim=-1)
                    tmp_t_qry_img_feats = self.distiller.t2s_img_align(t_qry_img_feats[i])

                    tmp_t_qry_image_features = F.normalize(tmp_t_qry_img_feats, p=2, dim=-1)

                    img_align_loss += self.alignment_loss_mmd(tmp_t_qry_image_features, tmp_s_qry_img_feats)
                    cur_idx_qry_img += 1

            if s_pos_img_feats is not None and t_pos_img_feats is not None:
                if cur_idx_pos_img < len(s_pos_img_feats) and cur_idx_pos_img < len(t_pos_img_feats):
                    tmp_s_pos_img_feats = F.normalize(s_pos_img_feats[i], p=2, dim=-1)
                    tmp_t_pos_img_feats = self.distiller.t2s_img_align(t_pos_img_feats[i])
                    
                    tmp_t_pos_image_features = F.normalize(tmp_t_pos_img_feats, p=2, dim=-1)

                    img_align_loss += self.alignment_loss_mmd(tmp_t_pos_image_features, tmp_s_pos_img_feats)
                    cur_idx_pos_img += 1

        img_align_loss = img_align_loss / batch_size

        ## data-points rkd 
        # s_qry_reps = F.normalize(s_qry_reps, p=2, dim=-1)
        # s_pos_reps = F.normalize(s_pos_reps, p=2, dim=-1)
        # t_qry_reps = F.normalize(t_qry_reps, p=2, dim=-1)
        # t_pos_reps = F.normalize(t_pos_reps, p=2, dim=-1)

        # qry_distance_loss = self.compute_distance_loss(s_qry_reps, t_qry_reps)
        # pos_distance_loss = self.compute_distance_loss(s_pos_reps, t_pos_reps)
        # distance_loss = 0.5 * qry_distance_loss + 0.5 * pos_distance_loss

        # qry_angle_loss = self.compute_angle_loss(s_qry_reps, t_qry_reps)
        # pos_angle_loss = self.compute_angle_loss(s_pos_reps, t_pos_reps)
        # angle_loss = 0.5 * qry_angle_loss + 0.5 * pos_angle_loss

        # rkd_loss = (0.5 * distance_loss + 0.5 * angle_loss)
        rkd_loss = torch.tensor(0.0)

        ## simple kd
        simple_kd_loss = self.simple_kd_logit_loss(s_qry_reps, s_pos_reps, t_qry_reps, t_pos_reps)

        ## intra rkd
        intra_rkd_loss = self.intra_rkd(t_qry_layers_embeds, t_pos_layers_embeds,
                                        s_qry_layers_embeds, s_pos_layers_embeds)
        
        ## cross modal kd
        # num_s_img_tokens, num_t_img_tokens = s_qry_img_feats.size(1), t_qry_img_feats.size(1)
        qry_cross_modal_kd_loss = self.cross_modal_kd_loss(s_qry_hidden_states,
                                                       t_qry_hidden_states,
                                                       s_qry_img_feats,
                                                       t_qry_img_feats)

        # num_s_img_tokens, num_t_img_tokens = s_pos_img_feats.size(1), t_pos_img_feats.size(1)
        pos_cross_modal_kd_loss = self.cross_modal_kd_loss(s_pos_hidden_states,
                                                       t_pos_hidden_states,
                                                       s_pos_img_feats,
                                                       t_pos_img_feats)
        
        cross_modal_kd_loss = 0.5 * qry_cross_modal_kd_loss + 0.5 * pos_cross_modal_kd_loss

        ## optimal transport loss
        ot_loss = self.compute_ot(s_qry_hidden_states, s_qry_attention,
                                  t_qry_hidden_states, t_qry_attention)
        ot_loss += self.compute_ot(s_pos_hidden_states, s_pos_attention,
                                  t_pos_hidden_states, t_pos_attention)
        ot_loss = ot_loss / 2.0

        total_loss = contrastive_loss + \
                     self.rkd_loss_weight * rkd_loss + \
                     self.simple_kd_weight * simple_kd_loss + \
                     self.intra_rkd_weight * intra_rkd_loss + \
                     self.cross_modal_kd_weight * cross_modal_kd_loss + \
                     self.ot_loss_weight * ot_loss + \
                     self.img_align_loss_weight * img_align_loss
        
        return {
            "total_loss": total_loss,
            "contrastive_loss": contrastive_loss,
            "rkd_loss": rkd_loss,
            "simple_kd_loss": simple_kd_loss,
            "intra_rkd_loss": intra_rkd_loss,
            "cross_modal_kd_loss": cross_modal_kd_loss,
            "ot_loss": ot_loss,
            "img_align_loss": img_align_loss
        }

    def gaussian_kernel(self, x, y, sigma=2.0):
        """
        Computes the RBF (Gaussian) kernel between two sets of vectors.
        k(x, y) = exp(-||x - y||^2 / (2 * sigma^2))
        
        Args:
            x (torch.Tensor): Shape (n, dim)
            y (torch.Tensor): Shape (m, dim)
            sigma (float): Kernel bandwidth.
        """
        beta = 1.0 / (2.0 * (sigma ** 2))
        # (n, m) matrix of squared pairwise distances
        dist_sq = torch.cdist(x.unsqueeze(0), y.unsqueeze(0), p=2).pow(2)
        return torch.exp(-beta * dist_sq)

    def alignment_loss_mmd(self, t_feats, s_feats, sigma=2.0):
        """
        Computes the Maximum Mean Discrepancy (MMD) loss using a Gaussian kernel.

        Args:
            x_teacher (torch.Tensor): Teacher features, shape (n, dim)
            x_student (torch.Tensor): Student features, shape (m, dim)
            sigma (float): Kernel bandwidth.
        """
        
        # Compute kernel matrices
        k_tt = self.gaussian_kernel(t_feats, t_feats, sigma) # (n, n)
        k_ss = self.gaussian_kernel(s_feats, s_feats, sigma) # (m, m)
        k_ts = self.gaussian_kernel(t_feats, s_feats, sigma) # (n, m)
        
        # This is the (biased) MMD^2 statistic
        # E[k(t, t')] + E[k(s, s')] - 2 * E[k(t, s)]
        mmd_loss = k_tt.mean() + k_ss.mean() - 2 * k_ts.mean()
        
        return mmd_loss
    
    def compute_ot(self, s_hidden_states, s_attn, t_hidden_states, t_attn):
        
        loss = 0.0
        num_student_layers = len(s_hidden_states)
        num_teacher_layers = len(t_hidden_states)
        scale = num_teacher_layers // num_student_layers
        start_layer = num_student_layers - self.distiller.num_chosen_hidden_states

        for l in range(start_layer, num_student_layers):
            s_dist = F.softmax(s_attn[l - 1].mean(dim=1)[:, -1], dim=-1) # (b, n)
            t_dist = F.softmax(t_attn[l - 1].mean(dim=1)[:, -1], dim=-1) # (b, m)

            s_hidden_state = s_hidden_states[l] # (b, n, emb_dim)
            proj_t_hidden_state = self.distiller.t2s[l - start_layer](t_hidden_states[scale * l]) # (b, m, emb_dim)

            for b in range(s_dist.size(0)):
                cost_matrix = 1 - torch.matmul(s_hidden_state[b], proj_t_hidden_state[b].T) # (n, m)
                
                cost_matrix = cost_matrix / cost_matrix.mean()

                transport = self.sinkhorn(s_dist[b], t_dist[b], cost_matrix) 
                loss += torch.sum(transport * cost_matrix)

        return loss
    
    def sinkhorn(self, a, b, cost_matrix, reg=0.1, num_iters=100, eps=1e-9, stopThr = 1e-7):
        """
        a: (m,) or (m,1) torch tensor (source weights)
        b: (n,) or (n,1) torch tensor (target weights)
        cost_matrix: (m, n) torch tensor
        reg: regularization (>=0) -- larger reg -> smoother K = exp(-C/reg)
        num_iters: number of Sinkhorn iterations
        """
        device = cost_matrix.device
        dtype = cost_matrix.dtype

        a = a.view(-1, 1)
        b = b.view(-1, 1)
        C = cost_matrix

        m, n = C.shape
        if m == 0 or n == 0:
            return torch.zeros((m, n), device=device, dtype=dtype)

        # ensure shapes
        if a.shape[0] != m:
            a = torch.ones((m, 1), device=device, dtype=dtype) / m
        if b.shape[0] != n:
            b = torch.ones((n, 1), device=device, dtype=dtype) / n

        suma = a.sum()
        sumb = b.sum()
        if suma <= eps or sumb <= eps:
            a = torch.ones((m, 1), device=device, dtype=dtype) / m
            b = torch.ones((n, 1), device=device, dtype=dtype) / n
        else:
            a = a / suma
            b = b / sumb

        K = torch.exp(-C / (reg + 1e-12))

        u = torch.ones((m, 1), device=device, dtype=dtype)
        v = torch.ones((n, 1), device=device, dtype=dtype)

        for i in range(num_iters):
            u_prev = u.clone()
            KTv = (K.t() @ u)  # shape (n,1)
            v = b / (KTv + eps)
            Kv = (K @ v)       # shape (m,1)
            u = a / (Kv + eps)

            err = torch.max(torch.abs(u - u_prev))
            if err.item() < stopThr:
                break

        # transport plan
        U = torch.diag_embed(u.squeeze())   # (m,m) diag(u)
        V = torch.diag_embed(v.squeeze())   # (n,n) diag(v)
        P = U @ K @ V                       # (m,n)
        return P
    
    def cross_modal_kd_loss(self, s_hidden_states, t_hidden_states, s_img_feats, t_img_feats):
        """
            hidden_states: list of (n_layers, b, n, dim)
            img_feats: (b, n_img_tokens, dim)
        """

        loss = 0.0
        cur_idx_img = 0
        num_student_layers = len(s_hidden_states)
        num_teacher_layers = len(t_hidden_states)
        scale = num_teacher_layers // num_student_layers
        batch_size = s_hidden_states[0].size(0)
        start_layer = num_student_layers - self.distiller.num_chosen_hidden_states

        if s_img_feats is None or t_img_feats is None:
            return loss
        
        if s_img_feats is not None and t_img_feats is not None:
            for b in range(batch_size):
                if cur_idx_img < len(s_img_feats) and cur_idx_img < len(t_img_feats):
                    for l in range(start_layer, num_student_layers):

                        num_s_img_tokens = s_img_feats[b].size(0)
                        num_t_img_tokens = t_img_feats[b].size(0)

                        s_img_hidden_states = F.normalize(s_hidden_states[l][b][:num_s_img_tokens]).to(torch.float32)
                        s_text_hidden_states = F.normalize(s_hidden_states[l][b][num_s_img_tokens:]).to(torch.float32)

                        proj_t_img_hidden_states = F.normalize(self.distiller.t2s[l - start_layer](t_hidden_states[scale * l][b][:num_t_img_tokens])).to(torch.float32)
                        proj_t_text_hidden_states = F.normalize(self.distiller.t2s[l - start_layer](t_hidden_states[scale * l][b][num_t_img_tokens:])).to(torch.float32)
                        
                        
                        loss += 0.5 * self.sdtw(s_img_hidden_states.unsqueeze(0), proj_t_text_hidden_states.unsqueeze(0)).mean()
                        loss += 0.5 * self.sdtw(s_text_hidden_states.unsqueeze(0), proj_t_img_hidden_states.unsqueeze(0)).mean()
                    cur_idx_img += 1
        loss = loss.to(torch.bfloat16)
        return loss / batch_size
    
    def simple_kd_logit_loss(self, student_qry_reps, student_pos_reps, teacher_qry_reps, teacher_pos_reps):
            projector_teacher_qry_reps = self.distiller.last_layer_projector(teacher_qry_reps)
            projector_teacher_pos_reps = self.distiller.last_layer_projector(teacher_pos_reps)

            loss = (
                    self.mse_loss(student_qry_reps, projector_teacher_qry_reps) +  
                    self.mse_loss(student_pos_reps, projector_teacher_pos_reps)
                   ) / 2.0
            return loss
    
    def intra_rkd(self, 
                  teacher_qry_layers_embeds, # (b, n_layers, dim), 
                  teacher_pos_layers_embeds,
                  student_qry_layers_embeds,
                  student_pos_layers_embeds):
        
        loss = 0.0
        batch_size = student_pos_layers_embeds.size(0)

        for b in range(batch_size):

            qry_dist_loss = self.compute_distance_loss(student_qry_layers_embeds[b], teacher_qry_layers_embeds[b])
            pos_dist_loss = self.compute_distance_loss(student_pos_layers_embeds[b], teacher_pos_layers_embeds[b])
            dist_loss = 0.5 * qry_dist_loss + 0.5 * pos_dist_loss
            
            qry_angle_loss = self.compute_angle_loss(student_qry_layers_embeds[b], teacher_qry_layers_embeds[b])
            pos_angle_loss = self.compute_angle_loss(student_pos_layers_embeds[b], teacher_pos_layers_embeds[b])
            angle_loss = 0.5 * qry_angle_loss + 0.5 * pos_angle_loss

            loss += 0.5 * dist_loss + 0.5 * angle_loss

        return loss / batch_size

    def pairwise_distance(self, x):
        norm = (x**2).sum(dim=1, keepdim=True)
        dist = norm + norm.t() - 2.0 * torch.mm(x, x.t())
        return dist
    
    def compute_distance_loss(self, student_repr, teacher_repr):
        
        num_student_layers = student_repr.size(0)
        num_teacher_layers = teacher_repr.size(0)
        scale = num_teacher_layers // num_student_layers

        teacher_repr = teacher_repr[
            torch.tensor([i * scale for i in range(num_student_layers)], device=teacher_repr.device)
        ]

        dist_student = self.pairwise_distance(student_repr)
        dist_teacher = self.pairwise_distance(teacher_repr)
        
        mask = torch.triu(torch.ones_like(dist_student), diagonal=1).bool()
        dist_student = dist_student[mask]
        dist_teacher = dist_teacher[mask]
        
        mean_td = dist_teacher.mean().detach() + 1e-8
        mean_sd = dist_student.mean().detach() + 1e-8
        
        dist_student = dist_student / mean_sd
        dist_teacher = dist_teacher / mean_td
        
        diff = dist_student - dist_teacher
        abs_diff = torch.abs(diff)
        quadratic = 0.5 * (abs_diff ** 2)
        linear = abs_diff - 0.5
        
        loss = torch.where(abs_diff < 1.0, quadratic, linear)
        loss = loss.mean()
        return loss
    
    def angle_potentials(self, x):
        n = x.size(0)
        diffs = x.unsqueeze(0) - x.unsqueeze(1)
        norms = torch.norm(diffs, dim=-1, keepdim=True) + 1e-8
        e = diffs / norms
        
        cos_angles = torch.einsum('ijd,kjd->ijk', e, e)
        return cos_angles
    
    def compute_angle_loss(self, student_repr, teacher_repr):
        
        num_student_layers = student_repr.size(0)
        num_teacher_layers = teacher_repr.size(0)
        scale = num_teacher_layers // num_student_layers

        teacher_repr = teacher_repr[
            torch.tensor([i * scale for i in range(num_student_layers)], device=teacher_repr.device)
        ]

        psi_student = self.angle_potentials(student_repr)
        psi_teacher = self.angle_potentials(teacher_repr)
        
        n = psi_student.size(0)
        mask = torch.ones((n, n, n), dtype=torch.bool, device=psi_student.device)
        idx = torch.arange(n, device=psi_student.device)
        mask[idx, idx, :] = 0
        mask[idx, :, idx] = 0
        mask[:, idx, idx] = 0
        
        psi_teacher = psi_teacher[mask]
        psi_student = psi_student[mask]
        
        diff = psi_student - psi_teacher
        abs_diff = torch.abs(diff)
        quadratic = 0.5 * (abs_diff ** 2)
        linear = abs_diff - 0.5
        loss = torch.where(abs_diff < 1.0, quadratic, linear)
        loss = loss.mean()
        return loss
             

In [16]:
criterion = StrongerKD(training_args)

In [17]:
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
    loss = criterion.forward(None, None)

loss

Projector t2s_img created with structure: Sequential(
  (0): Linear(in_features=1536, out_features=896, bias=True)
)
Projector t2s_txt created with structure: Sequential(
  (0): Linear(in_features=1536, out_features=896, bias=True)
)
Projectors set.


[2025-11-08 10:34:22,863] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetCurrent
[2025-11-08 10:34:22,864] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetDevice
[2025-11-08 10:34:22,865] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuPointerGetAttribute
[2025-11-08 10:34:22,866] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetCurrent
[2025-11-08 10:34:22,866] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetDevice
[2025-11-08 10:34:22,867] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuPointerGetAttribute
[2025-11-08 10:34:22,868] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetCurrent
[2025-11-08 10:34:22,868] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetDevice
[2025-11-08 10:34:22,869] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuLaunchKernel
[2025-11-08 10:34:22,872] DEBUG [numba.cuda.cudadrv.driver:325] call driver api: cuCtxGetCurrent
[2025-11-08 10:34:22,8

{'total_loss': tensor(nan, device='cuda:0', grad_fn=<AddBackward0>),
 'contrastive_loss': tensor(0., device='cuda:0'),
 'rkd_loss': tensor(0.),
 'simple_kd_loss': tensor(0.0013, device='cuda:0', grad_fn=<DivBackward0>),
 'intra_rkd_loss': tensor(0.0345, device='cuda:0', grad_fn=<DivBackward0>),
 'cross_modal_kd_loss': tensor(1464., device='cuda:0', dtype=torch.bfloat16, grad_fn=<AddBackward0>),
 'ot_loss': tensor(nan, device='cuda:0', grad_fn=<DivBackward0>),
 'img_align_loss': tensor(0.3848, device='cuda:0', grad_fn=<DivBackward0>)}

In [None]:
def pairwise_distance(x):
        norm = (x**2).sum(dim=1, keepdim=True)
        dist = norm + norm.t() - 2.0 * torch.mm(x, x.t())
        return dist
    
def compute_distance_loss(student_qry, student_pos, teacher_qry, teacher_pos):
    
    num_student_layers = student_qry.size(0)
    num_teacher_layers = teacher_qry.size(0)
    scale = num_teacher_layers // num_student_layers

    student_repr = torch.cat([student_qry, student_pos], dim=1)
    teacher_repr = torch.cat([teacher_qry, teacher_pos], dim=1)[
        torch.tensor([i * scale for i in range(num_student_layers)], device=teacher_qry.device)
    ]

    dist_student = pairwise_distance(student_repr)
    dist_teacher = pairwise_distance(teacher_repr)
    
    mask = torch.triu(torch.ones_like(dist_student), diagonal=1).bool()
    dist_student = dist_student[mask]
    dist_teacher = dist_teacher[mask]
    
    mean_td = dist_teacher.mean().detach() + 1e-8
    mean_sd = dist_student.mean().detach() + 1e-8
    
    dist_student = dist_student / mean_sd
    dist_teacher = dist_teacher / mean_td
    
    diff = dist_student - dist_teacher
    abs_diff = torch.abs(diff)
    quadratic = 0.5 * (abs_diff ** 2)
    linear = abs_diff - 0.5
    
    loss = torch.where(abs_diff < 1.0, quadratic, linear)
    loss = loss.mean()
    return loss

def angle_potentials(x):
    x = torch.clamp(x, min=-1e10, max=1e10)

    n = x.size(0)
    diffs = x.unsqueeze(0) - x.unsqueeze(1)
    norms = torch.norm(diffs, dim=-1, keepdim=True) + 1e-8

    bfloat16_max_safe = 1e38 
    
    safe_norms = torch.where(torch.isinf(norms), 
                             torch.tensor(bfloat16_max_safe, dtype=x.dtype, device=x.device), 
                             norms)
    e = diffs / safe_norms
    
    cos_angles = torch.einsum('ijd,kjd->ijk', e, e)
    return cos_angles

def compute_angle_loss(student_qry, student_pos, teacher_qry, teacher_pos):
    
    num_student_layers = student_qry.size(0)
    num_teacher_layers = teacher_qry.size(0)
    scale = num_teacher_layers // num_student_layers

    student_repr = torch.cat([student_qry, student_pos], dim=1)
    teacher_repr = torch.cat([teacher_qry, teacher_pos], dim=1)[
        torch.tensor([i * scale for i in range(num_student_layers)], device=teacher_qry.device)
    ]

    psi_student = angle_potentials(student_repr)
    psi_teacher = angle_potentials(teacher_repr)
    
    n = psi_student.size(0)
    mask = torch.ones((n, n, n), dtype=torch.bool, device=psi_student.device)
    idx = torch.arange(n, device=psi_student.device)
    mask[idx, idx, :] = 0
    mask[idx, :, idx] = 0
    mask[:, idx, idx] = 0
    
    psi_teacher = psi_teacher[mask]
    psi_student = psi_student[mask]
    
    diff = psi_student - psi_teacher
    abs_diff = torch.abs(diff)
    quadratic = 0.5 * (abs_diff ** 2)
    linear = abs_diff - 0.5
    loss = torch.where(abs_diff < 1.0, quadratic, linear)
    loss = loss.mean()
    return loss

In [68]:
distiller  = Distiller(model_args, training_args, "cuda")

def compute_ot(s_hidden_states, s_attn, t_hidden_states, t_attn):
    
    loss = 0.0
    num_student_layers = len(s_hidden_states)
    num_teacher_layers = len(t_hidden_states)
    scale = num_teacher_layers // num_student_layers
    start_layer = num_student_layers - 3

    for l in range(start_layer, num_student_layers):
        s_dist = F.softmax(s_attn[l - 1].mean(dim=1)[:, -1], dim=-1) # (b, n)
        t_dist = F.softmax(t_attn[l - 1].mean(dim=1)[:, -1], dim=-1) # (b, m)

        s_hidden_state = s_hidden_states[l] # (b, n, emb_dim)
        proj_t_hidden_state = distiller.t2s[l - start_layer](t_hidden_states[scale * l]) # (b, m, emb_dim)

        for b in range(s_dist.size(0)):
            norm_s_hs = torch.norm(s_hidden_state[b], dim=-1, keepdim=True) + 1e-8
            norm_t_hs = torch.norm(proj_t_hidden_state[b], dim=-1, keepdim=True) + 1e-8

            cost_matrix = torch.cdist(norm_s_hs.unsqueeze(0), 
                                      norm_t_hs.unsqueeze(0)).squeeze(0)
            cost_matrix /= cost_matrix.mean().item()

            transport = sinkhorn(s_dist[b], t_dist[b], cost_matrix) 
            print(transport)
            loss += torch.sum(transport * cost_matrix)

            print("Cost mean:", cost_matrix.mean().item())
            print("s_dist sum:", s_dist[b].sum().item())
            print("t_dist sum:", t_dist[b].sum().item())
            print("Transport mean:", transport.mean())
            print("OT cost:", torch.sum(transport * cost_matrix).item())
    return loss / s_dist.size(0), cost_matrix

epsilon = 1e-9
stopThr = 1e-7
sinkhorn_alpha = 0.1

def sinkhorn(a, b, cost_matrix, reg=10, num_iters=100, eps=1e-9):
    """
    a: (m,) or (m,1) torch tensor (source weights)
    b: (n,) or (n,1) torch tensor (target weights)
    cost_matrix: (m, n) torch tensor
    reg: regularization (>=0) -- larger reg -> smoother K = exp(-C/reg)
    num_iters: number of Sinkhorn iterations
    """
    device = cost_matrix.device
    # use float32 for numeric stability
    dtype = torch.float32
    a = a.detach().to(device=device, dtype=dtype).view(-1, 1)
    b = b.detach().to(device=device, dtype=dtype).view(-1, 1)
    C = cost_matrix.detach().to(device=device, dtype=dtype)

    m, n = C.shape
    if m == 0 or n == 0:
        return torch.zeros((m, n), device=device, dtype=dtype)

    # ensure shapes
    if a.shape[0] != m:
        a = torch.ones((m, 1), device=device, dtype=dtype) / m
    if b.shape[0] != n:
        b = torch.ones((n, 1), device=device, dtype=dtype) / n

    suma = a.sum()
    sumb = b.sum()
    if suma <= eps or sumb <= eps:
        a = torch.ones((m, 1), device=device, dtype=dtype) / m
        b = torch.ones((n, 1), device=device, dtype=dtype) / n
    else:
        a = a / suma
        b = b / sumb

    K = torch.exp(-C / (reg + 1e-12))
    K = torch.clamp(K, min=1e-10)

    u = torch.ones((m, 1), device=device, dtype=dtype)
    v = torch.ones((n, 1), device=device, dtype=dtype)

    for i in range(num_iters):
        u_prev = u.clone()
        KTv = (K.t() @ u)  # shape (n,1)
        v = b / (KTv + eps)
        Kv = (K @ v)       # shape (m,1)
        u = a / (Kv + eps)

        err = torch.max(torch.abs(u - u_prev))
        if err.item() < stopThr:
            break

    # transport plan
    U = torch.diag_embed(u.squeeze())   # (m,m) diag(u)
    V = torch.diag_embed(v.squeeze())   # (n,n) diag(v)
    P = U @ K @ V                       # (m,n)
    return P

Projector t2s_img created with structure: Sequential(
  (0): Linear(in_features=1536, out_features=896, bias=True)
)
Projector t2s_txt created with structure: Sequential(
  (0): Linear(in_features=1536, out_features=896, bias=True)
)
Projectors set.


In [69]:
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
    x = compute_ot(s_qry_hidden_states, s_qry_attention,
                t_qry_hidden_states, t_qry_attention)
x

tensor([[2.1100e-05, 2.1100e-05, 2.1815e-05,  ..., 2.1100e-05, 2.2292e-05,
         2.4676e-05],
        [1.3292e-05, 1.3232e-05, 1.3828e-05,  ..., 1.3292e-05, 1.4067e-05,
         1.5497e-05],
        [1.3232e-05, 1.3232e-05, 1.3769e-05,  ..., 1.3232e-05, 1.4007e-05,
         1.5497e-05],
        ...,
        [1.3292e-05, 1.3292e-05, 1.3769e-05,  ..., 1.3292e-05, 1.4007e-05,
         1.5497e-05],
        [1.3232e-05, 1.3232e-05, 1.3828e-05,  ..., 1.3292e-05, 1.4067e-05,
         1.5497e-05],
        [1.4186e-05, 1.4126e-05, 1.4663e-05,  ..., 1.4186e-05, 1.4901e-05,
         1.6570e-05]], device='cuda:0', dtype=torch.bfloat16)
Cost mean: 1.0
s_dist sum: 0.9999999403953552
t_dist sum: 1.0
Transport mean: tensor(1.3113e-05, device='cuda:0', dtype=torch.bfloat16)
OT cost: 1.0381078720092773
tensor([[2.2411e-05, 2.2411e-05, 2.2173e-05,  ..., 2.2531e-05, 2.3246e-05,
         2.6584e-05],
        [1.3292e-05, 1.3232e-05, 1.3053e-05,  ..., 1.3292e-05, 1.3709e-05,
         1.5616e-05],
       

(tensor(3.0542, device='cuda:0', grad_fn=<DivBackward0>),
 tensor([[0.7022, 0.7447, 0.9193,  ..., 0.8969, 0.9255, 0.9026],
         [0.9866, 1.0291, 1.2037,  ..., 1.1813, 1.2100, 1.1871],
         [0.9766, 1.0191, 1.1937,  ..., 1.1713, 1.1999, 1.1770],
         ...,
         [1.0242, 1.0667, 1.2413,  ..., 1.2189, 1.2476, 1.2247],
         [0.7187, 0.7611, 0.9358,  ..., 0.9134, 0.9420, 0.9191],
         [1.3639, 1.4064, 1.5810,  ..., 1.5586, 1.5873, 1.5644]],
        device='cuda:0', grad_fn=<AsStridedBackward0>))

In [16]:
s_qry_hidden_states[0].size()
t_qry_hidden_states[0].size()
s_qry_attention[0].size()
t_qry_attention[0].size()

torch.Size([2, 12, 278, 278])