# FLOPS Calculation

## Question 2 (c)

In [None]:
def calculate_flops(num_layers, seq_length, d_model):
    """
    Calculate FLOPS for a transformer model.

    Args:
        num_layers (int): Number of transformer layers.
        seq_length (int): Input sequence length.
        d_model (int): Hidden size.

    Returns:
        forward_flops (float): FLOPS for forward pass.
        backward_flops (float): FLOPS for backward pass.
    """
    # Attention FLOPS
    attention_flops = (
        3 * seq_length * d_model**2 +
        seq_length**2 * d_model +
        seq_length**2 * 11 +
        seq_length**2 * d_model +
        seq_length * d_model**2
    )

    # Feedforward FLOPS
    ff_flops = (
        4 * seq_length * d_model**2 +
        8 * seq_length * d_model +
        4 * seq_length * d_model**2
    )

    # RMSNorm FLOPS
    norm_flops = seq_length * d_model * 2

    # Residual Connection FLOPS
    residual_flops = 2 * seq_length * d_model

    # FLOPS per layer
    flops_per_layer = attention_flops + ff_flops + norm_flops + residual_flops

    # Total FLOPS
    forward_flops = num_layers * flops_per_layer
    backward_flops = 2 * forward_flops

    return forward_flops, backward_flops

# Example:
num_layers = 24      # 24 transformer layers
seq_length = 512     # 512 token input length
d_model = 768        # 768-dimensional hidden size

forward_flops, backward_flops = calculate_flops(num_layers, seq_length, d_model)

print(f"Forward FLOPS: {forward_flops:.2e}")
print(f"Backward FLOPS: {backward_flops:.2e}")

Forward FLOPS: 9.68e+10
Backward FLOPS: 1.94e+11
