# FLOPS Calculation

In [5]:
# parameters
num_layers = 24      # number of transformer layers
seq_length = 512     # token input length
D_model = 896        # hidden size
d_FFN = D_model * 4
num_heads = 14
rank = 8

## Question 2 (c) FLOPS for Qwen model

In [9]:
def calculate_flops(num_layers, seq_length, D_model, D_FFN, num_heads):
    """
    Calculate FLOPS for Qwen model.

    Args:
        num_layers (int): Number of transformer layers.
        seq_length (int): Input sequence length.
        D_model (int): Hidden size.
        D_FFN: hidden size for FFN
        num_heads: number of transformer heads for Qwen model

    Returns:
        forward_flops (float): FLOPS for forward pass.
        backward_flops (float): FLOPS for backward pass.
    """
    
    # Attention FLOPS
    attention_flops = (
        6 * seq_length * D_model**2 +
        seq_length**2 * (2 * D_model - num_heads) +
        num_heads * (11 * seq_length**2 + seq_length*(10*seq_length+seq_length-1) + seq_length**2 + 10) +
        D_model * seq_length * (2*seq_length - 1) +
        seq_length * D_model * (2*D_model - 1)
    )

    # Feedforward FLOPS
    FFN_flops = seq_length * (6*D_model*D_FFN + 13*D_FFN)

    # RMSNorm FLOPS
    norm_flops = (3*D_model + 12) * seq_length

    # Residual Connection FLOPS
    residual_flops = 2 * seq_length * D_model

    # positional embedding
    pos_embedding_flops = 3 * D_model * seq_length * (2*D_model/num_heads - 1)

    # FLOPS per layer
    flops_per_layer = attention_flops + FFN_flops + norm_flops + residual_flops + pos_embedding_flops

    # Total FLOPS
    forward_flops = num_layers * flops_per_layer
    backward_flops = 2 * forward_flops

    total_flops = forward_flops + backward_flops

    return total_flops

total_flops = calculate_flops(num_layers, seq_length, D_model, d_FFN, num_heads)

print(f"Forward FLOPS: {total_flops:.2e}")

Forward FLOPS: 1.03e+12


## Question 3(b): FLOPS for Qwen with LoRA

In [10]:
def calculate_flops_with_lora(num_layers, seq_length, D_model, D_FFN, num_heads, rank):
    lora_flops = 3 * (4 * rank * num_layers * D_model**2) # forward + backward flops together
    total_flops = calculate_flops(num_layers, seq_length, D_model, D_FFN, num_heads)
    total_flops_with_lora = lora_flops + total_flops
    return total_flops_with_lora

total_flops_with_lora = calculate_flops_with_lora(num_layers, seq_length, D_model, d_FFN, num_heads, rank)

print(f"Forward FLOPS with LoRA: {total_flops_with_lora:.2e}")

Forward FLOPS with LoRA: 1.04e+12
