# Playground Notebook For Quantizing VLP Models

## Define Helper Functions

In [12]:
import os
import torch.distributed as dist

# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# Initialize the process group
dist.init_process_group(backend='gloo', init_method='env://', world_size=1, rank=0)

# Verify initialization
print(f"Initialized: {dist.is_initialized()}")

Initialized: True


In [13]:
import os
import torch

def print_size_of_model(model):
    """
    Function to print the size of the model.

    Args:
        model (torch.nn.Module): The model to get the size
    
    Returns:
        None
    """
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def get_accuracy(pl_module, logits, target, device="cpu"):
        correct = 0
        total = 0
        logits, target = (
            logits.detach().to(device),
            target.detach().to(device),
        )
        preds = logits.argmax(dim=-1)
        preds = preds[target != -100]
        target = target[target != -100]
        if target.numel() == 0:
            return 1

        assert preds.shape == target.shape

        correct += torch.sum(preds == target)
        total += target.numel()

        return correct/total

In [14]:
import pytorch_lightning as pl

# Define the configuration for the experiments
vilt_config_nlvr2 = {'exp_name': 'test_ood_nlvr2', 'seed': 0, 'datasets': ['ood_nlvr2'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 0, 'nlvr2': 1, 'irtr': 0}, 'batch_size': 128, 'accelerator': 'gpu', 'train_transform_keys': ['pixelbert_randaug'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 0, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult': 1, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/vilt_nlvr2.ckpt', 'num_workers': 8, 'precision': 32}
vilt_config_vqav2 = {'exp_name': 'test_ood_vqa', 'seed': 0, 'datasets': ['ood_vqa'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 1, 'nlvr2': 0, 'irtr': 0}, 'batch_size': 256, 'accelerator': 'gpu', 'train_transform_keys': ['pixelbert_randaug'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 0, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult': 10, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 0.1, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/vilt_vqa.ckpt', 'num_workers': 8, 'precision': 32}

meter_config_nlvr2 = {'exp_name': 'test_ood_nlvr2', 'seed': 0, 'datasets': ['ood_nlvr2'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 0, 'vcr': 0, 'vcr_qar': 0, 'nlvr2': 1, 'irtr': 0, 'contras': 0, 'snli': 0}, 'batch_size': 256, 'accelerator': 'gpu', 'train_transform_keys': ['clip'], 'val_transform_keys': ['clip'], 'image_size': 288, 'patch_size': 16, 'draw_false_image': 0, 'image_only': False, 'resolution_before': 224, 'vqav2_label_size': 3129, 'max_text_len': 50, 'tokenizer': 'roberta-base', 'vocab_size': 50265, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'num_top_layer': 6, 'input_image_embed_size': 768, 'input_text_embed_size': 768, 'vit': 'ViT-B/16', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 6, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 1e-05, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult_head': 10, 'lr_mult_cross_modal': 5, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 64, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/meter_nlvr2.ckpt', 'num_workers': 8, 'precision': 32}
meter_config_vqav2 = {'exp_name': 'test_ood_vqa', 'seed': 0, 'datasets': ['ood_vqa'], 'loss_names': {'itm': 0, 'mlm': 0, 'mpp': 0, 'vqa': 1, 'vcr': 0, 'vcr_qar': 0, 'nlvr2': 0, 'irtr': 0, 'contras': 0, 'snli': 0}, 'batch_size': 512, 'accelerator': 'gpu', 'train_transform_keys': ['clip'], 'val_transform_keys': ['clip'], 'image_size': 576, 'patch_size': 16, 'draw_false_image': 0, 'image_only': False, 'resolution_before': 224, 'vqav2_label_size': 3129, 'max_text_len': 50, 'tokenizer': 'roberta-base', 'vocab_size': 50265, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'num_top_layer': 6, 'input_image_embed_size': 768, 'input_text_embed_size': 768, 'vit': 'ViT-B/16', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 6, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 5e-06, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 10, 'max_steps': 1, 'warmup_steps': 0.1, 'end_lr': 0, 'lr_mult_head': 50, 'lr_mult_cross_modal': 5, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 0.1, 'test_only': True, 'data_root': '/data-4/users/mileriso/datasets/OOD/arrows', 'log_dir': 'result', 'per_gpu_batchsize': 4, 'num_gpus': 1, 'num_nodes': 1, 'load_path': '/data-4/users/mileriso/models/meter_vqa.ckpt', 'num_workers': 8, 'precision': 32}

pl.seed_everything(0)

Seed set to 0


0

In [15]:
import warnings
# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="promote has been superseded by promote_options='default'")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.helpers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.layers is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, message="Importing from timm.models.registry is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch16_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_large_patch32_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_huge_patch14_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_224_in21k in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50_384 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_small_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet26d_224 in registry")
warnings.filterwarnings("ignore", category=UserWarning, message="Overwriting vit_base_resnet50d_224 in registry")
warnings.filterwarnings("ignore", category=FutureWarning, message="You are using `torch.load` with `weights_only=False`")

## Initialize the Datamodule

In [16]:
from torch.utils.data import Subset
from vilt.datamodules.multitask_datamodule import MTDataModule as MTDataModuleVILT
from meter.datamodules.multitask_datamodule import MTDataModule as MTDataModuleMeter


class SmallMTDataModuleVILT(MTDataModuleVILT):
    def __init__(self, _config, dist=False, num_samples=10):
        super().__init__(_config, dist)
        self.num_samples = num_samples

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.num_samples))

