In [1]:
import torch
import torch.nn.functional as F
import pynvml
from transformers import AutoModelForCausalLM, AutoTokenizer
from ipyexperiments import IPyExperimentsPytorch
import gc

# Run parameters

In [2]:
device = torch.device("cuda")
model_name_or_path = "NousResearch/Llama-2-13b-hf" # microsoft/phi-1_5, microsoft/phi-2, NousResearch/Llama-2-7b-hf, mistralai/Mistral-7B-v0.1, gpt2, gpt2-medium, gpt2-large, gpt2-xl
dtype = torch.float32
mixed_precision_training = True
bs = 4
seq_length = 2048
get_optimizer = lambda parameters: torch.optim.SGD(parameters, lr=0.1, momentum=0.9) # SGD(parameters, lr=0.1), SGD(parameters, lr=0.1, momentum=0.9), AdamW(parameters, lr=0.1)

if mixed_precision_training:
    assert dtype == torch.float32

In [3]:
n_bytes_per_param = 2 if dtype in (torch.float16, torch.bfloat16) else 4

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
get_vram = lambda: pynvml.nvmlDeviceGetMemoryInfo(handle).used / 2**20 # MiB

start_vram = get_vram()

# Initializing CUDA kernels
a = torch.ones((1,1)).to(device); del a; gc.collect(); torch.cuda.empty_cache()
cuda_kernels_vram = get_vram() - start_vram
print(f"CUDA kernels VRAM: {cuda_kernels_vram:.0f} MiB")

exp = IPyExperimentsPytorch()

CUDA kernels VRAM: 716 MiB

*** Experiment started with the Pytorch backend
Device: ID 0, NVIDIA GeForce RTX 3090 (24576 RAM)


*** Current state:
RAM:     Used     Free    Total        Util
CPU:    1,727  123,637  128,658 MB   1.34% 
GPU:    1,035   23,540   24,576 MB   4.21% 


･ RAM:  △Consumed    △Peaked    Used Total | Exec time 0:00:00.000
･ CPU:          0          0      1,727 MB |
･ GPU:          0          0      1,035 MB |


In [4]:
!nvidia-smi

Tue Dec 19 13:24:10 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
|  0%   30C    P2              97W / 370W |    718MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        On  | 00000000:82:00.0 Off |  

# Loading model

In [12]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype, trust_remote_code=True, device_map="auto")
model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

n_training_parameters = sum(p.numel() for p in model.parameters())
n_parameters = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in model.buffers())
model_estimated_vram = n_parameters * n_bytes_per_param / 2**20
model_actual_vram = get_vram() - cuda_kernels_vram - start_vram

print(model.config)
print("=" * 75)
print(model)
print("=" * 75)
print(f"Number of parameters: {(n_training_parameters / 1e9):.3f} B ({n_training_parameters})")
print(f"Number of parameters: {(n_parameters / 1e9):.3f} B ({n_parameters})")
print(f"Model VRAM usage: {(model_actual_vram):.0f} MiB (expected {(model_estimated_vram):.0f} MiB, error {((model_actual_vram - model_estimated_vram) * 100 / model_actual_vram):.1f} %)")
print("=" * 75)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

big_modeling.py:411 -       dispatch_model() | Some parameters are on the meta device device because they were offloaded to the cpu.


LlamaConfig {
  "_name_or_path": "NousResearch/Llama-2-13b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 5120,
  "initializer_range": 0.02,
  "intermediate_size": 13824,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "num_attention_heads": 40,
  "num_hidden_layers": 40,
  "num_key_value_heads": 40,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float32",
  "transformers_version": "4.36.2",
  "use_cache": false,
  "vocab_size": 32000
}

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120, padding_idx=0)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
  

In [6]:
!nvidia-smi

Tue Dec 19 13:24:30 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
|  0%   32C    P2              97W / 370W |  12276MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        On  | 00000000:82:00.0 Off |  

In [7]:
input_ids = torch.randint(0, len(tokenizer), (bs, seq_length)).to(device)
attention_mask = torch.ones((bs, seq_length)).to(device)

･ RAM:  △Consumed    △Peaked    Used Total | Exec time 0:00:00.004
･ CPU:          0          0      3,896 MB |
･ GPU:          0          0     12,593 MB |


# Inference forward pass

## warmup

In [8]:
_ = model.eval()

with torch.no_grad():
    out = model(input_ids=input_ids, attention_mask=attention_mask)

del out; gc.collect(); torch.cuda.empty_cache()

inference_warmup = get_vram() - model_actual_vram - cuda_kernels_vram - start_vram
warmup = inference_warmup
print(f"Inference warmup took: {inference_warmup:.0f} MiB")
print("=" * 75)

