In [1]:
import os
# Limit the number of CPUs
os.environ["OMP_NUM_THREADS"] = "10"  # Set this to the number of CPUs you want to use
os.environ["MKL_NUM_THREADS"] = "10"  # Set this to the number of CPUs you want to use


import torch
import random
random.seed(42)
from torch.utils.data import Subset
from vilt.datamodules.multitask_datamodule import MTDataModule as MTDataModuleVILT
from meter.datamodules.multitask_datamodule import MTDataModule as MTDataModuleMeter

from vilt.modules import ViLTransformerSS
from meter.modules import METERTransformerSS

# Set the configuration
import pytorch_lightning as pl
import configs
_config = configs.meter_config_nlvr2_original
_config["batch_size"] = 10
_config["per_gpu_batchsize"] = 10
pl.seed_everything(_config["seed"])

# Set the GPU device
gpu_id = 0
# device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
device = "cpu"
torch.cuda.set_device(gpu_id)

print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())



class SmallMTDataModuleVILT(MTDataModuleVILT):
    def __init__(self, _config, dist=False, num_samples=5, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage, is_random):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        if is_random:
            self.train_dataset = self._get_random_subset(self.train_dataset, self.num_samples)
            self.val_dataset = self._get_random_subset(self.val_dataset, self.num_samples)
            self.test_dataset = self._get_random_subset(self.test_dataset, self.num_samples)
        else:    
            self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
            self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
            self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        
    def _get_random_subset(self, dataset, num_samples):
        indices = random.sample(range(len(dataset)), num_samples)
        return Subset(dataset, indices)

class SmallMTDataModuleMETER(MTDataModuleMeter):
    def __init__(self, _config, dist=False, num_samples=10, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage, is_random):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        if is_random:
            self.train_dataset = self._get_random_subset(self.train_dataset, self.num_samples)
            self.val_dataset = self._get_random_subset(self.val_dataset, self.num_samples)
            self.test_dataset = self._get_random_subset(self.test_dataset, self.num_samples)
        else:    
            self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
            self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
            self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        
    
    def _get_random_subset(self, dataset, num_samples):
        indices = random.sample(range(len(dataset)), num_samples)
        return Subset(dataset, indices)


2025-03-06 16:00:32.026443: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741273232.110286 2765597 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741273232.141552 2765597 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  def vit_small_patch16_224(pretrained=False, **kwargs):
  def vit_base_patch16_224(pretrained=False, **kwargs):
  def vit_base_patch32_224(pretrained=False, **kwargs):
  def vit_base_patch16_384(pretrained=False, **kwargs):
  def vit_base_patch32_384(pretrained=False, **kwargs):
  def vit_large_patch16_224(pretrained=False, **kwargs):
  def vit_large_patch32_224(pretrained=False, **kwargs):
  def vit_large_patch16_384(pretrained=Fal

True
2
0


In [2]:
# Hessian analysis
def compute_gradients(pl_module, batch, layer):
    pl_module.zero_grad()

    infer1 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=1
    )
    infer2 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=2
    )

    cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1)
    nlvr2_logits = pl_module.nlvr2_classifier(cls_feats)

    nlvr2_labels = batch["answers"]
    nlvr2_labels = torch.tensor(nlvr2_labels).to(device).long()  # Move labels to GPU
    loss = torch.nn.functional.cross_entropy(nlvr2_logits, nlvr2_labels)

    grad_params = torch.autograd.grad(loss, layer.parameters(), create_graph=True)
    return grad_params

def hvp(layer, grad_params, v):
    # Flatten gradients and vector v
    grads = torch.cat([g.contiguous().view(-1) for g in grad_params])
    v = torch.cat([vi.contiguous().view(-1) for vi in v])
    
    # Compute g^T * v
    gTv = torch.dot(grads, v)
    
    # Compute Hv = ∇(g^T v)
    Hv = torch.autograd.grad(gTv, layer.parameters(), retain_graph=True)
    Hv = [h.detach() for h in Hv]  # Detach to stop gradient tracking
    return Hv

def compute_top_eigenvalue(model, layer, input, num_iterations=50):
    grad_params = compute_gradients(model, input, layer)
    
    # Initialize random vector v with same shape as parameters
    params = list(layer.parameters())
    v = [torch.randn_like(p).to(device) for p in params]  # Move v to GPU
    
    # Normalize v
    v_flat = torch.cat([vi.view(-1) for vi in v])
    v_norm = torch.norm(v_flat)
    v = [vi / v_norm for vi in v]
    
    for i in range(num_iterations):
        # if i % 5 == 0:
        #     print(f"Iteration: {i}")
        
        Hv = hvp(layer, grad_params, v)
        Hv_flat = torch.cat([hvi.view(-1) for hvi in Hv])
        
        # Update v and eigenvalue estimate
        v_norm = torch.norm(Hv_flat)
        v = [hvi / v_norm for hvi in Hv]
        eigenvalue = v_norm.item()
    
    return eigenvalue

def compute_layer_eigenvalues(model, input, num_iterations=50):
    eigenvalues = {}

    for name, layer in model.named_modules():
        
        if isinstance(layer, (torch.nn.Linear)):
            # print(f"Layer: {name}")
            if "encoder" not in name or "intermediate" not in name or "output" in name or "attention" in name:
                continue
            print(f"Computing eigenvalue for layer: {name}")
            eigenvalue = compute_top_eigenvalue(model, layer, input, num_iterations)

            eigenvalues[name] = eigenvalue   

            print("==============================================")
            print(f"Computed eigenvalue for layer {name} : {eigenvalue}")
            print("All eigenvalues computed so far:")
            print(f"{eigenvalues}")
            print("==============================================")

    return eigenvalues


