# Playground Notebook For Quantizing VLP Models

## Initialize the Distributed Backend

In [1]:
import os
import torch.distributed as dist
import copy

# Limit the number of CPUs
os.environ["OMP_NUM_THREADS"] = "4"  # Set this to the number of CPUs you want to use
os.environ["MKL_NUM_THREADS"] = "4"  # Set this to the number of CPUs you want to use

# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# Initialize the process group
dist.init_process_group(backend='gloo', init_method='env://', world_size=1, rank=0)

# Verify initialization
print(f"Initialized: {dist.is_initialized()}")

Initialized: True


In [2]:
import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="promote has been superseded by promote_options='default'")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.helpers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.layers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.registry is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_huge_patch14_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50d_224 in registry")
warnings.filterwarnings("ignore", category=FutureWarning, message="You are using `torch.load` with `weights_only=False`")

### Define helper functions

In [3]:
import torch

def print_size_of_model(model):
    """
    Function to print the size of the model.

    Args:
        model (torch.nn.Module): The model to get the size
    
    Returns:
        None
    """
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def get_accuracy(pl_module, logits, target, device="cpu"):
        correct = 0
        total = 0
        logits, target = (
            logits.detach().to(device),
            target.detach().to(device),
        )
        preds = logits.argmax(dim=-1)
        preds = preds[target != -100]
        target = target[target != -100]
        if target.numel() == 0:
            return 1

        assert preds.shape == target.shape

        correct += torch.sum(preds == target)
        total += target.numel()

        return correct/total

### Define the Configuration to Initialize the Datamodule and Model

In [4]:
import pytorch_lightning as pl

# Define the configuration for the experiments
vilt_config_nlvr2 = {'exp_name': 'test_ood_nlvr2', 'seed': 0, 'datasets': ['ood_nlvr2'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 0, 'nlvr2': 1, 'irtr': 0}, 'batch_size': 128, 'accelerator': 'gpu', 'train_transform_keys': ['pixelbert_randaug'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 0, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult': 1, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/vilt_nlvr2.ckpt', 'num_workers': 8, 'precision': 32}
vilt_config_vqav2 = {'exp_name': 'test_ood_vqa', 'seed': 0, 'datasets': ['ood_vqa'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 1, 'nlvr2': 0, 'irtr': 0}, 'batch_size': 256, 'accelerator': 'gpu', 'train_transform_keys': ['pixelbert_randaug'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 0, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult': 10, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 0.1, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/vilt_vqa.ckpt', 'num_workers': 8, 'precision': 32}

meter_config_nlvr2 = {'exp_name': 'test_ood_nlvr2', 'seed': 0, 'datasets': ['ood_nlvr2'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 0, 'vcr': 0, 'vcr_qar': 0, 'nlvr2': 1, 'irtr': 0, 'contras': 0, 'snli': 0}, 'batch_size': 256, 'accelerator': 'gpu', 'train_transform_keys': ['clip'], 'val_transform_keys': ['clip'], 'image_size': 288, 'patch_size': 16, 'draw_false_image': 0, 'image_only': False, 'resolution_before': 224, 'vqav2_label_size': 3129, 'max_text_len': 50, 'tokenizer': 'roberta-base', 'vocab_size': 50265, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'num_top_layer': 6, 'input_image_embed_size': 768, 'input_text_embed_size': 768, 'vit': 'ViT-B/16', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 6, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 1e-05, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult_head': 10, 'lr_mult_cross_modal': 5, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/meter_nlvr2.ckpt', 'num_workers': 8, 'precision': 32}
meter_config_vqav2 = {'exp_name': 'test_ood_vqa', 'seed': 0, 'datasets': ['ood_vqa'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 1, 'vcr': 0, 'vcr_qar': 0, 'nlvr2': 0, 'irtr': 0, 'contras': 0, 'snli': 0}, 'batch_size': 512, 'accelerator': 'gpu', 'train_transform_keys': ['clip'], 'val_transform_keys': ['clip'], 'image_size': 576, 'patch_size': 16, 'draw_false_image': 0, 'image_only': False, 'resolution_before': 224, 'vqav2_label_size': 3129, 'max_text_len': 50, 'tokenizer': 'roberta-base', 'vocab_size': 50265, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'num_top_layer': 6, 'input_image_embed_size': 768, 'input_text_embed_size': 768, 'vit': 'ViT-B/16', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 6, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 5e-06, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult_head': 50, 'lr_mult_cross_modal': 5, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 0.1, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 4, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/meter_vqa.ckpt', 'num_workers': 8, 'precision': 32}

pl.seed_everything(0)

Seed set to 0


0

## Initialize the Datamodule

Create a child datamodule that constructs a smaller version of the full datamodule

In [5]:
from torch.utils.data import Subset
from vilt.datamodules.multitask_datamodule import MTDataModule as MTDataModuleVILT
from meter.datamodules.multitask_datamodule import MTDataModule as MTDataModuleMeter

class SmallMTDataModuleVILT(MTDataModuleVILT):
    def __init__(self, _config, dist=False, num_samples=10, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))

class SmallMTDataModuleMETER(MTDataModuleMeter):
    def __init__(self, _config, dist=False, num_samples=10, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))