Inference warmup took: 376 MiB
･ RAM:  △Consumed    △Peaked    Used Total | Exec time 0:00:16.860
･ CPU:      2,306          0      6,202 MB |
･ GPU:        376      6,786     12,969 MB |


In [9]:
!nvidia-smi

Tue Dec 19 13:24:48 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
|  0%   34C    P2              99W / 370W |  12652MiB / 24576MiB |     18%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        On  | 00000000:82:00.0 Off |  

## actual

In [10]:
_ = model.eval()

with torch.no_grad():
    out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)[0]
    # probs = F.softmax(out.logits[:, -1, :], dim=-1) # for inference we need probabilities only over the last token; omit this as it is very small

out_bs, out_sequence_length, out_embedding_size = out.shape
n_bytes_per_param_out = 2 if out.dtype in (torch.float16, torch.bfloat16) else 4
output_estimated_vram = out_bs * out_sequence_length * out_embedding_size * n_bytes_per_param_out / 2**20
print(f"Out tensor dtype: {out.dtype}")

!nvidia-smi

total_forward_pass_vram = get_vram() - warmup - model_actual_vram - cuda_kernels_vram - start_vram
gc.collect(); torch.cuda.empty_cache()

!nvidia-smi

output_vram = get_vram() - warmup - model_actual_vram - cuda_kernels_vram - start_vram
del out; gc.collect(); torch.cuda.empty_cache()

activations_actual_vram = total_forward_pass_vram - output_vram

print(f"Total forward pass VRAM usage: {total_forward_pass_vram:.0f} MiB")
print(f"Output tensor with bs {out_bs}, seq length {out_sequence_length} and emb size {out_embedding_size} VRAM usage: {output_vram:.0f} MiB (expected {output_estimated_vram:.0f} MiB)")
print(f"Activations VRAM usage: {activations_actual_vram:.0f} MiB")
#print(torch.cuda.memory_summary())
print("=" * 75)

Out tensor dtype: torch.float32
Tue Dec 19 13:25:03 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
|  0%   34C    P2              98W / 370W |  19454MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090      

In [11]:
!nvidia-smi

Tue Dec 19 13:25:04 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  | 00000000:81:00.0 Off |                  N/A |
|  0%   34C    P2              99W / 370W |  12652MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        On  | 00000000:82:00.0 Off |  

# Training step

## warmup

In [None]:
_ = model.train()
optimizer = get_optimizer(model.parameters())

with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    probs = F.softmax(out.logits, dim=-1)
    loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

del out
del probs
del loss
del optimizer

gc.collect(); torch.cuda.empty_cache()

train_warmup = get_vram() - inference_warmup - model_actual_vram - cuda_kernels_vram - start_vram
warmup += train_warmup
print(f"Train warmup took: {train_warmup:.0f} MiB")
print("=" * 75)

## actual

In [None]:
_ = model.train()
optimizer = get_optimizer(model.parameters())
#scaler = torch.cuda.amp.GradScaler(enabled=mixed_precision_training) # scaler is not needed with bf16

In [None]:
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=mixed_precision_training):
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    train_forward_pass_vram = get_vram() - warmup - model_actual_vram - cuda_kernels_vram - start_vram
    
    probs = F.softmax(out.logits, dim=-1)
    probs_vram = get_vram() - train_forward_pass_vram - warmup - model_actual_vram - cuda_kernels_vram - start_vram
    
    loss = F.cross_entropy(probs.permute(0, 2, 1), input_ids) # mapping tokens into themselves
    loss_calculation_vram = get_vram() - probs_vram - train_forward_pass_vram - warmup - model_actual_vram - cuda_kernels_vram - start_vram

loss.backward()
optimizer.step()
#scaler.scale(loss).backward()
#scaler.step(optimizer)
#scaler.update()
backward_vram = get_vram() - loss_calculation_vram - probs_vram - train_forward_pass_vram - warmup - model_actual_vram - cuda_kernels_vram - start_vram

total_train_forward_pass_vram = train_forward_pass_vram + probs_vram + loss_calculation_vram + backward_vram

print(f"Model gradients type: {next(model.parameters()).grad.dtype}")
print(f"Total train forward pass VRAM usage: {total_train_forward_pass_vram:.0f} MiB" + (f" (expect {(n_parameters * 2 / 2**20):.0f} MiB of these to be for fp16 weights copy)" if mixed_precision_training else ""))
#print(f"Actual probs tensor VRAM usage: {probs_vram:.0f} MiB")
#print(f"Loss calculation VRAM usage: {loss_calculation_vram:.0f} MiB")
#print(f"Backward calculation VRAM usage: {backward_vram:.0f} MiB")

