# FLOPS Calculation

In [14]:
from transformers import AutoModelForCausalLM
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name)

# Print model parameters
print(model)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

In [4]:
# parameters
num_layers = 24      # number of transformer layers
seq_length = 512     # token input length
D_model = 896        # hidden size
d_FFN = 4864
num_heads = 14
rank = 4

## Question 2 (c) FLOPS for Qwen model

In [10]:
def calculate_flops(num_layers, seq_length, D_model, D_FFN, num_heads):
    """
    Calculate FLOPS for Qwen model.

    Args:
        num_layers (int): Number of transformer layers.
        seq_length (int): Input sequence length.
        D_model (int): Hidden size.
        D_FFN: hidden size for FFN
        num_heads: number of transformer heads for Qwen model

    Returns:
        forward_flops (float): FLOPS for forward pass.
        backward_flops (float): FLOPS for backward pass.
    """
    
    # Attention FLOPS
    attention_flops = (
        6 * seq_length * D_model**2 +
        seq_length**2 * (2 * D_model - num_heads) +
        num_heads * (11 * seq_length**2 + seq_length*(10*seq_length+seq_length-1) + seq_length**2 + 10) +
        D_model * seq_length * (2*seq_length - 1) +
        seq_length * D_model * (2*D_model - 1)
    )

    # Feedforward FLOPS
    FFN_flops = seq_length * (6*D_model*D_FFN + 13*D_FFN)

    # RMSNorm FLOPS
    norm_flops = (3*D_model + 12) * seq_length

    # Residual Connection FLOPS
    residual_flops = 2 * seq_length * D_model

    # positional embedding
    pos_embedding_flops = 3 * D_model * seq_length * (2*D_model/num_heads - 1)

    # FLOPS per layer
    flops_per_layer = attention_flops + FFN_flops + norm_flops + residual_flops + pos_embedding_flops

    # Total FLOPS
    forward_flops = num_layers * flops_per_layer
    backward_flops = 2 * forward_flops

    total_flops = forward_flops + backward_flops

    return total_flops

total_flops = calculate_flops(num_layers, seq_length, D_model, d_FFN, num_heads)

print(f"Total FLOPS: {total_flops:.2e}")

Total FLOPS: 1.29e+12


## Question 3(b): FLOPS for Qwen with LoRA

In [15]:
def calculate_flops_with_lora(num_layers, seq_length, D_model, D_FFN, num_heads, rank, num_steps):
    lora_flops = 3 * (4 * rank * num_layers * D_model**2) # forward + backward flops together
    total_flops = calculate_flops(num_layers, seq_length, D_model, D_FFN, num_heads)
    total_flops_with_lora = num_steps * (lora_flops + total_flops)
    return total_flops_with_lora

experiment_1 = calculate_flops_with_lora(num_layers=24, seq_length=128, D_model=896, D_FFN=4864, num_heads=14, rank=2, num_steps=5000)
experiment_2 = calculate_flops_with_lora(num_layers=24, seq_length=512, D_model=896, D_FFN=4864, num_heads=14, rank=4, num_steps=5000)
experiment_3 = calculate_flops_with_lora(num_layers=24, seq_length=768, D_model=896, D_FFN=4864, num_heads=14, rank=8, num_steps=5000)


print(f"Total FLOPS with LoRA for experiment 1: {experiment_1:.2e}")
print(f"Total FLOPS with LoRA for experiment 2: {experiment_2:.2e}")
print(f"Total FLOPS with LoRA for experiment 3: {experiment_3:.2e}")

Total FLOPS with LoRA for experiment 1: 1.54e+15
Total FLOPS with LoRA for experiment 2: 6.45e+15
Total FLOPS with LoRA for experiment 3: 9.95e+15