In [3]:
# ==========================================
# ========= Create full datamodule =========
# ==========================================
if "meter" in _config["model"]:
    # full_dm = MTDataModuleMeter(_config, dist=False)
    
    # calibrarte_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=5, start_idx=100)
    
    infer_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=10, start_idx=0)
    infer_dm.setup("test", is_random=True)
    infer_dataloader = infer_dm.test_dataloader()

elif "vilt" in _config["model"]:
    # full_dm = MTDataModuleVILT(_config, dist=False)

    # calibrarte_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=5)
    
    infer_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=1, start_idx=0)
    infer_dm.setup("test", is_random=True)
    infer_dataloader = infer_dm.test_dataloader()

else:
    raise ValueError("Model not supported: ", _config["model"])

print(f"Batch size: {_config['batch_size']}")

# ==========================================
# ========= Initialize the model ===========
# ==========================================
if _config["model"] == "vilt":
    model = ViLTransformerSS(_config)
    print("Initialized ViLT model")

elif _config["model"] == "meter":
    model = METERTransformerSS(_config)
    print("Initialized METER model")

else:
    raise ValueError("Model not supported: ", _config["model"])

# Move model to GPU
model.to(device)


Loaded names: ['nlvr2_train']
Loaded names: ['nlvr2_dev', 'nlvr2_test1']
Loaded names: ['nlvr2_dev', 'nlvr2_test1']
Batch size: 10


  super().__init__(
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")


Initialized METER model


METERTransformerSS(
  (cross_modal_text_transform): Linear(in_features=768, out_features=768, bias=True)
  (cross_modal_image_transform): Linear(in_features=768, out_features=768, bias=True)
  (token_type_embeddings): Embedding(3, 768)
  (vit_model): CLIP(
    (visual): VisualTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
  

In [4]:
# ==========================================
# ======= Initialize the dataloader ========
# ==========================================
input_batch = next(iter(infer_dataloader))
num_batches = len(infer_dataloader)

print(input_batch.keys())
print(f"Number of batches: {num_batches}")
print(f"Samples in a batch: {len(input_batch['answers'])}")

# Move input data to GPU
for key in input_batch:
    if isinstance(input_batch[key], torch.Tensor):
        input_batch[key] = input_batch[key].to(device)

input_batch["image_0"][0] = input_batch["image_0"][0].to(device)
input_batch["image_1"][0] = input_batch["image_1"][0].to(device)


# ==========================================
# ========= Compute eigenvalues ============
# ==========================================
# for i in range(20):
#     infer_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=1, start_idx=0)
#     infer_dm.setup("test", is_random=True)
#     infer_dataloader = infer_dm.test_dataloader()
#     eigenvalues = compute_layer_eigenvalues(model, input_batch, num_iterations=50)

#     # Save the eigenvalues to a txt file
#     with open(f"eigenvalues_meter_{i}.txt", "w") as f:
#         f.write(str(eigenvalues))
# # eigenvalues = compute_averaged_eigenvalues(model, infer_dataloader, num_batches, num_iterations=50)

dict_keys(['text', 'image_1', 'image_0', 'table_name', 'answers', 'text_ids', 'text_labels', 'text_ids_mlm', 'text_labels_mlm', 'text_masks'])
Number of batches: 1
Samples in a batch: 10


# Hessian Eigenvalue Calculation

In [10]:
from quantization_utils import get_quantization_config
from copy import deepcopy

def quantize_weights(weight, bits=8):
    """
    Symmetric uniform quantization for weights.
    Args:
        weight (Tensor): Full-precision weights of a layer.
        bits (int): Number of quantization bits.
    Returns:
        dequantized_weight (Tensor): Quantized/dequantized weights.
    """
    # # Symmetric range based on max absolute value
    # max_val = torch.max(torch.abs(weight))
    # scale = max_val / (2 ** (bits - 1) - 1)
    
    # # Quantize and dequantize
    # quantized = torch.clamp(torch.round(weight / scale), -2**(bits-1), 2**(bits-1)-1)
    # dequantized_weight = quantized * scale
    # return dequantized_weight, scale
    min_val = weight.min()
    max_val = weight.max()
    scale = (max_val - min_val) / (2**bits - 1)
    zero_point = torch.round(-min_val / scale)
    quantized = torch.clamp(torch.round(weight / scale + zero_point), 0, 2**bits - 1)
    dequantized = (quantized - zero_point) * scale

    return dequantized, scale, zero_point

def compute_quantization_perturbation(model, eigenvalues, bits=4):
    """
    Compute quantization perturbation metric Ω_i for each layer.
    Args:
        model (nn.Module): PyTorch model.
        eigenvalues (dict): Precomputed top eigenvalues (λ_i) for each layer.
        bits (int): Number of quantization bits.
    Returns:
        omega_dict (dict): Ω_i values for each layer.
    """
    omega_dict = {}
    perturbation_dict = {}
    model_dynamic = deepcopy(model)
    # Quantize weights
    quantization_config, _ = get_quantization_config(4)
    torch.quantization.quantize_dynamic(
        model_dynamic, {"text_transformer.encoder.layer": quantization_config}, inplace=True
    )

    
    for i in range(12):
        layer_name = "text_transformer.encoder.layer." + str(i)
        if layer_name + ".output.dense" in eigenvalues:
            layer_name = layer_name + ".output.dense"
        elif layer_name + ".intermediate.dense" in eigenvalues:
            layer_name = layer_name + ".intermediate.dense"
        else:
            break
        # Get original weights
        weight = model.text_transformer.encoder.layer[i].intermediate.dense.weight
        quantized_weight = model_dynamic.text_transformer.encoder.layer[i].intermediate.dense.weight()
        

        q_weigths, scale, zero = quantize_weights(weight, bits)

        # print(weight)
        # print(f"Min: {weight.min()}, Max: {weight.max()}, Mean: {weight.mean()}")
        
        # print(quantized_weight)
        # print(f"Min: {quantized_weight.data.min()}, Max: {quantized_weight.data.max()}, Mean: {quantized_weight.data.mean()}")
        
        # print(quantized_weight.dequantize().data)
        # print(f"Min: {quantized_weight.dequantize().data.min()}, Max: {quantized_weight.dequantize().data.max()}, Mean: {quantized_weight.dequantize().data.mean()}")
        # print(q_weigths)
        # print(f"Scale: {scale}")
        # print(f"Zero Point: {zero}")
        # print("==============================================")

        # Compute L2 perturbation: ‖Q(W_i) - W_i‖²
        # perturbation = torch.norm(quantized_weight.dequantize().data - weight, p=2) ** 2
        perturbation = torch.norm(quantized_weight.dequantize().data - weight, p=2) ** 2
        
        # Compute Ω_i = λ_i * perturbation
        lambda_i = eigenvalues[layer_name]
        omega_i = lambda_i * perturbation.item()
        
        perturbation_dict[layer_name] = perturbation.item()
        omega_dict[layer_name] = omega_i


    return omega_dict, perturbation_dict

In [6]:
# eigenvalues0 = {'text_transformer.encoder.layer.0.intermediate.dense': 0.16790370643138885, 'text_transformer.encoder.layer.1.intermediate.dense': 0.1432534158229828, 'text_transformer.encoder.layer.2.intermediate.dense': 0.2982516884803772, 'text_transformer.encoder.layer.3.intermediate.dense': 0.2224932760000229, 'text_transformer.encoder.layer.4.intermediate.dense': 0.3737144470214844, 'text_transformer.encoder.layer.5.intermediate.dense': 0.5271583795547485, 'text_transformer.encoder.layer.6.intermediate.dense': 1.0165506601333618, 'text_transformer.encoder.layer.7.intermediate.dense': 1.4828656911849976, 'text_transformer.encoder.layer.8.intermediate.dense': 1.0849188566207886, 'text_transformer.encoder.layer.9.intermediate.dense': 1.3699719905853271, 'text_transformer.encoder.layer.10.intermediate.dense': 0.655511736869812, 'text_transformer.encoder.layer.11.intermediate.dense': 0.6610674858093262}
# eigenvalues1 = {'text_transformer.encoder.layer.0.intermediate.dense': 0.17076769471168518, 'text_transformer.encoder.layer.1.intermediate.dense': 0.2954607605934143, 'text_transformer.encoder.layer.2.intermediate.dense': 0.3319089114665985, 'text_transformer.encoder.layer.3.intermediate.dense': 0.22153879702091217, 'text_transformer.encoder.layer.4.intermediate.dense': 0.3949551284313202, 'text_transformer.encoder.layer.5.intermediate.dense': 0.414730429649353, 'text_transformer.encoder.layer.6.intermediate.dense': 0.5701264142990112, 'text_transformer.encoder.layer.7.intermediate.dense': 0.34018513560295105, 'text_transformer.encoder.layer.8.intermediate.dense': 1.1660056114196777, 'text_transformer.encoder.layer.9.intermediate.dense': 0.4433540105819702, 'text_transformer.encoder.layer.10.intermediate.dense': 0.37822774052619934, 'text_transformer.encoder.layer.11.intermediate.dense': 0.9691262245178223}
# eigenvalues2 = {'text_transformer.encoder.layer.0.intermediate.dense': 0.14164957404136658, 'text_transformer.encoder.layer.1.intermediate.dense': 0.1395464688539505, 'text_transformer.encoder.layer.2.intermediate.dense': 0.22589385509490967, 'text_transformer.encoder.layer.3.intermediate.dense': 0.184654101729393, 'text_transformer.encoder.layer.4.intermediate.dense': 0.5432092547416687, 'text_transformer.encoder.layer.5.intermediate.dense': 0.4408581256866455, 'text_transformer.encoder.layer.6.intermediate.dense': 0.598903238773346, 'text_transformer.encoder.layer.7.intermediate.dense': 0.4705699384212494, 'text_transformer.encoder.layer.8.intermediate.dense': 0.41199880838394165, 'text_transformer.encoder.layer.9.intermediate.dense': 1.2278794050216675, 'text_transformer.encoder.layer.10.intermediate.dense': 0.3707928955554962, 'text_transformer.encoder.layer.11.intermediate.dense': 0.2581954598426819}
# eigenvalues3 = {'text_transformer.encoder.layer.0.intermediate.dense': 0.16633661091327667, 'text_transformer.encoder.layer.1.intermediate.dense': 0.11317455768585205, 'text_transformer.encoder.layer.2.intermediate.dense': 0.32719480991363525, 'text_transformer.encoder.layer.3.intermediate.dense': 0.3101462423801422, 'text_transformer.encoder.layer.4.intermediate.dense': 0.4569089412689209, 'text_transformer.encoder.layer.5.intermediate.dense': 0.8330849409103394, 'text_transformer.encoder.layer.6.intermediate.dense': 0.7549309134483337, 'text_transformer.encoder.layer.7.intermediate.dense': 0.3048100471496582, 'text_transformer.encoder.layer.8.intermediate.dense': 0.36305004358291626, 'text_transformer.encoder.layer.9.intermediate.dense': 1.6126809120178223, 'text_transformer.encoder.layer.10.intermediate.dense': 0.4344247877597809, 'text_transformer.encoder.layer.11.intermediate.dense': 0.4249117970466614}
# eigenvalues4 = {'text_transformer.encoder.layer.0.intermediate.dense': 0.12828055024147034, 'text_transformer.encoder.layer.1.intermediate.dense': 0.12925875186920166, 'text_transformer.encoder.layer.2.intermediate.dense': 0.30379724502563477, 'text_transformer.encoder.layer.3.intermediate.dense': 0.34414660930633545, 'text_transformer.encoder.layer.4.intermediate.dense': 0.32475945353507996, 'text_transformer.encoder.layer.5.intermediate.dense': 0.3463519215583801, 'text_transformer.encoder.layer.6.intermediate.dense': 0.49876466393470764, 'text_transformer.encoder.layer.7.intermediate.dense': 0.5240191221237183, 'text_transformer.encoder.layer.8.intermediate.dense': 0.6521978378295898, 'text_transformer.encoder.layer.9.intermediate.dense': 1.50875723361969, 'text_transformer.encoder.layer.10.intermediate.dense': 0.5115640163421631, 'text_transformer.encoder.layer.11.intermediate.dense': 0.45940256118774414}
# eigenvalues5 = {'text_transformer.encoder.layer.0.output.dense': 24.300703048706055, 'text_transformer.encoder.layer.1.output.dense': 27.49468231201172, 'text_transformer.encoder.layer.2.output.dense': 23.84470558166504, 'text_transformer.encoder.layer.3.output.dense': 27.41619300842285, 'text_transformer.encoder.layer.4.output.dense': 33.994407653808594, 'text_transformer.encoder.layer.5.output.dense': 36.02847671508789, 'text_transformer.encoder.layer.6.output.dense': 37.94231414794922, 'text_transformer.encoder.layer.7.output.dense': 38.986351013183594, 'text_transformer.encoder.layer.8.output.dense': 40.81344223022461, 'text_transformer.encoder.layer.9.output.dense': 45.41384506225586, 'text_transformer.encoder.layer.10.output.dense': 32.24785232543945, 'text_transformer.encoder.layer.11.output.dense': 40.361717224121094}
# # Average the eigenvalues
# eigenvalues = {}
# for key in eigenvalues5.keys():
#     eigenvalues[key] = (eigenvalues0[key] + eigenvalues1[key] + eigenvalues2[key] + eigenvalues3[key] + eigenvalues4[key]) / 5

eigenvalues = {"text_transformer.encoder.layer.0.output.dense": 23.591533660888672, "text_transformer.encoder.layer.1.output.dense": 28.929523468017578, "text_transformer.encoder.layer.2.output.dense": 24.528627395629883, "text_transformer.encoder.layer.3.output.dense": 26.392732620239258, "text_transformer.encoder.layer.4.output.dense": 34.20820236206055, "text_transformer.encoder.layer.5.output.dense": 34.652042388916016, "text_transformer.encoder.layer.6.output.dense": 37.35274887084961, "text_transformer.encoder.layer.7.output.dense": 38.30513000488281, "text_transformer.encoder.layer.8.output.dense": 43.141258239746094, "text_transformer.encoder.layer.9.output.dense": 45.42149353027344, "text_transformer.encoder.layer.10.output.dense": 32.83196258544922, "text_transformer.encoder.layer.11.output.dense": 43.240238189697266}
eigenvalues = {'text_transformer.encoder.layer.0.output.dense': 24.300703048706055, 'text_transformer.encoder.layer.1.output.dense': 27.49468231201172, 'text_transformer.encoder.layer.2.output.dense': 23.84470558166504, 'text_transformer.encoder.layer.3.output.dense': 27.41619300842285, 'text_transformer.encoder.layer.4.output.dense': 33.994407653808594, 'text_transformer.encoder.layer.5.output.dense': 36.02847671508789, 'text_transformer.encoder.layer.6.output.dense': 37.94231414794922, 'text_transformer.encoder.layer.7.output.dense': 38.986351013183594, 'text_transformer.encoder.layer.8.output.dense': 40.81344223022461, 'text_transformer.encoder.layer.9.output.dense': 45.41384506225586, 'text_transformer.encoder.layer.10.output.dense': 32.24785232543945, 'text_transformer.encoder.layer.11.output.dense': 40.361717224121094}
print(f"Average eigenvalues: {eigenvalues}")


Average eigenvalues: {'text_transformer.encoder.layer.0.output.dense': 24.300703048706055, 'text_transformer.encoder.layer.1.output.dense': 27.49468231201172, 'text_transformer.encoder.layer.2.output.dense': 23.84470558166504, 'text_transformer.encoder.layer.3.output.dense': 27.41619300842285, 'text_transformer.encoder.layer.4.output.dense': 33.994407653808594, 'text_transformer.encoder.layer.5.output.dense': 36.02847671508789, 'text_transformer.encoder.layer.6.output.dense': 37.94231414794922, 'text_transformer.encoder.layer.7.output.dense': 38.986351013183594, 'text_transformer.encoder.layer.8.output.dense': 40.81344223022461, 'text_transformer.encoder.layer.9.output.dense': 45.41384506225586, 'text_transformer.encoder.layer.10.output.dense': 32.24785232543945, 'text_transformer.encoder.layer.11.output.dense': 40.361717224121094}


In [11]:
# 1. Compute eigenvalues for all layers
# eigenvalues = compute_layer_eigenvalues(model, input_batch, num_iterations=50)

# 2. Compute quantization perturbation metrics (Ω_i)
bits = 4  # Target quantization precision
omega_dict, perturbation_dict = compute_quantization_perturbation(model, eigenvalues, bits=bits)

# 3. Sort layers by Ω_i to determine fine-tuning order
sorted_layers = sorted(omega_dict.items(), key=lambda x: x[1], reverse=True)
print("Fine-tuning order (descending Ω_i):")
for layer, omega in sorted_layers:
    print(f"{layer}: {omega:.4f} || {perturbation_dict[layer]:.4f} || {eigenvalues[layer]:.4f}")

Fine-tuning order (descending Ω_i):
text_transformer.encoder.layer.1.output.dense: 83961.5639 || 3053.7383 || 27.4947
text_transformer.encoder.layer.0.output.dense: 73070.9029 || 3006.9460 || 24.3007
text_transformer.encoder.layer.7.output.dense: 71301.5530 || 1828.8850 || 38.9864
text_transformer.encoder.layer.6.output.dense: 57149.1147 || 1506.2106 || 37.9423
text_transformer.encoder.layer.4.output.dense: 52108.8623 || 1532.8657 || 33.9944
text_transformer.encoder.layer.8.output.dense: 47796.0532 || 1171.0861 || 40.8134
text_transformer.encoder.layer.9.output.dense: 44863.6595 || 987.8851 || 45.4138
text_transformer.encoder.layer.5.output.dense: 44211.4665 || 1227.1256 || 36.0285
text_transformer.encoder.layer.3.output.dense: 44066.4031 || 1607.3130 || 27.4162
text_transformer.encoder.layer.10.output.dense: 39600.2013 || 1227.9950 || 32.2479
text_transformer.encoder.layer.2.output.dense: 34007.0793 || 1426.1899 || 23.8447
text_transformer.encoder.layer.11.output.dense: 21534.3286 || 

In [None]:
eigenvalues = {'text_transformer.encoder.layer.0.output.dense': 0.135345126, 'text_transformer.encoder.layer.1.output.dense': 0.088416712, 'text_transformer.encoder.layer.2.output.dense': 0.096227547, 'text_transformer.encoder.layer.3.output.dense': 0.107769994, 'text_transformer.encoder.layer.4.output.dense': 0.188330392, 'text_transformer.encoder.layer.5.output.dense': 0.191571182, 'text_transformer.encoder.layer.6.output.dense': 0.257129359, 'text_transformer.encoder.layer.7.output.dense': 0.202422408, 'text_transformer.encoder.layer.8.output.dense': 0.278533794, 'text_transformer.encoder.layer.9.output.dense': 0.280634147, 'text_transformer.encoder.layer.10.output.dense': 0.202529321, 'text_transformer.encoder.layer.11.output.dense': 0.368850504}
eigenvalues = {"text_transformer.encoder.layer.0.output.dense": 23.591533660888672, "text_transformer.encoder.layer.1.output.dense": 28.929523468017578, "text_transformer.encoder.layer.2.output.dense": 24.528627395629883, "text_transformer.encoder.layer.3.output.dense": 26.392732620239258, "text_transformer.encoder.layer.4.output.dense": 34.20820236206055, "text_transformer.encoder.layer.5.output.dense": 34.652042388916016, "text_transformer.encoder.layer.6.output.dense": 37.35274887084961, "text_transformer.encoder.layer.7.output.dense": 38.30513000488281, "text_transformer.encoder.layer.8.output.dense": 43.141258239746094, "text_transformer.encoder.layer.9.output.dense": 45.42149353027344, "text_transformer.encoder.layer.10.output.dense": 32.83196258544922, "text_transformer.encoder.layer.11.output.dense": 43.240238189697266}
eigenvalues = {'text_transformer.encoder.layer.0.output.dense': 24.300703048706055, 'text_transformer.encoder.layer.1.output.dense': 27.49468231201172, 'text_transformer.encoder.layer.2.output.dense': 23.84470558166504, 'text_transformer.encoder.layer.3.output.dense': 27.41619300842285, 'text_transformer.encoder.layer.4.output.dense': 33.994407653808594, 'text_transformer.encoder.layer.5.output.dense': 36.02847671508789, 'text_transformer.encoder.layer.6.output.dense': 37.94231414794922, 'text_transformer.encoder.layer.7.output.dense': 38.986351013183594, 'text_transformer.encoder.layer.8.output.dense': 40.81344223022461, 'text_transformer.encoder.layer.9.output.dense': 45.41384506225586, 'text_transformer.encoder.layer.10.output.dense': 32.24785232543945, 'text_transformer.encoder.layer.11.output.dense': 40.361717224121094}
# 1. Compute eigenvalues for all layers
# eigenvalues = compute_layer_eigenvalues(model, input_batch, num_iterations=50)

# 2. Compute quantization perturbation metrics (Ω_i)
bits = 4  # Target quantization precision
omega_dict, perturbation_dict = compute_quantization_perturbation(model, eigenvalues, bits=bits)

# 3. Sort layers by Ω_i to determine fine-tuning order
sorted_layers = sorted(omega_dict.items(), key=lambda x: x[1], reverse=True)
# sorted_layers = sorted(eigenvalues.items(), key=lambda x: x[1], reverse=True)
print("Fine-tuning order (descending Ω_i):")
for layer, omega in sorted_layers:
    print(f"{layer}: {omega:.4f} || {perturbation_dict[layer]:.4f} || {eigenvalues[layer]:.4f}")

Fine-tuning order (descending Ω_i):
text_transformer.encoder.layer.1.output.dense: 88343.1933 || 3053.7383 || 28.9295
text_transformer.encoder.layer.0.output.dense: 70938.4688 || 3006.9460 || 23.5915
text_transformer.encoder.layer.7.output.dense: 70055.6781 || 1828.8850 || 38.3051
text_transformer.encoder.layer.6.output.dense: 56261.1052 || 1506.2106 || 37.3527
text_transformer.encoder.layer.4.output.dense: 52436.5808 || 1532.8657 || 34.2082
text_transformer.encoder.layer.8.output.dense: 50522.1261 || 1171.0861 || 43.1413
text_transformer.encoder.layer.9.output.dense: 44871.2154 || 987.8851 || 45.4215
text_transformer.encoder.layer.5.output.dense: 42522.4087 || 1227.1256 || 34.6520
text_transformer.encoder.layer.3.output.dense: 42421.3819 || 1607.3130 || 26.3927
text_transformer.encoder.layer.10.output.dense: 40317.4857 || 1227.9950 || 32.8320
text_transformer.encoder.layer.2.output.dense: 34982.4817 || 1426.1899 || 24.5286
text_transformer.encoder.layer.11.output.dense: 23070.1160 || 

# Perform Ablation

In [9]:
def compute_ablated_loss(pl_module, batch, target_layer_name):
    """
    Compute loss when a specific layer is ablated (zeroed out).
    Integrates with your existing pipeline.
    """
    original_loss = None
    delta_loss = None

    # ----------------------------------------------
    # Step 1: Compute original loss with all layers
    # ----------------------------------------------
    with torch.no_grad():
        infer1 = pl_module.infer(
            batch, mask_text=False, mask_image=False, image_token_type_idx=1
        )
        infer2 = pl_module.infer(
            batch, mask_text=False, mask_image=False, image_token_type_idx=2
        )

        cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1)
        nlvr2_logits = pl_module.nlvr2_classifier(cls_feats)
        
        nlvr2_labels = batch["answers"]
        nlvr2_labels = torch.tensor(nlvr2_labels).to(device).long()
        original_loss = torch.nn.functional.cross_entropy(nlvr2_logits, nlvr2_labels).item()

    # ----------------------------------------------------
    # Step 2: Compute loss with target layer ablated
    # ----------------------------------------------------
    def zero_output_hook(module, input, output):
        """Hook function to zero the layer's output"""
        return torch.zeros_like(output)

    # Register hook on target layer
    target_layer = dict(pl_module.named_modules())[target_layer_name]
    handle = target_layer.register_forward_hook(zero_output_hook)

    with torch.no_grad():
        infer1_ablated = pl_module.infer(
            batch, mask_text=False, mask_image=False, image_token_type_idx=1
        )
        infer2_ablated = pl_module.infer(
            batch, mask_text=False, mask_image=False, image_token_type_idx=2
        )

        cls_feats_ablated = torch.cat([infer1_ablated["cls_feats"], infer2_ablated["cls_feats"]], dim=-1)
        nlvr2_logits_ablated = pl_module.nlvr2_classifier(cls_feats_ablated)
        
        ablated_loss = torch.nn.functional.cross_entropy(nlvr2_logits_ablated, nlvr2_labels).item()

    # Cleanup and return
    handle.remove()
    delta_loss = ablated_loss - original_loss

    return {
        "layer": target_layer_name,
        "original_loss": original_loss,
        "ablated_loss": ablated_loss,
        "delta_loss": delta_loss
    }

def ablation_analysis(pl_module, dataloader, layer_names):
    """
    Full ablation study pipeline
    """
    results = []
    
    for idx, batch in enumerate(dataloader):
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(device)

        batch["image_0"][0] = batch["image_0"][0].to(device)
        batch["image_1"][0] = batch["image_1"][0].to(device)
        print("==============================================")
        print(f"Batch: {idx}")
        
        for layer_name in layer_names:
            if any(substr in layer_name for substr in ["encoder", "intermediate", "output", "attention"]):
                result = compute_ablated_loss(pl_module, batch, layer_name)
                results.append(result)
                print(f"Ablated {layer_name} | ΔLoss: {result['delta_loss']:.4f}")
    
    return results

In [10]:
# Define layers to analyze (example for transformer model)
target_layers = [
    "text_transformer.encoder.layer.0.intermediate.dense",
    "text_transformer.encoder.layer.0.output.dense",
    "text_transformer.encoder.layer.1.intermediate.dense",
    "text_transformer.encoder.layer.1.output.dense",
    "text_transformer.encoder.layer.2.intermediate.dense",
    "text_transformer.encoder.layer.2.output.dense",
    "text_transformer.encoder.layer.3.intermediate.dense",
    "text_transformer.encoder.layer.3.output.dense",
    "text_transformer.encoder.layer.8.intermediate.dense",
    "text_transformer.encoder.layer.8.output.dense",
    "text_transformer.encoder.layer.9.intermediate.dense",
    "text_transformer.encoder.layer.9.output.dense",
]

device = "cuda"
model.to(device)

# input_batch = next(iter(infer_dataloader))
# num_batches = len(infer_dataloader)

# print(input_batch.keys())
# print(f"Number of batches: {num_batches}")
# print(f"Samples in a batch: {len(input_batch['answers'])}")

# # Move input data to GPU
# for key in input_batch:
#     if isinstance(input_batch[key], torch.Tensor):
#         input_batch[key] = input_batch[key].to(device)

# input_batch["image_0"][0] = input_batch["image_0"][0].to(device)
# input_batch["image_1"][0] = input_batch["image_1"][0].to(device)


# Run ablation study
ablation_results = ablation_analysis(
    pl_module=model,
    dataloader=infer_dataloader,
    layer_names=target_layers
)

# Analyze results
for result in ablation_results:
    print(f"Layer {result['layer']}:")
    print(f"  Original Loss: {result['original_loss']:.4f}")
    print(f"  Ablated Loss:  {result['ablated_loss']:.4f}")
    print(f"  ΔLoss:         {result['delta_loss']:.4f}\n")

Batch: 0




Ablated text_transformer.encoder.layer.0.intermediate.dense | ΔLoss: 0.0790
Ablated text_transformer.encoder.layer.0.output.dense | ΔLoss: 0.0510
Ablated text_transformer.encoder.layer.1.intermediate.dense | ΔLoss: 0.0227
Ablated text_transformer.encoder.layer.1.output.dense | ΔLoss: 0.0438
Ablated text_transformer.encoder.layer.2.intermediate.dense | ΔLoss: 0.0411
Ablated text_transformer.encoder.layer.2.output.dense | ΔLoss: 0.0567
Ablated text_transformer.encoder.layer.3.intermediate.dense | ΔLoss: -0.0165
Ablated text_transformer.encoder.layer.3.output.dense | ΔLoss: -0.0236
Ablated text_transformer.encoder.layer.8.intermediate.dense | ΔLoss: 0.0570
Ablated text_transformer.encoder.layer.8.output.dense | ΔLoss: 0.0575
Ablated text_transformer.encoder.layer.9.intermediate.dense | ΔLoss: -0.0088
Ablated text_transformer.encoder.layer.9.output.dense | ΔLoss: -0.0140
Layer text_transformer.encoder.layer.0.intermediate.dense:
  Original Loss: 0.4641
  Ablated Loss:  0.5432
  ΔLoss:     

# Gemini Conn Sens

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_connection_sensitivity(pl_module, batch, target_layers):
    """
    Computes the connection sensitivity for specified layers in a PyTorch model
    following the user's forward pass implementation.

    Args:
        pl_module (nn.Module): The PyTorch Lightning Module (or your model).
        batch (dict): A dictionary representing the input batch (as expected by pl_module).
        target_layers (list of str): List of layer names to compute sensitivity for.

    Returns:
        dict: Dictionary of layer sensitivities, keyed by layer names.
    """
    sensitivities = {}
    device = next(pl_module.parameters()).device # Get device from model params

    # Make sure gradients are enabled
    for param in pl_module.parameters():
        param.requires_grad_(True)
    pl_module.train() # Set model to training mode for gradient calculation

    # Zero out gradients
    pl_module.zero_grad()

    # Forward pass (following user's implementation)
    infer1 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=1
    )
    infer2 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=2
    )

    cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1)
    nlvr2_logits = pl_module.nlvr2_classifier(cls_feats)

    nlvr2_labels = batch["answers"]
    target = torch.tensor(nlvr2_labels).to(model.device).long()

    # Loss calculation (following user's implementation)
    loss = torch.nn.functional.cross_entropy(nlvr2_logits, target)

    # Backward pass to compute gradients
    loss.backward()

    for layer_name in target_layers:
        try:
            layer = pl_module.get_submodule(layer_name)
            if isinstance(layer, nn.Linear): # Assuming target layers are Linear layers
                weight_grad = layer.weight.grad
                if weight_grad is not None:
                    abs_grad = torch.abs(weight_grad)
                    sensitivity = torch.mean(abs_grad).item() # Average absolute gradient
                    sensitivities[layer_name] = sensitivity
                else:
                    print(f"Warning: Gradients not found for layer: {layer_name}. "
                          f"Ensure layer parameters are part of the computational graph "
                          f"and require gradients.")
                    sensitivities[layer_name] = None # Indicate no sensitivity calculated
            else:
                print(f"Warning: Layer {layer_name} is not a Linear layer. "
                      f"Connection sensitivity calculation is implemented for Linear layers in this example.")
                sensitivities[layer_name] = None # Indicate not applicable for this layer type
        except AttributeError:
            print(f"Warning: Layer '{layer_name}' not found in the model.")
            sensitivities[layer_name] = None # Indicate layer not found

    return sensitivities