class SmallMTDataModuleMETER(MTDataModuleMeter):
    def __init__(self, _config, dist=False, num_samples=10):
        super().__init__(_config, dist)
        self.num_samples = num_samples

    def setup(self, stage):
        super().setup(stage)
        
        # Limit the number of samples in the datasets
        self.train_dataset = Subset(self.train_dataset, range(self.num_samples))
        self.val_dataset = Subset(self.val_dataset, range(self.num_samples))
        self.test_dataset = Subset(self.test_dataset, range(self.num_samples))

In [17]:
# Set the configuration
_config = vilt_config_vqav2
_config["model_"] = "vilt"

# ==========================================
# ========= Create full datamodule =========
# ==========================================
if "meter" in _config["model_"]:
    full_dm = MTDataModuleMeter(_config, dist=False)
    full_dm.setup("test")
    full_dataloader = full_dm.test_dataloader()
else:
    full_dm = MTDataModuleVILT(_config, dist=False)
    full_dm.setup("test")
    full_dataloader = full_dm.test_dataloader()


test_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=5)
test_dm.setup("test")
test_dataloader = test_dm.test_dataloader()

print(f"Length of the test dataloader: {len(test_dataloader.dataset)}")
print(f"Length of the full dataloader: {len(full_dataloader.dataset)}")

Length of the test dataloader: 5
Length of the full dataloader: 11942


## Initialize Model

In [18]:
from vilt.modules import ViLTransformerSS
from meter.modules import METERTransformerSS

if _config["model_"] == "vilt":
    model = ViLTransformerSS(_config)
    print("Initialized ViLT model")

elif _config["model_"] == "meter":
    model = METERTransformerSS(_config)
    print("Initialized METER model")

else:
    raise ValueError("Model not supported")

Initialized ViLT model


## Initialize The Trainer

In [19]:
# ========== Initialize the trainer for full precision ==========
exp_name = f'{_config["exp_name"]}'

os.makedirs(_config["log_dir"], exist_ok=True)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    verbose=True,
    monitor="val/the_metric",
    mode="max",
    save_last=True,
)
logger = pl.loggers.TensorBoardLogger(
    _config["log_dir"],
    name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
)

lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback, lr_callback]

num_gpus = (
    _config["num_gpus"]
    if isinstance(_config["num_gpus"], int)
    else len(_config["num_gpus"])
)

grad_steps = _config["batch_size"] // (
    _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
)

max_steps = _config["max_steps"] if _config["max_steps"] is not None else None


