# Transformer Accounting

## Basic Memory Calculations
Let $E$ be embedding mem, $L$ num layers
$$ 
M = \underbrace{|Vocab|d}_{E} + L\left(\underbrace{3d^2 + d^2}_{W_{QKV} + W_O} + \underbrace{3d_{ff}d}_{FFN} + 2\underbrace{d}_{\text{LN}}\right) + \underbrace{d}_{\text{LN}} + \underbrace{|Vocab|d}_{\text{Ouput}}
$$


## Basic FLOPS Calculations

Transformer block dominates, focus there. Split into MHA and FFN. Let $c$ be context length, $h$ be number of heads
$$
MHA = \underbrace{6cd^2}_{Q, K, V} + h\left(\underbrace{2c^2(d/h)}_{Q_hK_h^T} + \underbrace{2c^2(d/h)}_{\text{softmax}\cdot V_h}\right) + \underbrace{2d^2c}_{W_O \text{ mult}}
$$
All the multiplies take same amount of time.
$$
FFN = \underbrace{6d_{ff} d c}_{W_2(\sigma(W_1x)\odot W_3x)}
$$

### Flops Conclusion

When context lenght dominates, MHA is more expensive. When $d_{ff}$ dominates, the FFN is more expensive.

In [73]:

def nb_trainable_params(param_dict):
    # Embedding layer

    vocab_size = param_dict['vocab_size']
    num_layers = param_dict['num_layers']
    d_model = param_dict['d_model']
    d_ff = param_dict['d_ff']

    nb_params = 0
    embedding = d_model*vocab_size
    nb_params += embedding
    # Transformer Block layers
    ## Feed forward layer
    ffn_w1 = d_ff*d_model
    ffn_w2 = d_ff*d_model
    ffn_w3 = d_model * d_ff
    nb_params += ffn_w1+ffn_w2+ ffn_w3

    ## MHA params
    W_QKV = 3*d_model**2
    W_O = d_model**2
    nb_params += W_QKV + W_O

    ## Layer Norm 
    ln_in_block = 2*d_model
    nb_params += ln_in_block

    ## Multiplied by num_layers 
    nb_params *= num_layers

    # final LN
    ln_final = d_model
    nb_params += ln_final

    # final output layer
    output = vocab_size * d_model
    nb_params += output

    return nb_params