# 1. Define your pl_module (replace DummyTransformerModel with your actual pl_module)
pl_module = model

# 2. Define target layers
target_layers = [
    "text_transformer.encoder.layer.0.intermediate.dense",
    "text_transformer.encoder.layer.0.output.dense",
    "text_transformer.encoder.layer.1.intermediate.dense",
    "text_transformer.encoder.layer.1.output.dense",
    "text_transformer.encoder.layer.2.intermediate.dense",
    "text_transformer.encoder.layer.2.output.dense",
    "text_transformer.encoder.layer.3.intermediate.dense",
    "text_transformer.encoder.layer.3.output.dense",
    "text_transformer.encoder.layer.8.intermediate.dense",
    "text_transformer.encoder.layer.8.output.dense",
    "text_transformer.encoder.layer.9.intermediate.dense",
    "text_transformer.encoder.layer.9.output.dense",
]

# 3. Initialize the dataloader
input_batch = next(iter(infer_dataloader))
num_batches = len(infer_dataloader)

print(input_batch.keys())
print(f"Number of batches: {num_batches}")
print(f"Samples in a batch: {len(input_batch['answers'])}")

# Move input data to GPU
for key in input_batch:
    if isinstance(input_batch[key], torch.Tensor):
        input_batch[key] = input_batch[key].to(device)