trainer = pl.Trainer(
        accelerator="cpu",
        devices=1,
        num_nodes=_config["num_nodes"],
        precision=_config["precision"],
        # strategy="ddp",
        benchmark=True,
        deterministic=False,
        max_epochs=_config["max_epoch"] if max_steps is None else 1000,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=logger,
        # accumulate_grad_batches=grad_steps,
        log_every_n_steps=10,
        fast_dev_run=_config["fast_dev_run"],
        val_check_interval=_config["val_check_interval"],
    )

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


# Quantization

### Test Inference

In [20]:
trainer.test(model, datamodule=test_dm)

Testing: |          | 0/? [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 5. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('vqa/val/loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('vqa/val/score', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/data-4/users/mileriso/envs/.dev/lib/python3.10/site-packages/pytor

[{'vqa/val/loss': 3.264875888824463,
  'vqa/val/score': 0.0,
  'vqa/val/score_epoch': 0.0,
  'vqa/val/loss_epoch': 3.264875888824463,
  'val/the_metric': 0.0}]

# Post Training Quantization 8-bits

### Dynamic Quantization

In [9]:
# torch.quantization.quantize_dynamic(
#         model, {torch.nn.Linear, torch.nn.Dropout, torch.nn.GELU, torch.nn.Conv2d}, dtype=torch.qint8, inplace=True
#     )

# torch.quantization.quantize_dynamic(
#         model, {torch.nn.Embedding, torch.nn.LayerNorm, model.transformer.patch_embed.proj}, dtype=torch.quint8, inplace=True
#     )



# print("Size after quantization:")
# print_size_of_model(model)
# print(model)

### Static Quantization

In [10]:
import os
# Limit the number of CPUs
os.environ["OMP_NUM_THREADS"] = "4"  # Set this to the number of CPUs you want to use
os.environ["MKL_NUM_THREADS"] = "4"  # Set this to the number of CPUs you want to use

# General quantization configuration
quantization_config = torch.quantization.get_default_qconfig('x86')

# Configuration for nn.Embedding layers
embedding_qconfig = torch.quantization.QConfig(
    activation=torch.quantization.HistogramObserver.with_args(reduce_range=True),
    weight=torch.quantization.default_float_qparams_observer.with_args(dtype=torch.quint8)
)

if _config["model_"] == "vilt":
    # Assign the quantization configurations to the model
    model.qconfig = quantization_config
    model.token_type_embeddings.qconfig = embedding_qconfig
    model.text_embeddings.word_embeddings.qconfig = embedding_qconfig
    model.text_embeddings.position_embeddings.qconfig = embedding_qconfig
    model.text_embeddings.token_type_embeddings.qconfig = embedding_qconfig

elif _config["model_"] == "meter":
    # Assign the quantization configurations to the model
    model.qconfig = quantization_config
    model.token_type_embeddings.qconfig = embedding_qconfig
    model.text_transformer.embeddings.word_embeddings.qconfig = embedding_qconfig
    model.text_transformer.embeddings.position_embeddings.qconfig = embedding_qconfig
    model.text_transformer.embeddings.token_type_embeddings.qconfig = embedding_qconfig

# Perform static quantization
torch.ao.quantization.prepare(model, inplace=True)
trainer.test(model, datamodule=test_dm)
torch.ao.quantization.convert(model, inplace=True)

print("Size after quantization:")
print_size_of_model(model)



Testing: |          | 0/? [00:00<?, ?it/s]

The input tensor is not contiguous. This will result in a copy and increase in memory usage.
Type of self:  <class 'meter.modules.clip_model.CLIP'>
not callable
Dtype of self: torch.float32
not callable
The input tensor is contiguous.
Type of self:  <class 'meter.modules.clip_model.CLIP'>
not callable
Dtype of self: torch.float32
not callable




Size after quantization:
Size (MB): 337.249982


## Test Inference After Quantization

In [11]:
# Perform test inference
trainer.test(model, datamodule=test_dm)

Testing: |          | 0/? [00:00<?, ?it/s]

The input tensor is not contiguous. This will result in a copy and increase in memory usage.
Type of self:  <class 'meter.modules.clip_model.CLIP'>


AttributeError: 'CLIP' object has no attribute 'dtype'