# Playground Notebook For Quantizing VLP Models

## Initialize the Distributed Backend

In [1]:
import os
import torch.distributed as dist
import copy

# Limit the number of CPUs
os.environ["OMP_NUM_THREADS"] = "2"  # Set this to the number of CPUs you want to use
os.environ["MKL_NUM_THREADS"] = "2"  # Set this to the number of CPUs you want to use

# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# Initialize the process group
dist.init_process_group(backend='gloo', init_method='env://', world_size=1, rank=0)

# Verify initialization
print(f"Initialized: {dist.is_initialized()}")

Initialized: True


In [2]:
import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="promote has been superseded by promote_options='default'")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.helpers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.layers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.registry is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_huge_patch14_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50d_224 in registry")
warnings.filterwarnings("ignore", category=FutureWarning, message="You are using `torch.load` with `weights_only=False`")

### Define helper functions

In [3]:
import torch

def print_size_of_model(model):
    """
    Function to print the size of the model.

    Args:
        model (torch.nn.Module): The model to get the size
    
    Returns:
        None
    """
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def get_accuracy(pl_module, logits, target, device="cpu"):
        correct = 0
        total = 0
        logits, target = (
            logits.detach().to(device),
            target.detach().to(device),
        )
        preds = logits.argmax(dim=-1)
        preds = preds[target != -100]
        target = target[target != -100]
        if target.numel() == 0:
            return 1

        assert preds.shape == target.shape

        correct += torch.sum(preds == target)
        total += target.numel()

        return correct/total

### Define the Configuration to Initialize the Datamodule and Model

In [4]:
import pytorch_lightning as pl
import configs

# Set the configuration
_config = configs.vilt_config_vqav2
_config["model_"] = "vilt"
_config["batch_size"] = 32


pl.seed_everything(_config["seed"])

Seed set to 0


0

## Initialize the Datamodule

Create a child datamodule that constructs a smaller version of the full datamodule

In [5]:
from torch.utils.data import Subset
from vilt.datamodules.multitask_datamodule import MTDataModule as MTDataModuleVILT
from meter.datamodules.multitask_datamodule import MTDataModule as MTDataModuleMeter

class SmallMTDataModuleVILT(MTDataModuleVILT):
    def __init__(self, _config, dist=False, num_samples=5, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))

class SmallMTDataModuleMETER(MTDataModuleMeter):
    def __init__(self, _config, dist=False, num_samples=10, start_idx=100):
        super().__init__(_config, dist)
        self.num_samples = num_samples
        self.start_idx = start_idx

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.start_idx, self.start_idx+self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.start_idx, self.start_idx+self.num_samples))

2024-12-25 09:10:06.212024: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1735114206.231243  836915 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1735114206.237047  836915 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Select the configuration and initialize the test and full datamodule

In [6]:
# ==========================================
# ========= Create full datamodule =========
# ==========================================
if "meter" in _config["model_"]:
    full_dm = MTDataModuleMeter(_config, dist=False)
    
    calibrarte_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=5, start_idx=100)
    
    infer_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=5, start_idx=0)
    infer_dm.setup("test")
    infer_dataloader = infer_dm.test_dataloader()

else:
    full_dm = MTDataModuleVILT(_config, dist=False)

    calibrarte_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=5, start_idx=100)
    
    infer_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=5, start_idx=0)
    infer_dm.setup("test")
    infer_dataloader = infer_dm.test_dataloader()




Loaded names: ['vqa_vlue_test']
Loaded names: ['vqa_vlue_test']
Loaded names: ['vqa_vlue_test']


## Initialize The Model

In [7]:
from vilt.modules import ViLTransformerSS
from meter.modules import METERTransformerSS

if _config["model_"] == "vilt":
    model = ViLTransformerSS(_config)
    print("Initialized ViLT model")

elif _config["model_"] == "meter":
    model = METERTransformerSS(_config)
    print("Initialized METER model")

else:
    raise ValueError("Model not supported")

Initialized ViLT model


## Initialize The Trainer

In [8]:
# ========== Initialize the trainer for full precision ==========
exp_name = f'{_config["exp_name"]}'

os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    verbose=True,
    monitor="val/the_metric",
    mode="max",
    save_last=True,
)
logger = pl.loggers.TensorBoardLogger(
    _config["log_dir"],
    name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]

num_gpus = (
    _config["num_gpus"]
    if isinstance(_config["num_gpus"], int)
    else len(_config["num_gpus"])
)

grad_steps = _config["batch_size"] // (
    _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
)

max_steps = _config["max_steps"] if _config["max_steps"] is not None else None


trainer = pl.Trainer(
        accelerator="cpu",
        devices=1,
        num_nodes=_config["num_nodes"],
        precision=_config["precision"],
        # strategy="ddp",
        benchmark=True,
        deterministic=False,
        max_epochs=_config["max_epoch"] if max_steps is None else 1000,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=logger,
        # accumulate_grad_batches=grad_steps,
        log_every_n_steps=10,
        fast_dev_run=_config["fast_dev_run"],
        val_check_interval=_config["val_check_interval"],
    )