input_batch["image_0"][0] = input_batch["image_0"][0].to(device)
input_batch["image_1"][0] = input_batch["image_1"][0].to(device)


# 4. Compute connection sensitivities
layer_sensitivities = compute_connection_sensitivity(pl_module, input_batch, target_layers)

# 5. Print the results
print("Connection Sensitivities:")
for layer_name, sensitivity in layer_sensitivities.items():
    print(f"Layer: {layer_name}, Sensitivity: {sensitivity}")

dict_keys(['table_name', 'answers', 'text', 'image_0', 'image_1', 'text_ids', 'text_labels', 'text_ids_mlm', 'text_labels_mlm', 'text_masks'])
Number of batches: 1
Samples in a batch: 10
Connection Sensitivities:
Layer: text_transformer.encoder.layer.0.intermediate.dense, Sensitivity: 0.001178910257294774
Layer: text_transformer.encoder.layer.0.output.dense, Sensitivity: 0.0014642142923548818
Layer: text_transformer.encoder.layer.1.intermediate.dense, Sensitivity: 0.001485239015892148
Layer: text_transformer.encoder.layer.1.output.dense, Sensitivity: 0.0016155705088749528
Layer: text_transformer.encoder.layer.2.intermediate.dense, Sensitivity: 0.0011649816296994686
Layer: text_transformer.encoder.layer.2.output.dense, Sensitivity: 0.0012712652096524835
Layer: text_transformer.encoder.layer.3.intermediate.dense, Sensitivity: 0.001026810728944838
Layer: text_transformer.encoder.layer.3.output.dense, Sensitivity: 0.0012590724509209394
Layer: text_transformer.encoder.layer.8.intermediate.d

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_connection_sensitivity(pl_module, batch, target_layers):
    """
    Computes the connection sensitivity for specified layers in a PyTorch model
    following the user's forward pass implementation.

    Args:
        pl_module (nn.Module): The PyTorch Lightning Module (or your model).
        batch (dict): A dictionary representing the input batch (as expected by pl_module).
        target_layers (list of str): List of layer names to compute sensitivity for.

    Returns:
        dict: Dictionary of layer sensitivities, keyed by layer names.
    """
    sensitivities = {}
    device = next(pl_module.parameters()).device # Get device from model params

    # Make sure gradients are enabled
    for param in pl_module.parameters():
        param.requires_grad_(True)
    pl_module.train() # Set model to training mode for gradient calculation

    # Zero out gradients
    pl_module.zero_grad()

    # Forward pass (following user's implementation)
    infer1 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=1
    )
    infer2 = pl_module.infer(
        batch, mask_text=False, mask_image=False, image_token_type_idx=2
    )

    cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1)
    nlvr2_logits = pl_module.nlvr2_classifier(cls_feats)

    nlvr2_labels = batch["answers"]
    target = torch.tensor(nlvr2_labels).to(device).long() # Use 'device' here

    # Loss calculation (following user's implementation)
    loss = torch.nn.functional.cross_entropy(nlvr2_logits, target)

    # Backward pass to compute gradients
    loss.backward()

    for layer_name in target_layers:
        try:
            layer = pl_module.get_submodule(layer_name)
            if isinstance(layer, nn.Linear): # Assuming target layers are Linear layers
                weight_grad = layer.weight.grad
                if weight_grad is not None:
                    abs_grad = torch.abs(weight_grad)
                    sensitivity = torch.mean(abs_grad).item() # Average absolute gradient
                    sensitivities[layer_name] = sensitivity
                else:
                    print(f"Warning: Gradients not found for layer: {layer_name}. "
                          f"Ensure layer parameters are part of the computational graph "
                          f"and require gradients.")
                    sensitivities[layer_name] = None # Indicate no sensitivity calculated
            else:
                print(f"Warning: Layer {layer_name} is not a Linear layer. "
                      f"Connection sensitivity calculation is implemented for Linear layers in this example.")
                sensitivities[layer_name] = None # Indicate not applicable for this layer type
        except AttributeError:
            print(f"Warning: Layer '{layer_name}' not found in the model.")
            sensitivities[layer_name] = None # Indicate layer not found

    return sensitivities