# flops focusing on costly parts
def nb_flops(param_dict): 
    nb_flops = 0
    nb_flops_by_part = {'MHA' : 0, 'FFN' : 0, 'ouput' : 0} 
    vocab_size = param_dict['vocab_size']
    num_layers = param_dict['num_layers']
    context_length = param_dict['context_length']
    d_model = param_dict['d_model']
    num_heads = param_dict['num_heads']
    d_ff = param_dict['d_ff']
    
    # Transformer block layers
    
    ## MHA (ignoring softmax)
    ### Project input to weights QKV
    proj_flops = 2*(3*d_model**2)*(context_length)
    ### Per heads QK^T
    QKT_flops = 2*(context_length**2)*(d_model//num_heads)
    ### Value flops
    V_flops = 2*(context_length**2)*(d_model//num_heads)
    ### W_O flops
    W_O_flops  = 2*d_model**2 * context_length
    nb_flops_by_part['MHA'] = proj_flops+ num_heads*(QKT_flops+V_flops)+W_O_flops
    nb_flops_by_part['MHA'] *= num_layers
    
    ## FFN
    ### Multiply input (ignoring nonlinearity and the odot) 
    W1x_flops = 2*d_ff*d_model*context_length
    W3x_flops = W1x_flops
    W2_sigma_W1x_dot_W3x = W1x_flops
    nb_flops_by_part['FFN'] = W1x_flops + W3x_flops + W2_sigma_W1x_dot_W3x
    nb_flops_by_part['FFN'] *= num_layers

    ## Output layer
    nb_flops_by_part['output'] = 2*vocab_size*d_model*context_length

    for count in nb_flops_by_part.values():
        nb_flops += count
    return nb_flops, nb_flops_by_part



## GPT-2-XL

In [76]:
# GPT-2 XL 
gpt_2_xl = {}
gpt_2_xl['vocab_size'] = 50257
gpt_2_xl['context_length'] = 1024
gpt_2_xl['num_layers'] = 48
gpt_2_xl['d_model'] = 1600
gpt_2_xl['num_heads'] = 25
gpt_2_xl['d_ff'] = 6400

nb_trainable = nb_trainable_params(gpt_2_xl)
print("Number of trainable parameters", nb_trainable)

print("Memory consumed", nb_trainable*4 /(10**9), "GBs")

nb_flops_total, nb_flops_by_part = nb_flops(gpt_2_xl)
print("Number of flops total", nb_flops_total/(10**12), "Teraflops")
print("Number of MHA flops", nb_flops_by_part['MHA']/(10**12), "Teraflops")
print("Number of FFN flops", nb_flops_by_part['FFN']/(10**12), "Teraflops")
print("Number of output flops", nb_flops_by_part['output']/(10**12), "Teraflops")

print("FFN/MHA", nb_flops_by_part['FFN']/nb_flops_by_part['MHA'])

Number of trainable parameters 5906384000
Memory consumed 23.625536 GBs
Number of flops total 4.5133365248 Teraflops
Number of MHA flops 1.3287555072 Teraflops
Number of FFN flops 3.01989888 Teraflops
Number of output flops 0.1646821376 Teraflops
FFN/MHA 2.272727272727273


## GPT-2 Small

In [77]:
# GPT-2 XL 
gpt_2_small = {}
gpt_2_small['vocab_size'] = 50257
gpt_2_small['context_length'] = 1024
gpt_2_small['num_layers'] = 12
gpt_2_small['d_model'] = 768
gpt_2_small['num_heads'] = 12
gpt_2_small['d_ff'] = 3072

nb_trainable = nb_trainable_params(gpt_2_small)
print("Number of trainable parameters", nb_trainable)

print("Memory consumed", nb_trainable*4 /(10**9), "GBs")

nb_flops_total, nb_flops_by_part = nb_flops(gpt_2_small)
print("Number of flops total", nb_flops_total/(10**12), "Teraflops")
print("Number of MHA flops", nb_flops_by_part['MHA']/(10**12), "Teraflops")
print("Number of FFN flops", nb_flops_by_part['FFN']/(10**12), "Teraflops")
print("Number of output flops", nb_flops_by_part['output']/(10**12), "Teraflops")

print("FFN/MHA", nb_flops_by_part['FFN']/nb_flops_by_part['MHA'])

Number of trainable parameters 615031296
Memory consumed 2.460125184 GBs
Number of flops total 0.349630365696 Teraflops
Number of MHA flops 0.09663676416 Teraflops
Number of FFN flops 0.173946175488 Teraflops
Number of output flops 0.079047426048 Teraflops
FFN/MHA 1.8


## GPT-2 Large

In [78]:
# GPT-2 XL 
gpt_2_large = {}
gpt_2_large['vocab_size'] = 50257
gpt_2_large['context_length'] = 1024
gpt_2_large['num_layers'] = 24
gpt_2_large['d_model'] = 1024
gpt_2_large['num_heads'] = 16
gpt_2_large['d_ff'] = 5120

nb_trainable = nb_trainable_params(gpt_2_large)
print("Number of trainable parameters", nb_trainable)

print("Memory consumed", nb_trainable*4 /(10**9), "GBs")

nb_flops_total, nb_flops_by_part = nb_flops(gpt_2_large)
print("Number of flops total", nb_flops_total/(10**12), "Teraflops")
print("Number of MHA flops", nb_flops_by_part['MHA']/(10**12), "Teraflops")
print("Number of FFN flops", nb_flops_by_part['FFN']/(10**12), "Teraflops")
print("Number of output flops", nb_flops_by_part['output']/(10**12), "Teraflops")

print("FFN/MHA", nb_flops_by_part['FFN']/nb_flops_by_part['MHA'])

Number of trainable parameters 1764780032
Memory consumed 7.059120128 GBs
Number of flops total 1.187728326656 Teraflops
Number of MHA flops 0.309237645312 Teraflops
Number of FFN flops 0.77309411328 Teraflops
Number of output flops 0.105396568064 Teraflops
FFN/MHA 2.5


# GPT-XL long context

In [80]:
# GPT-2 XL 
gpt_2_xl = {}
gpt_2_xl['vocab_size'] = 50257
gpt_2_xl['context_length'] = 16384
gpt_2_xl['num_layers'] = 48
gpt_2_xl['d_model'] = 1600
gpt_2_xl['num_heads'] = 25
gpt_2_xl['d_ff'] = 6400

nb_trainable = nb_trainable_params(gpt_2_xl)
print("Number of trainable parameters", nb_trainable)

print("Memory consumed", nb_trainable*4 /(10**9), "GBs")

nb_flops_total, nb_flops_by_part = nb_flops(gpt_2_xl)
print("Number of flops total", nb_flops_total/(10**12), "Teraflops")
print("Number of MHA flops", nb_flops_by_part['MHA']/(10**12), "Teraflops")
print("Number of FFN flops", nb_flops_by_part['FFN']/(10**12), "Teraflops")
print("Number of output flops", nb_flops_by_part['output']/(10**12), "Teraflops")

print("FFN/MHA", nb_flops_by_part['FFN']/nb_flops_by_part['MHA'])

Number of trainable parameters 5906384000
Memory consumed 23.625536 GBs
Number of flops total 149.5227957248 Teraflops
Number of MHA flops 98.5694994432 Teraflops
Number of FFN flops 48.31838208 Teraflops
Number of output flops 2.6349142016 Teraflops
FFN/MHA 0.49019607843137253