trainer.test(model, datamodule=calibrarte_dm)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Loaded names: ['vqa_vlue_test']


/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Loaded names: ['vqa_vlue_test']
Loaded names: ['vqa_vlue_test']


Testing: |          | 0/? [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('vqa/val/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('vqa/val/score', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('vqa/val/score_epoch', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/

[{'vqa/val/loss': 3.7881855964660645,
  'vqa/val/score': 0.0,
  'vqa/val/score_epoch': 0.0,
  'vqa/val/loss_epoch': 3.7881853580474854,
  'val/the_metric': 0.0}]

# Quantization | PTQ to 8-bits

### Dynamic Quantization

In [9]:
import dynamic_quantization as dq

default_dynamic = copy.deepcopy(model)

torch.quantization.quantize_dynamic(
        default_dynamic, {torch.nn.Embedding}, dtype=torch.quint8, inplace=True
    )

torch.quantization.quantize_dynamic(
        default_dynamic, {torch.nn.Linear, torch.nn.LayerNorm}, dtype=torch.qint8, inplace=True
    )

print("Size after quantization:")
print_size_of_model(default_dynamic)
# print(model_dynamic)

Size after quantization:
Size (MB): 125.73422


In [10]:
import dynamic_quantization as dq

custom_8bit = copy.deepcopy(model)
custom_8bit = dq.quantize_model_dynamic(custom_8bit, 8)

print("Size after quantization:")
print_size_of_model(custom_8bit)
# print(model_dynamic)

Size after quantization:
Size (MB): 125.73422


# Numeric Suite Analysis

In [11]:
import torch.quantization._numeric_suite as ns

def compute_error(x, y):
    """
    Signal to Noise Ratio (SNR)    
    """
    Ps = torch.norm(x)
    Pn = torch.norm(x-y)
    return 20*torch.log10(Ps/Pn)

## Dynamic Model

In [12]:
# ======== Dynamic quantization comparison ========
wt_compare_dict_dynamic = ns.compare_weights(model.state_dict(), custom_8bit.state_dict())

# print('keys of wt_compare_dict:')
# print(wt_compare_dict_dynamic.keys())

# key = 'text_embeddings.LayerNorm.weight'

total_error = 0
inf_count = 0
max_err = 0
for i, key in enumerate(wt_compare_dict_dynamic):
    if wt_compare_dict_dynamic[key]['quantized'].is_quantized:
        err = compute_error(wt_compare_dict_dynamic[key]['float'], wt_compare_dict_dynamic[key]['quantized'].dequantize())
        
        print(f"{i} - {key}")
        print(f"{i} - {err:.2f}")
        
        if not torch.isinf(err):
            total_error += err
            if err > max_err:
                max_err = err
        else:
            inf_count += 1
    else:
        err = compute_error(wt_compare_dict_dynamic[key]['float'], wt_compare_dict_dynamic[key]['quantized'])
        
        print(f"{i} - {key}")
        print(f"{i} - {err:.2f}")

        if not torch.isinf(err):
            total_error += err
            if err > max_err:
                max_err = err
        else:
            inf_count += 1

print(f"Total error: {total_error:.2f}")
print(f"Total inf: {inf_count}")
print(f"Max error: {max_err}")

0 - text_embeddings.LayerNorm.weight
0 - inf
1 - transformer.patch_embed.proj.weight
1 - inf
2 - transformer.blocks.0.norm1.weight
2 - inf
3 - transformer.blocks.0.attn.qkv._packed_params._packed_params
3 - 27.72
4 - transformer.blocks.0.attn.proj._packed_params._packed_params
4 - 22.11
5 - transformer.blocks.0.norm2.weight
5 - inf
6 - transformer.blocks.0.mlp.fc1._packed_params._packed_params
6 - 26.74
7 - transformer.blocks.0.mlp.fc2._packed_params._packed_params
7 - 18.21
8 - transformer.blocks.1.norm1.weight
8 - inf
9 - transformer.blocks.1.attn.qkv._packed_params._packed_params
9 - 32.24
10 - transformer.blocks.1.attn.proj._packed_params._packed_params
10 - 29.18
11 - transformer.blocks.1.norm2.weight
11 - inf
12 - transformer.blocks.1.mlp.fc1._packed_params._packed_params
12 - 27.66
13 - transformer.blocks.1.mlp.fc2._packed_params._packed_params
13 - 16.82
14 - transformer.blocks.2.norm1.weight
14 - inf
15 - transformer.blocks.2.attn.qkv._packed_params._packed_params
15 - 34.03
1

In [14]:
infer_batch = next(iter(infer_dataloader))
calibration_batch = next(iter(calibrarte_dm.val_dataloader()))
# full_batch = next(iter(full_dm.val_dataloader()))

In [15]:
# act_compare_dict_dynamic = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(model_dynamic), full_batch)
act_compare_dict_dynamic = ns.compare_model_outputs(copy.deepcopy(model), copy.deepcopy(custom_8bit), infer_batch)
print(act_compare_dict_dynamic.keys())

/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/core/module.py:445: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


dict_keys(['text_embeddings.LayerNorm.stats', 'text_embeddings.quant.stats', 'transformer.patch_embed.proj.stats', 'transformer.blocks.0.norm1.stats', 'transformer.blocks.0.attn.qkv.stats', 'transformer.blocks.0.attn.proj.stats', 'transformer.blocks.0.norm2.stats', 'transformer.blocks.0.mlp.fc1.stats', 'transformer.blocks.0.mlp.fc2.stats', 'transformer.blocks.1.norm1.stats', 'transformer.blocks.1.attn.qkv.stats', 'transformer.blocks.1.attn.proj.stats', 'transformer.blocks.1.norm2.stats', 'transformer.blocks.1.mlp.fc1.stats', 'transformer.blocks.1.mlp.fc2.stats', 'transformer.blocks.2.norm1.stats', 'transformer.blocks.2.attn.qkv.stats', 'transformer.blocks.2.attn.proj.stats', 'transformer.blocks.2.norm2.stats', 'transformer.blocks.2.mlp.fc1.stats', 'transformer.blocks.2.mlp.fc2.stats', 'transformer.blocks.3.norm1.stats', 'transformer.blocks.3.attn.qkv.stats', 'transformer.blocks.3.attn.proj.stats', 'transformer.blocks.3.norm2.stats', 'transformer.blocks.3.mlp.fc1.stats', 'transformer.bl

In [16]:
total_err = 0
inf_count = 0
max_err = 0
for idx, key in enumerate(act_compare_dict_dynamic):
    err = compute_error(act_compare_dict_dynamic[key]['float'][0][0], act_compare_dict_dynamic[key]['quantized'][0][0])
    # print(type(err))
    if torch.isinf(err):
        inf_count += 1
    else:
        total_err += err
        if err > max_err:
            max_err = err
    print(f"{idx} - {key}")
    print(f"{idx} - {err:.2f}")

print(f"Total error: {total_err:.2f}")
# print(f"Total inf: {inf_count}")
# print(f"Max error: {max_err}")

0 - text_embeddings.LayerNorm.stats
0 - 39.83
1 - text_embeddings.quant.stats
1 - 39.09
2 - transformer.patch_embed.proj.stats
2 - inf
3 - transformer.blocks.0.norm1.stats
3 - 9.85
4 - transformer.blocks.0.attn.qkv.stats
4 - 21.50
5 - transformer.blocks.0.attn.proj.stats
5 - 11.46
6 - transformer.blocks.0.norm2.stats
6 - 8.84
7 - transformer.blocks.0.mlp.fc1.stats
7 - 18.13
8 - transformer.blocks.0.mlp.fc2.stats
8 - 9.53
9 - transformer.blocks.1.norm1.stats
9 - 7.37
10 - transformer.blocks.1.attn.qkv.stats
10 - 17.58
11 - transformer.blocks.1.attn.proj.stats
11 - 7.76
12 - transformer.blocks.1.norm2.stats
12 - 8.30
13 - transformer.blocks.1.mlp.fc1.stats
13 - 14.61
14 - transformer.blocks.1.mlp.fc2.stats
14 - 4.42
15 - transformer.blocks.2.norm1.stats
15 - 6.03
16 - transformer.blocks.2.attn.qkv.stats
16 - 12.27
17 - transformer.blocks.2.attn.proj.stats
17 - 8.19
18 - transformer.blocks.2.norm2.stats
18 - 7.17
19 - transformer.blocks.2.mlp.fc1.stats
19 - 13.23
20 - transformer.blocks.2

In [20]:
print(model)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(40, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (quant): QuantStub()
    (dequant): DeQuantStub()
  )
  (token_type_embeddings): Embedding(3, 768)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      (quant): QuantStub()
      (dequant): DeQuantStub()
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features

In [21]:
print(model_static)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): QuantizedEmbedding(num_embeddings=30522, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (position_embeddings): QuantizedEmbedding(num_embeddings=40, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (token_type_embeddings): QuantizedEmbedding(num_embeddings=2, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
    (LayerNorm): QuantizedLayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): QuantizedDropout(p=0.1, inplace=False)
    (quant): Quantize(scale=tensor([0.0120]), zero_point=tensor([63]), dtype=torch.quint8)
    (dequant): DeQuantize()
  )
  (token_type_embeddings): QuantizedEmbedding(num_embeddings=3, embedding_dim=768, dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): 