del out
del probs
del loss
gc.collect(); torch.cuda.empty_cache() # calling `free` on allocated memory for activations and outputs

gradients_optimizer_total_vram = get_vram() - warmup - model_actual_vram - cuda_kernels_vram - start_vram
optimizer.zero_grad(set_to_none=True); gc.collect(); torch.cuda.empty_cache()
optimizer_vram = get_vram() - warmup - model_actual_vram - cuda_kernels_vram - start_vram
del optimizer; gc.collect(); torch.cuda.empty_cache()

gradients_vram = gradients_optimizer_total_vram - optimizer_vram
print(f"Gradients VRAM usage: {gradients_vram:.0f} MiB (trainable params were {(n_training_parameters * n_bytes_per_param / 2**20):.0f} MiB)")
print(f"Optimizer states VRAM usage: {optimizer_vram:.0f} MiB")
print(f"Activations VRAM usage: {(total_train_forward_pass_vram - (n_parameters * 2 / 2**20 if mixed_precision_training else 0) - output_estimated_vram * 2 - gradients_vram - optimizer_vram):.0f} MiB")
print("=" * 75)

# Estimation activations

In [None]:
n_bytes_per_param = 2 if mixed_precision_training or dtype in (torch.float16, torch.bfloat16) else 4

hidden_size = model.config.hidden_size
num_attention_heads = model.config.num_attention_heads
num_key_value_heads = model.config.num_key_value_heads if hasattr(model.config, "num_key_value_heads") else model.config.num_attention_heads # different from num_attention_heads in case of GQA
intermediate_size = model.config.intermediate_size if hasattr(model.config, "intermediate_size") else 4 * model.config.hidden_size # MLP projection
num_hidden_layers = model.config.num_hidden_layers
head_dim = hidden_size // num_attention_heads
print(f"Calculating size of activation for single block with:\nbatch size {bs}\nseq length {seq_length}\nhidden size {hidden_size}\nnum attention heads {num_attention_heads}\nnum key value heads {num_key_value_heads}\nintermediate size {intermediate_size}\nhead dim {head_dim}\nnum hidden layers {num_hidden_layers}")
print("=" * 75)

attention_input      = n_bytes_per_param * bs * seq_length * hidden_size
q                    = n_bytes_per_param * bs * seq_length * head_dim * num_attention_heads # for Q @ K.T
k                    = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads # num_key_value_heads might be different from num_attention_heads in case of GQA
softmax_output       = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2 # to multiply with V
softmax_dropout_mask = 1                 * bs * num_attention_heads * seq_length ** 2 # single byte per elem
dropout_output       = n_bytes_per_param * bs * num_attention_heads * seq_length ** 2
v                    = n_bytes_per_param * bs * seq_length * head_dim * num_key_value_heads
out_proj_input       = n_bytes_per_param * bs * seq_length * num_attention_heads * head_dim
attention_dropout    = 1                 * bs * seq_length * hidden_size
#attention_block = attention_input + q + k + softmax_output + v + out_proj_input
attention_block = attention_input + q + k + softmax_output + v + out_proj_input + softmax_dropout_mask + dropout_output + attention_dropout

mlp_input        = n_bytes_per_param * bs * seq_length * hidden_size
activation_input = n_bytes_per_param * bs * seq_length * intermediate_size # SiLU
down_proj_input  = n_bytes_per_param * bs * seq_length * intermediate_size
dropout_mask     = 1                 * bs * seq_length * hidden_size # single byte per elem
#mlp_block = mlp_input + activation_input + down_proj_input
mlp_block = mlp_input + activation_input + down_proj_input + dropout_mask

layer_norms = n_bytes_per_param * bs * seq_length * hidden_size * 2 # 2 layer norms

layer = attention_block + mlp_block + layer_norms
print(f"Single layer (out of {num_hidden_layers}) estimated activations VRAM usage: {layer // 2**20} MiB")
print(f"All layers estimated activations VRAM usage: {layer * num_hidden_layers // 2**20} MiB")
print(f"Estimated activations on inference forward pass VRAM usage (softmax output + v): {(softmax_output + v) // 2**20} MiB")
print("=" * 75)

In [None]:
# https://arxiv.org/pdf/2205.05198.pdf

def calculate_attention_block():
     return 11 * seq_length * bs * hidden_size + 5 * num_attention_heads * seq_length ** 2 * bs

def calculate_mlp_block():
     return 19 * seq_length * bs * hidden_size

def calculate_layernorms():
    return 4 * seq_length * bs * hidden_size

def calculate_per_layer():
    return seq_length * bs * hidden_size * (34 + 5 * num_attention_heads * seq_length / hidden_size)

assert calculate_attention_block() + calculate_mlp_block() + calculate_layernorms() == calculate_per_layer() == layer