In [1]:
# Suppress specific warnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# General imports
import os
import torch
import torch.quantization
import pytorch_lightning as pl
from copy import deepcopy
import random
random.seed(42)

# Model Specific imports
from vilt.datamodules.multitask_datamodule import MTDataModule as MTDataModuleVILT
from meter.datamodules.multitask_datamodule import MTDataModule as MTDataModuleMeter
from vilt.modules import ViLTransformerSS
from meter.modules import METERTransformerSS

# Custom imports
import configs
from quantization_utils import get_quantization_config
from quantization_utils import  SmallMTDataModuleMETER, SmallMTDataModuleVILT

from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Set the configuration
_config = configs.vilt_config_vqav2
_config["batch_size"] = 32
_config["per_gpu_batchsize"] = 32
_config["learning_rate"] = 0.01

# Set the PyTorch Lightning seed
pl.seed_everything(_config["seed"])

# Limit the number of CPUs
os.environ["OMP_NUM_THREADS"] = "3"  # Set this to the number of CPUs you want to use
os.environ["MKL_NUM_THREADS"] = "3"  # Set this to the number of CPUs you want to use

# Set environment variables
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

Seed set to 0


In [8]:
def print_size_of_model(model):
    """
    Function to print the size of the model.

    Args:
        model (torch.nn.Module): The model to get the size
    
    Returns:
        None
    """
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [9]:
# ==========================================
# ========= Create full datamodule =========
# ==========================================
if "meter" in _config["model"]:
    full_dm = MTDataModuleMeter(_config, dist=False)
    
    test_dm = SmallMTDataModuleMETER(_config, dist=False, percentage=0.01)
    test_dm.setup("test", is_random=True)
    test_dataloader = test_dm.test_dataloader()
    
    # fine_tune_dm = SmallMTDataModuleMETER(_config, dist=False, percentage=0.25)
    fine_tune_dm = SmallMTDataModuleMETER(_config, dist=False, num_samples=8, start_idx=0)
    fine_tune_dm.setup("fit", is_random=True)
    fine_tune_dataloader = fine_tune_dm.test_dataloader()

elif "vilt" in _config["model"]:
    full_dm = MTDataModuleVILT(_config, dist=False)

    test_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=50)
    test_dm.setup("test", is_random=True)
    test_dataloader = test_dm.test_dataloader()

    fine_tune_dm = SmallMTDataModuleVILT(_config, dist=False, num_samples=50)
    fine_tune_dm.setup("test", is_random=True)
    fine_tune_dataloader = fine_tune_dm.test_dataloader()

else:
    raise ValueError("Model not supported: ", _config["model"])

print(f"Batch size: {_config['batch_size']}")
print(f"Lenght of the finetune dataloader: {len(fine_tune_dataloader)}")
print(f"Length of test dataloader: {len(test_dataloader)}")

Batch size: 32
Lenght of the finetune dataloader: 2
Length of test dataloader: 2


In [10]:
if _config["model"] == "vilt":
    model = ViLTransformerSS(_config)
    print("Initialized ViLT model")

elif _config["model"] == "meter":
    model = METERTransformerSS(_config)
    print("Initialized METER model")

else:
    raise ValueError("Model not supported: ", _config["model"])

Initialized ViLT model


In [11]:
model_half = model.half()
print("Size of Half Precision Model")
print_size_of_model(model_half)

Size of Half Precision Model
Size (MB): 235.229554


In [12]:
def create_quantization_config_dict(bits, module_name_list):
    """
    Creates a dictionary of quantization configurations for specific modules in a model.
    
    Args:
        bits (int): The number of bits to quantize the model to. Available options are 8, 4, and 2.
        module_name_list (list of str): A list of module names (or dot-separated paths) within the model to quantize.
        
    Returns:
        dict: A dictionary of quantization configurations for the specified modules.
    """

    quantization_config, embedding_config = get_quantization_config(bits)
    modules_config = {}

    for module_name in module_name_list:
        if "embedding" in module_name:
            modules_config[module_name] = embedding_config
        else:
            modules_config[module_name] = quantization_config
    
    return modules_config


def quantize_modules(model, bits, module_name_list, inplace=True):
    """
    Quantizes specific modules in a deep copy of the input model using dynamic quantization.

    Args:
        model (torch.nn.Module): The PyTorch model to quantize.
        bits (int): The number of bits to quantize the model to. Available options are 8, 4, and 2.
        module_names_to_quantize (list of str): A list of module names (or dot-separated paths)
                                                 within the model to apply dynamic quantization to.

    Returns:
        torch.nn.Module: A deep copy of the input model with specified modules dynamically quantized.
                         Returns None if no modules are provided to quantize.
    """

    modules_config = create_quantization_config_dict(bits, module_name_list)

    model_quantized = deepcopy(model)
    
    
    torch.quantization.quantize_dynamic(
        model_quantized, modules_config, inplace=True
    )

    return model_quantized