2024-12-05 18:43:17.041490: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733420597.060177 3157510 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733420597.065771 3157510 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Select the configuration and initialize the test and full datamodule

In [6]:
# Set the configuration
_config = vilt_config_nlvr2
_config["model_"] = "vilt"
_config["batch_size"] = 5

# ==========================================
# ========= Create full datamodule =========
# ==========================================
if "meter" in _config["model_"]:
    full_dm = MTDataModuleMeter(_config, dist=False)
    full_dm.setup("test")
    full_dataloader = full_dm.test_dataloader()
    
    test_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=5, start_idx=100)
    test_dm.setup("test")
    test_dataloader = test_dm.test_dataloader()

else:
    full_dm = MTDataModuleVILT(_config, dist=False)
    full_dm.setup("test")
    full_dataloader = full_dm.test_dataloader()

    test_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=5, start_idx=100)
    test_dm.setup("test")
    test_dataloader = test_dm.test_dataloader()



print(f"Length of the test dataloader: {len(test_dataloader.dataset)}")
print(f"Length of the full dataloader: {len(full_dataloader.dataset)}")

test_batch = next(iter(test_dataloader))
full_batch = next(iter(full_dataloader))

print(f"Lenght of the test batch: {len(test_batch)}")
print(f"Lenght of the full batch: {len(full_batch)}")

Length of the test dataloader: 5
Length of the full dataloader: 5662




Lenght of the test batch: 10
Lenght of the full batch: 10


## Initialize The Model

In [7]:
from vilt.modules import ViLTransformerSS
from meter.modules import METERTransformerSS

if _config["model_"] == "vilt":
    model = ViLTransformerSS(_config)
    print("Initialized ViLT model")

elif _config["model_"] == "meter":
    model = METERTransformerSS(_config)
    print("Initialized METER model")

else:
    raise ValueError("Model not supported")

Initialized ViLT model


## Initialize The Trainer

In [8]:
# ========== Initialize the trainer for full precision ==========
exp_name = f'{_config["exp_name"]}'

os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    verbose=True,
    monitor="val/the_metric",
    mode="max",
    save_last=True,
)
logger = pl.loggers.TensorBoardLogger(
    _config["log_dir"],
    name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]

num_gpus = (
    _config["num_gpus"]
    if isinstance(_config["num_gpus"], int)
    else len(_config["num_gpus"])
)

grad_steps = _config["batch_size"] // (
    _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
)

max_steps = _config["max_steps"] if _config["max_steps"] is not None else None


trainer = pl.Trainer(
        accelerator="cpu",
        devices=1,
        num_nodes=_config["num_nodes"],
        precision=_config["precision"],
        # strategy="ddp",
        benchmark=True,
        deterministic=False,
        max_epochs=_config["max_epoch"] if max_steps is None else 1000,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=logger,
        # accumulate_grad_batches=grad_steps,
        log_every_n_steps=10,
        fast_dev_run=_config["fast_dev_run"],
        val_check_interval=_config["val_check_interval"],
    )

trainer.test(model, datamodule=test_dm)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


