In [20]:
def calc_saved_block_activations(B,S,D,NH, amp=False, dropout=False, bias=False):
    bo = 1
    fp = 4
    mp = 2 if amp else 4
    hp = 2
    i64 = 8

    total = 0

    ## layer norm1

    # NativeLayerNormBackward0-input
    total += B*S*D * fp # input tensor, typically not cached previously. 
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total += int(bias) * 0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights. 
    

    print("LN1", total)
    last = total

    ## layer norm2

    # NativeLayerNormBackward0-input
    total += B*S*D * fp # this tensor comes from addition of attention and input tensor
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total +=int(bias) *  0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights.
    
    print("LN2", total-last)
    last = total

    ## scaled_dot_product_attention efficient implmenentation

    # MmBackward0-self (c_attn)
    total += B*S*D * mp
    # MmBackward0-mat2 (c_attn)
    total += 3*D*D * hp if amp else 0 # model parameter (attn.c_attn.weight) at half precision, or full precision

        
    print("attention-c_attn", total-last)
    last = total

    # ScaledDotProductEfficientAttentionBackward0-value-query-key
    total += B*S*D * mp
    total += B*S*D * mp
    total += B*S*D * mp
    # ScaledDotProductEfficientAttentionBackward0-log_sumexp
    total += B*NH*S * fp
    # ScaledDotProductEfficientAttentionBackward0-philox_offset
    total += 1 * i64
    # ScaledDotProductEfficientAttentionBackward0-philox_seed
    total += 1 * i64
    # ScaledDotProductEfficientAttentionBackward0-output
    ## total += B*S*D * mp # this is accounted in c_proj matmul

    print("scaled dot product attention", total-last)
    last = total

    # MmBackward0-mat2 (c_proj)
    total += D * D * hp if amp else 0  # model parameter (attn.c_proj.weight) at half precision, or full precision.
    # MmBackward0-self (c_proj)
    total += B*S*D * mp # same as ScaledDotProductEfficientAttentionBackward0-output

    print("attention-c_proj", total-last)
    last = total

    # NativeDropoutBackward0-result1
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout

    print("attention-droput", total-last)
    last = total

    ## MLP

    # MmBackward0-mat2 (c_fc)
    total += S*D*hp if amp else 0 # model parameter (mlp.c_fc.weight) at half precision, or full precision
    # MmBackward0-self (c_fc)
    total += B*S*D * mp # this value comes from layer norm 2

    print("c_fc", total-last)
    last = total

    # GeluBackward0-self
    total += 4 * B * S * D * mp # this tensor is result of 'x @ c_fc'

    print("gelu", total-last)
    last = total

    # MmBackward0-mat2 (c_proj)
    total += S*D*hp if amp else 0 # model parameter (mlp.c_proj.weight) at half precision, or full precision
    # MmBackward0-self (c_proj)
    total += 4*B * S * D * mp  # this value comes from GELU

    print("c_proj", total-last)
    last = total

    # NativeDropoutBackward0-result1 (mlp dropout)
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout

    print("mlp dropout", total-last)
    last = total

    return total

In [21]:
def calc_saved_gpt_activations(B,S,D,NH,V, n_layers=1, amp=False, dropout=False, bias=False):
    bo = 1
    fp = 4
    mp = 2 if amp else 4
    hp = 2
    i64 = 8

    total = 0

    ## token embeddings
    total += B * S * i64 # indices cached by embedding layer

    ## position embeddings
    total += S * i64 # indices cached by embedding layer

    print("embeddings", total)
    last = total

    ## x = self.transformer.drop(tok_emb + pos_emb)
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout
    
    print("gpt-droput", total-last)
    last = total
    
    # add blocks
    
    ## we shifted calculation to add input to first layer norm
    ## total += B *S *D * fp # cache input, normally it is cached as the input variable to the first layer norm. 
    total_block= n_layers * calc_block(B,S,D,NH,amp,dropout, bias)
    total +=total_block
    print("block_sum",str(total_block))

    print("total blocks (block+input)", total-last)
    last = total

    # final ln
    # NativeLayerNormBackward0-input
    total += B*S*D * fp # output of final attention block
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total +=int(bias) *  0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights.
    
    print("ln_f", total-last)
    last = total

    ## lm head
    # lm_head-mat1
    total += V * D * hp if amp else 0 # in amp half preicison copy of embedding parameter weights are used. in full precision embedding parameter weights is used, no need to add to total
    # lm_head-mat1
    total += B*S*D * mp # output of final ln
    
    print("lm_head", total-last)
    last = total

    return total