In [18]:
print("Size of Full Precision Model")
print_size_of_model(model)

Size of Full Precision Model
Size (MB): 1296.258138


In [9]:
# Quantize the model
bit_precision = 4
quantization_config, embedding_config = get_quantization_config(bit_precision)

q_dict = {torch.nn.Linear: quantization_config, torch.nn.LayerNorm: quantization_config, torch.nn.Conv2d: quantization_config, torch.nn.Embedding: embedding_config}

model_dynamic = torch.quantization.quantize_dynamic(
    model, q_dict, inplace=False
)

In [10]:
print("Size of Quantized Model")
print_size_of_model(model_dynamic)

Size of Quantized Model


Size (MB): 412.013446


Size of Half Precision Model
Size (MB): 648.241754


In [None]:
# Count the nu ber of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params}")

# Print in millions
print(f"Number of parameters (in millions): {num_params / 1e6}")

Number of parameters: 113962754
Number of parameters (in millions): 113.962754


In [11]:
# COunt the total number of layers
num_layers = sum(1 for p in model.parameters())
print(f"Number of layers: {num_layers}")

Number of layers: 164


In [12]:
# COunt number of parameters in a given parent module
num_params = sum(p.numel() for p in model.transformer.blocks.parameters())
print(f"Number of parameters in millions: {num_params / 1e6}")

Number of parameters in millions: 85.054464


In [13]:
# ========== Initialize the trainer for full precision ==========
def init_trainer(_config, accelerator, num_devices, max_epochs, max_steps):
    exp_name = f'{_config["exp_name"]}'

    os.makedirs(_config["log_dir"], exist_ok=True)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        verbose=True,
        monitor="val/the_metric",
        mode="max",
        save_last=True,
    )
    logger = pl.loggers.TensorBoardLogger(
        _config["log_dir"],
        name=f'{exp_name}_seed{_config["seed"]}_from_{_config["load_path"].split("/")[-1][:-5]}',
    )

    lr_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
    # callbacks = [lr_callback]
    callbacks = [checkpoint_callback, lr_callback]

    num_gpus = (
        _config["num_gpus"]
        if isinstance(_config["num_gpus"], int)
        else len(_config["num_gpus"])
    )

    grad_steps = max(_config["batch_size"] // (
        _config["per_gpu_batchsize"] * num_gpus * _config["num_nodes"]
    ), 1)

    # max_steps = _config["max_steps"] if _config["max_steps"] is not None else None


    trainer = pl.Trainer(
            accelerator=accelerator,
            devices=num_devices,
            num_nodes=_config["num_nodes"],
            precision=_config["precision"],
            # strategy="ddp",
            benchmark=True,
            deterministic=False,
            max_epochs=max_epochs,
            max_steps=max_steps,
            callbacks=callbacks,
            logger=logger,
            accumulate_grad_batches=grad_steps,
            log_every_n_steps=10,
            fast_dev_run=_config["fast_dev_run"],
            val_check_interval=_config["val_check_interval"],
        )
    
    return trainer

trainer = init_trainer(_config, accelerator="cpu", num_devices=1, max_epochs=3, max_steps=1)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [14]:
trainer.fit(model, test_dataloader)


   | Name                  | Type              | Params | Mode 
---------------------------------------------------------------------
0  | text_embeddings       | BertEmbeddings    | 23.5 M | train
1  | token_type_embeddings | Embedding         | 2.3 K  | train
2  | transformer           | VisionTransformer | 87.5 M | train
3  | pooler                | Pooler            | 590 K  | train
4  | nlvr2_classifier      | Sequential        | 2.4 M  | train
5  | train_nlvr2_accuracy  | Accuracy          | 0      | train
6  | train_nlvr2_loss      | Scalar            | 0      | train
7  | dev_nlvr2_accuracy    | Accuracy          | 0      | train
8  | dev_nlvr2_loss        | Scalar            | 0      | train
9  | test_nlvr2_accuracy   | Accuracy          | 0      | train
10 | test_nlvr2_loss       | Scalar            | 0      | train
11 | quant                 | QuantStub         | 0      | train
12 | dequant               | DeQuantStub       | 0      | train
---------------------------------

Epoch 0:  50%|█████     | 1/2 [00:54<00:54,  0.02it/s, v_num=96]

TypeError: ViLTransformerSS.on_train_epoch_end() missing 1 required positional argument: 'outs'