if __name__ == '__main__':
    # Assuming 'model', 'infer_dataloader', and 'device' are already defined in your script

    pl_module = model # Use your actual 'model' (pl_module)
    device = "cuda"
    model.to(device)

    # 2. Define target layers
    target_layers = [
        "text_transformer.encoder.layer.0.intermediate.dense",
        "text_transformer.encoder.layer.0.output.dense",
        "text_transformer.encoder.layer.1.intermediate.dense",
        "text_transformer.encoder.layer.1.output.dense",
        "text_transformer.encoder.layer.2.intermediate.dense",
        "text_transformer.encoder.layer.2.output.dense",
        "text_transformer.encoder.layer.3.intermediate.dense",
        "text_transformer.encoder.layer.3.output.dense",
        "text_transformer.encoder.layer.8.intermediate.dense",
        "text_transformer.encoder.layer.8.output.dense",
        "text_transformer.encoder.layer.9.intermediate.dense",
        "text_transformer.encoder.layer.9.output.dense",
    ]

    # 3. Initialize the dataloader (assuming infer_dataloader is already defined)
    input_batch = next(iter(infer_dataloader))
    num_batches = len(infer_dataloader)

    print(input_batch.keys())
    print(f"Number of batches: {num_batches}")
    print(f"Samples in a batch: {len(input_batch['answers'])}")

    # Move input data to GPU
    for key in input_batch:
        if isinstance(input_batch[key], torch.Tensor):
            input_batch[key] = input_batch[key].to(device)

    if "image_0" in input_batch and isinstance(input_batch["image_0"], list) and input_batch["image_0"]:
        input_batch["image_0"][0] = input_batch["image_0"][0].to(device)
    if "image_1" in input_batch and isinstance(input_batch["image_1"], list) and input_batch["image_1"]:
        input_batch["image_1"][0] = input_batch["image_1"][0].to(device)


    # 4. Compute connection sensitivities
    layer_sensitivities = compute_connection_sensitivity(pl_module, input_batch, target_layers)

    # 5. Print the results in decreasing order of sensitivity
    sorted_sensitivities = sorted(layer_sensitivities.items(), key=lambda item: item[1], reverse=True)

    print("Connection Sensitivities (Decreasing Order):")
    for layer_name, sensitivity in sorted_sensitivities:
        print(f"Layer: {layer_name}, Sensitivity: {sensitivity}")

dict_keys(['table_name', 'answers', 'text', 'image_0', 'image_1', 'text_ids', 'text_labels', 'text_ids_mlm', 'text_labels_mlm', 'text_masks'])
Number of batches: 1
Samples in a batch: 10




: 