In [None]:
import torch
from transformers import AutoTokenizer, AutoConfig
from llm2vec.models import LlamaBiForMNTP
from peft import LoraConfig, get_peft_model

def initialize_peft(
    model,
    lora_r: int = 16,   # Using 16 as in your config
    lora_alpha: int = 32,  # Typically 2 * lora_r
    lora_dropout: float = 0.05,
    lora_modules=None
):
    if lora_modules is None and model.config.__class__.__name__ in ["LlamaConfig"]:
        lora_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    
    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=None,  # This is customized by the actual model call.
    )

    model = get_peft_model(model, config)
    return model

def measure_gpu_memory(config_args):
    # Load tokenizer and config
    print("Loading Model and Tokenizer...")

    tokenizer = AutoTokenizer.from_pretrained(config_args["model_name_or_path"])
    config = AutoConfig.from_pretrained(config_args["model_name_or_path"])
    
    # Adjust Mask Token:
    if tokenizer.mask_token is None:
        if config_args["mask_token_type"] == "blank":
            tokenizer.mask_token = "_"
        elif config_args["mask_token_type"] == "eos":
            tokenizer.mask_token = tokenizer.eos_token
        elif config_args["mask_token_type"] == "mask":
            tokenizer.add_tokens(["<mask>"])
            tokenizer.mask_token = "<mask>"

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # We can use EOS as PAD token
    
    # Load the Bidirectional Model using LLM2Vec package
    model_class = LlamaBiForMNTP
    torch_dtype = torch.bfloat16 if config_args["torch_dtype"] == "bfloat16" else torch.float16
    model = model_class.from_pretrained(
        config_args["model_name_or_path"],
        config=config,
        torch_dtype=torch_dtype,
        attn_implementation=config_args["attn_implementation"]
    )

    # *** Move the model to GPU before applying anything ***
    model.to('cuda')  # Flash Attention requires the entire model to be on GPU.

    # Apply PEFT (LoRA) after moving the model to GPU
    model.model = initialize_peft(model.model, lora_r=config_args["lora_r"], lora_alpha=2*config_args["lora_r"])

    # Create dummy input data - respecting the batch size and sequence length from the config
    batch_size = config_args["per_device_train_batch_size"]
    max_seq_length = config_args["max_seq_length"]

    dummy_input = ["This is a test sentence for benchmarking." for _ in range(batch_size)]  
    inputs = tokenizer(dummy_input, return_tensors="pt", max_length=max_seq_length, truncation=True, padding="max_length").to('cuda')

    # Clear GPU cache and measure initial memory usage
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    initial_memory = torch.cuda.memory_allocated()

    print("Running forward pass...")
    # Run a dummy forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Measure memory after the forward pass
    final_memory = torch.cuda.memory_allocated()
    memory_used = final_memory - initial_memory

    print(f"Estimated memory usage for batch_size={batch_size}, seq_length={max_seq_length}: {memory_used // 2**20} MB")
    return memory_used // 2**20

# Example Config Input (typically matching what would come from JSON config)
config_args = {
    "model_name_or_path": "princeton-nlp/Sheared-LLaMA-1.3B",
    "lora_r": 16,
    "per_device_train_batch_size": 256,
    "max_seq_length": 512,
    "mask_token_type": "blank",
    "torch_dtype": "bfloat16",
    "attn_implementation": "flash_attention_2"
}

# Run the check
measure_gpu_memory(config_args)

Loading Model and Tokenizer...
Running forward pass...
Estimated memory usage for batch_size=256, seq_length=512: 40576 MB


40576

: 

In [None]:
# above bove calculation estimates are not correct, But, running the script and calling Nvidia-smi in the console Does reveal the actual GPU footprint ...