Testing: |          | 0/? [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 5. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('nlvr2/test/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('nlvr2/test/accuracy', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packa

[{'nlvr2/test/loss': 1.2987868785858154,
  'nlvr2/test/accuracy': 0.6000000238418579,
  'nlvr2/dev/accuracy_epoch': nan,
  'nlvr2/dev/loss_epoch': nan,
  'nlvr2/test/accuracy_epoch': 0.6000000238418579,
  'nlvr2/test/loss_epoch': 1.2987868785858154,
  'val/the_metric': 0.6000000238418579}]

# Quantization | PTQ to 8-bits

### Dynamic Quantization

In [9]:
model_dynamic = copy.deepcopy(model)


torch.quantization.quantize_dynamic(
        model_dynamic, {torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.Conv2d}, dtype=torch.quint8, inplace=True
    )

torch.quantization.quantize_dynamic(
        model_dynamic, {torch.nn.Linear, torch.nn.Dropout, torch.nn.GELU, torch.nn.Conv2d}, dtype=torch.qint8, inplace=True
    )

print("Size after quantization:")
print_size_of_model(model_dynamic)
# print(model_dynamic)

Size after quantization:
Size (MB): 122.102028


### Static Quantization

In [10]:
model_static = copy.deepcopy(model)

# General quantization configuration
quantization_config = torch.quantization.get_default_qconfig('x86')

# Configuration for nn.Embedding layers
embedding_qconfig = torch.quantization.QConfig(
    activation=torch.quantization.HistogramObserver.with_args(reduce_range=True),
    weight=torch.quantization.default_float_qparams_observer.with_args(dtype=torch.quint8)
)

if _config["model_"] == "vilt":
    # Assign the quantization configurations to the model
    model_static.qconfig = quantization_config
    model_static.token_type_embeddings.qconfig = embedding_qconfig
    model_static.text_embeddings.word_embeddings.qconfig = embedding_qconfig
    model_static.text_embeddings.position_embeddings.qconfig = embedding_qconfig
    model_static.text_embeddings.token_type_embeddings.qconfig = embedding_qconfig

elif _config["model_"] == "meter":
    # Assign the quantization configurations to the model
    model_static.qconfig = quantization_config
    model_static.token_type_embeddings.qconfig = embedding_qconfig
    model_static.text_transformer.embeddings.word_embeddings.qconfig = embedding_qconfig
    model_static.text_transformer.embeddings.position_embeddings.qconfig = embedding_qconfig
    model_static.text_transformer.embeddings.token_type_embeddings.qconfig = embedding_qconfig

# Perform static quantization
torch.quantization.prepare(model_static, inplace=True)
trainer.test(model_static, datamodule=test_dm)
torch.quantization.convert(model_static, inplace=True)

print("Size after quantization:")
print_size_of_model(model_static)
# print(model_static)



Testing: |          | 0/? [00:00<?, ?it/s]

Size after quantization:
Size (MB): 116.453972


# Numeric Suite Analysis

In [None]:
import torch.quantization._numeric_suite as ns

def compute_error(x, y):
    """
    Signal to Noise Ratio (SNR)    
    """
    Ps = torch.norm(x)
    Pn = torch.norm(x-y)
    return 20*torch.log10(Ps/Pn)

## Static Model

In [12]:
# ======== Static quantization comparison ========
wt_compare_dict_static = ns.compare_weights(model.state_dict(), model_static.state_dict())

print('keys of wt_compare_dict:')
print(wt_compare_dict_static.keys())

key = 'text_embeddings.LayerNorm.weight'

# print(f"\nkeys of wt_compare_dict entry for {key} weight:")
# print(wt_compare_dict_static[key].keys())
# print(wt_compare_dict_static[key]['float'].shape)
# print(wt_compare_dict_static[key]['quantized'].shape)

for key in wt_compare_dict_static:
    print(key, compute_error(wt_compare_dict_static[key]['float'], wt_compare_dict_static[key]['quantized'].dequantize()))

keys of wt_compare_dict:
dict_keys(['text_embeddings.LayerNorm.weight', 'transformer.patch_embed.proj.weight', 'transformer.blocks.0.norm1.weight', 'transformer.blocks.0.attn.qkv._packed_params._packed_params', 'transformer.blocks.0.attn.proj._packed_params._packed_params', 'transformer.blocks.0.norm2.weight', 'transformer.blocks.0.mlp.fc1._packed_params._packed_params', 'transformer.blocks.0.mlp.fc2._packed_params._packed_params', 'transformer.blocks.1.norm1.weight', 'transformer.blocks.1.attn.qkv._packed_params._packed_params', 'transformer.blocks.1.attn.proj._packed_params._packed_params', 'transformer.blocks.1.norm2.weight', 'transformer.blocks.1.mlp.fc1._packed_params._packed_params', 'transformer.blocks.1.mlp.fc2._packed_params._packed_params', 'transformer.blocks.2.norm1.weight', 'transformer.blocks.2.attn.qkv._packed_params._packed_params', 'transformer.blocks.2.attn.proj._packed_params._packed_params', 'transformer.blocks.2.norm2.weight', 'transformer.blocks.2.mlp.fc1._packed_

In [13]:
# Take in floating point and quantized model as well as input data, and returns a dict, with keys
# corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and
# 'quantized', containing the activations of floating point and quantized model at matching locations.

# act_compare_dict_static = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(model_static), full_batch)
act_compare_dict_static = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(model_static), test_batch)

print('keys of act_compare_dict:')
print(act_compare_dict_static.keys())

key_act = "transformer.blocks.0.attn.qkv.stats"

print(f"\nkeys of act_compare_dict entry for {key} output:")
print(act_compare_dict_static[key_act].keys())
print(act_compare_dict_static[key_act]['float'][0].shape)
print(act_compare_dict_static[key_act]['quantized'][0].shape)

keys of act_compare_dict:
dict_keys(['text_embeddings.LayerNorm.stats', 'text_embeddings.quant.stats', 'transformer.patch_embed.proj.stats', 'transformer.patch_embed.quant.stats', 'transformer.blocks.0.norm1.stats', 'transformer.blocks.0.attn.qkv.stats', 'transformer.blocks.0.attn.proj.stats', 'transformer.blocks.0.attn.quant.stats', 'transformer.blocks.0.norm2.stats', 'transformer.blocks.0.mlp.fc1.stats', 'transformer.blocks.0.mlp.fc2.stats', 'transformer.blocks.0.quant.stats', 'transformer.blocks.1.norm1.stats', 'transformer.blocks.1.attn.qkv.stats', 'transformer.blocks.1.attn.proj.stats', 'transformer.blocks.1.attn.quant.stats', 'transformer.blocks.1.norm2.stats', 'transformer.blocks.1.mlp.fc1.stats', 'transformer.blocks.1.mlp.fc2.stats', 'transformer.blocks.1.quant.stats', 'transformer.blocks.2.norm1.stats', 'transformer.blocks.2.attn.qkv.stats', 'transformer.blocks.2.attn.proj.stats', 'transformer.blocks.2.attn.quant.stats', 'transformer.blocks.2.norm2.stats', 'transformer.blocks.

In [24]:
total_err = 0.0
for idx, key in enumerate(act_compare_dict_static):
    err = compute_error(act_compare_dict_static[key]['float'][0], act_compare_dict_static[key]['quantized'][0].dequantize())
    total_err += err
    print(f"{idx} - {key}")
    print(f"{idx} - {err}")

print(f"Total error: {total_err}")

0 - text_embeddings.LayerNorm.stats
0 - 17.82866859436035
1 - text_embeddings.quant.stats
1 - 18.955486297607422
2 - transformer.patch_embed.proj.stats
2 - 26.32302474975586
3 - transformer.patch_embed.quant.stats
3 - 44.364158630371094
4 - transformer.blocks.0.norm1.stats
4 - 0.4864632785320282
5 - transformer.blocks.0.attn.qkv.stats
5 - 5.93009090423584
6 - transformer.blocks.0.attn.proj.stats
6 - 0.6434845328330994
7 - transformer.blocks.0.attn.quant.stats
7 - 0.4597845673561096
8 - transformer.blocks.0.norm2.stats
8 - 0.7093193531036377
9 - transformer.blocks.0.mlp.fc1.stats
9 - 6.044439315795898
10 - transformer.blocks.0.mlp.fc2.stats
10 - -1.2007348537445068
11 - transformer.blocks.0.quant.stats
11 - -0.07495744526386261
12 - transformer.blocks.1.norm1.stats
12 - 0.08032231032848358
13 - transformer.blocks.1.attn.qkv.stats
13 - 5.3930511474609375
14 - transformer.blocks.1.attn.proj.stats
14 - 0.07086136937141418
15 - transformer.blocks.1.attn.quant.stats
15 - 0.02302350290119648


In [15]:
# import matplotlib.pyplot as plt

# q = wt_compare_dict[key]['quantized'].flatten().dequantize()
# f = wt_compare_dict[key]['float'].flatten()

# plt.hist(q, bins=100, alpha=0.5, label='Quantized')
# plt.hist(f, bins=100, alpha=0.5, label='Floating Point')


# plt.title(f"Model Weights of {key}")
# plt.legend()
# plt.show()

## Dynamic Model

In [16]:
# ======== Dynamic quantization comparison ========
wt_compare_dict_dynamic = ns.compare_weights(model.state_dict(), model_dynamic.state_dict())


print('keys of wt_compare_dict:')
print(wt_compare_dict_dynamic.keys())

# key = 'text_embeddings.LayerNorm.weight'

# print(f"\nkeys of wt_compare_dict entry for {key} weight:")
# print(wt_compare_dict_dynamic[key].keys())
# print(wt_compare_dict_dynamic[key]['float'].shape)
# print(wt_compare_dict_dynamic[key]['quantized'].shape)

for key in wt_compare_dict_dynamic:
    if wt_compare_dict_dynamic[key]['quantized'].is_quantized:
        print(key, compute_error(wt_compare_dict_dynamic[key]['float'], wt_compare_dict_dynamic[key]['quantized'].dequantize()))
    else:
        print(key, compute_error(wt_compare_dict_dynamic[key]['float'], wt_compare_dict_dynamic[key]['quantized']))

keys of wt_compare_dict:
dict_keys(['text_embeddings.LayerNorm.weight', 'transformer.patch_embed.proj.weight', 'transformer.blocks.0.norm1.weight', 'transformer.blocks.0.attn.qkv._packed_params._packed_params', 'transformer.blocks.0.attn.proj._packed_params._packed_params', 'transformer.blocks.0.norm2.weight', 'transformer.blocks.0.mlp.fc1._packed_params._packed_params', 'transformer.blocks.0.mlp.fc2._packed_params._packed_params', 'transformer.blocks.1.norm1.weight', 'transformer.blocks.1.attn.qkv._packed_params._packed_params', 'transformer.blocks.1.attn.proj._packed_params._packed_params', 'transformer.blocks.1.norm2.weight', 'transformer.blocks.1.mlp.fc1._packed_params._packed_params', 'transformer.blocks.1.mlp.fc2._packed_params._packed_params', 'transformer.blocks.2.norm1.weight', 'transformer.blocks.2.attn.qkv._packed_params._packed_params', 'transformer.blocks.2.attn.proj._packed_params._packed_params', 'transformer.blocks.2.norm2.weight', 'transformer.blocks.2.mlp.fc1._packed_

In [17]:
# import matplotlib.pyplot as plt

# q = wt_compare_dict[key]['quantized'].flatten().dequantize()
# f = wt_compare_dict[key]['float'].flatten()

# plt.hist(q, bins=100, alpha=0.5, label='Quantized')
# plt.hist(f, bins=100, alpha=0.5, label='Floating Point')


# plt.title(f"Model Weights of {key}")
# plt.legend()
# plt.show()

In [18]:
# act_compare_dict_dynamic = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(model_dynamic), full_batch)
act_compare_dict_dynamic = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(model_dynamic), test_batch)
print(act_compare_dict_dynamic.keys())

dict_keys(['text_embeddings.LayerNorm.stats', 'text_embeddings.quant.stats', 'transformer.patch_embed.proj.stats', 'transformer.patch_embed.quant.stats', 'transformer.blocks.0.norm1.stats', 'transformer.blocks.0.attn.qkv.stats', 'transformer.blocks.0.attn.proj.stats', 'transformer.blocks.0.attn.quant.stats', 'transformer.blocks.0.norm2.stats', 'transformer.blocks.0.mlp.fc1.stats', 'transformer.blocks.0.mlp.fc2.stats', 'transformer.blocks.0.quant.stats', 'transformer.blocks.1.norm1.stats', 'transformer.blocks.1.attn.qkv.stats', 'transformer.blocks.1.attn.proj.stats', 'transformer.blocks.1.attn.quant.stats', 'transformer.blocks.1.norm2.stats', 'transformer.blocks.1.mlp.fc1.stats', 'transformer.blocks.1.mlp.fc2.stats', 'transformer.blocks.1.quant.stats', 'transformer.blocks.2.norm1.stats', 'transformer.blocks.2.attn.qkv.stats', 'transformer.blocks.2.attn.proj.stats', 'transformer.blocks.2.attn.quant.stats', 'transformer.blocks.2.norm2.stats', 'transformer.blocks.2.mlp.fc1.stats', 'transfo

/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/core/module.py:445: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [19]:
total_err = 0
for idx, key in enumerate(act_compare_dict_dynamic):
    err = compute_error(act_compare_dict_dynamic[key]['float'][0][0], act_compare_dict_dynamic[key]['quantized'][0][0])
    # print(type(err))
    if torch.isinf(err):
        pass
    else:
        total_err += err
    print(f"{idx} - {key}")
    print(f"{idx} - {err}")

print(f"Total error: {total_err}")

0 - text_embeddings.LayerNorm.stats
0 - 38.937477111816406
1 - text_embeddings.quant.stats
1 - 38.74382019042969
2 - transformer.patch_embed.proj.stats
2 - inf
3 - transformer.patch_embed.quant.stats
3 - inf
4 - transformer.blocks.0.norm1.stats
4 - -0.9947296977043152
5 - transformer.blocks.0.attn.qkv.stats
5 - 4.067282676696777
6 - transformer.blocks.0.attn.proj.stats
6 - 0.0086312685161829
7 - transformer.blocks.0.attn.quant.stats
7 - -0.9947296977043152
8 - transformer.blocks.0.norm2.stats
8 - -0.7651967406272888
9 - transformer.blocks.0.mlp.fc1.stats
9 - 3.797595262527466
10 - transformer.blocks.0.mlp.fc2.stats
10 - -2.599177360534668
11 - transformer.blocks.0.quant.stats
11 - -1.2061665058135986
12 - transformer.blocks.1.norm1.stats
12 - -1.2639917135238647
13 - transformer.blocks.1.attn.qkv.stats
13 - 3.756920099258423
14 - transformer.blocks.1.attn.proj.stats
14 - -0.4064253270626068
15 - transformer.blocks.1.attn.quant.stats
15 - -1.2639917135238647
16 - transformer.blocks.1.no

In [20]:
print(model)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(40, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (quant): QuantStub()
    (dequant): DeQuantStub()
  )
  (token_type_embeddings): Embedding(3, 768)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      (quant): QuantStub()
      (dequant): DeQuantStub()
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features

In [21]:
print(model_static)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): QuantizedEmbedding(num_embeddings=30522, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (position_embeddings): QuantizedEmbedding(num_embeddings=40, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (token_type_embeddings): QuantizedEmbedding(num_embeddings=2, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (LayerNorm): QuantizedLayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): QuantizedDropout(p=0.1, inplace=False)
    (quant): Quantize(scale=tensor([0.0120]), zero_point=tensor([63]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (token_type_embeddings): QuantizedEmbedding(num_embeddings=3, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): 

In [22]:
print(model_dynamic)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): QuantizedEmbedding(num_embeddings=30522, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (position_embeddings): QuantizedEmbedding(num_embeddings=40, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (token_type_embeddings): QuantizedEmbedding(num_embeddings=2, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (quant): QuantStub()
    (dequant): DeQuantStub()
  )
  (token_type_embeddings): QuantizedEmbedding(num_embeddings=3, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      (quant): QuantStub()
  