In [35]:
def calc_params(sequenceLength, hiddenDim, vocabSize, numLayers=1, includeBias=False, includeEmbeddings=False, weightTying=True):
    wte = vocabSize * hiddenDim # Token embedding
    wpe = hiddenDim * sequenceLength # Positional embedding
    
    # Per layer
    ln_1 = hiddenDim # layer norm 1
    attn_c_attn = 3 * hiddenDim * hiddenDim # Query, Key, Value projections
    attn_c_proj = hiddenDim * hiddenDim # Output projection
    ln_2 = hiddenDim # layer norm 2
    mlp_c_fc = 4 * hiddenDim * hiddenDim # First MLP layer
    mlp_c_proj = 4 * hiddenDim * hiddenDim # Second MLP layer

    paramsPerLayer = ln_1 + attn_c_attn + attn_c_proj + ln_1 + mlp_c_fc + mlp_c_proj
    final_ln = hiddenDim
    final_dense = vocabSize * hiddenDim

    totalParamsWOEmbedding = (paramsPerLayer * numLayers) + final_ln + final_dense
    totalParams = totalParamsWOEmbedding
    
    if includeBias:
        # each layer has 6 biases
        # - ln_1.bias
        # - c_attn.bias (3 * hidden dimension)
        # - c_proj.bias
        # - ln_2.bias
        # - c_fc.bias (4 * hidden dimension)
        # - c_proj.bias
        totalParams += 11 * numLayers * hiddenDim
        
        # final ln has a bias
        totalParams += hiddenDim

    if includeEmbeddings:
        totalParams += wpe
        if not weightTying:
            # if weight tying, wte and final_dense is shared. Only count wte when weight tying is off
            totalParams += wte
    
    return totalParams

In [36]:
calc_params(512, 384, 8192, 1, True)

4920960

In [43]:
def calc_all(B,S,D,NH,V, n_layers=1, amp=False, dropout=False, bias=False):
    bo = 1
    fp = 4
    mp = 2 if amp else 4
    hp = 2
    i64 = 8

    total = 0
    saved_activations = calc_gpt(B,S,D,NH,V, n_layers, amp, dropout, bias)
    print("gpt activations", saved_activations)

    total += saved_activations
    last = total
    
    # x and y
    total += B*S*i64 * 2

    print("inputs", total-last)
    last = total

    # output of lm head, before loss is calculated.
    # this tensor will be destroyed after loss calculation, however, it is still a significant contribution.
    total += B*S*V*mp

    print("lm_head_output", total-last)
    last = total

    params = calc_params(S,D,V,n_layers, includeBias=bias, includeEmbeddings=True, weightTying=True)
    print("Number of Params", params)

    total += params*fp

    print("params_mem", total-last)
    last = total

    total += params*fp

    print("gradients", total-last)
    last = total

    total += params*fp*2
    
    print("optimizer", total-last)
    last = total

    return total


In [44]:
calc_saved_block_activations(B=48,S=512,D=384,NH=4, amp=False, dropout=True, bias=False)

LN1 37945344
LN2 37945344
attention-c_attn 37748736
scaled dot product attention 113639440
attention-c_proj 37748736
attention-droput 9437184
c_fc 37748736
gelu 150994944
c_proj 150994944
mlp dropout 9437184


623640592

In [49]:
calc_saved_gpt_activations(B=48,S=1024,D=512,NH=4,V=8192, n_layers=8, amp=False, dropout=True, bias=False)

embeddings 401408
gpt-droput 25165824
LN1 101056512
LN2 101056512
attention-c_attn 100663296
scaled dot product attention 302776336
attention-c_proj 100663296
attention-droput 25165824
c_fc 100663296
gelu 402653184
c_proj 402653184
mlp dropout 25165824
block_sum 13300138112
total blocks (block+input) 13300138112
ln_f 101056512
lm_head 100663296


13527425152

In [55]:
calc_all(B=24,S=1024,D=384,NH=4,V=8192, n_layers=4, amp=True, dropout=True, bias=False)

embeddings 204800
gpt-droput 9437184
LN1 37945344
LN2 37945344
attention-c_attn 19759104
scaled dot product attention 57016336
attention-c_proj 19169280
attention-droput 9437184
c_fc 19660800
gelu 75497472
c_proj 76283904
mlp dropout 9437184
block_sum 1448607808
total blocks (block+input) 1448607808
ln_f 37945344
lm_head 25165824
gpt activations 1521360960
inputs 393216
lm_head_output 402653184
Number of Params 10620288
params_mem 42481152
gradients 42481152
optimizer 84962304


2094331968