In [1]:
def calc_block(B,S,D,NH, amp=False, dropout=False, bias=False):
    bo = 1
    fp = 4
    mp = 2 if amp else 4
    hp = 2
    i64 = 8

    total = 0

    ## layer norm1

    # NativeLayerNormBackward0-input
    total += 0 * fp # input tensor, assume already cached previously. 
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total += int(bias) * 0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights. 
    

    print("LN1", total)
    last = total

    ## layer norm2

    # NativeLayerNormBackward0-input
    total += B*S*D * fp # this tensor comes from addition of attention and input tensor
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total +=int(bias) *  0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights.
    
    print("LN2", total-last)
    last = total

    ## scaled_dot_product_attention efficient implmenentation

    # MmBackward0-self (c_attn)
    total += B*S*D * mp
    # MmBackward0-mat2 (c_attn)
    total += 3*D*D * hp if amp else 0 # model parameter (attn.c_attn.weight) at half precision, or full precision

        
    print("attention-c_attn", total-last)
    last = total

    # ScaledDotProductEfficientAttentionBackward0-value-query-key
    total += B*S*D * mp
    total += B*S*D * mp
    total += B*S*D * mp
    # ScaledDotProductEfficientAttentionBackward0-log_sumexp
    total += B*NH*S * fp
    # ScaledDotProductEfficientAttentionBackward0-philox_offset
    total += 1 * i64
    # ScaledDotProductEfficientAttentionBackward0-philox_seed
    total += 1 * i64
    # ScaledDotProductEfficientAttentionBackward0-output
    total += B*S*D * mp

    print("scaled dot product attention", total-last)
    last = total

    # MmBackward0-mat2 (c_proj)
    total += D * D * hp if amp else 0  # model parameter (attn.c_proj.weight) at half precision, or full precision.
    # MmBackward0-self (c_proj)
    total += 0 # ScaledDotProductEfficientAttentionBackward0-output, already cached by the scaled_dot_product_attention block.

    print("attention-c_proj", total-last)
    last = total

    # NativeDropoutBackward0-result1
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout

    print("attention-droput", total-last)
    last = total

    ## MLP

    # MmBackward0-mat2 (c_fc)
    total += S*D*hp if amp else 0 # model parameter (mlp.c_fc.weight) at half precision, or full precision
    # MmBackward0-self (c_fc)
    total += B*S*D * mp # this value comes from layer norm 2

    print("c_fc", total-last)
    last = total

    # GeluBackward0-self
    total += B * S * S * mp # this tensor is result of 'x @ c_fc'

    print("gelu", total-last)
    last = total

    # MmBackward0-mat2 (c_proj)
    total += S*D*hp if amp else 0 # model parameter (mlp.c_proj.weight) at half precision, or full precision
    # MmBackward0-self (c_proj)
    total += B * S * S * mp  # this value comes from GELU

    print("c_proj", total-last)
    last = total

    # NativeDropoutBackward0-result1 (mlp dropout)
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout

    print("mlp dropout", total-last)
    last = total

    return total

In [2]:
def calc_gpt(B,S,D,NH,V, n_layers=1, amp=False, dropout=False, bias=False):
    bo = 1
    fp = 4
    mp = 2 if amp else 4
    hp = 2
    i64 = 8

    total = 0

    ## token embeddings
    total += B * S * i64 # indices cached by embedding layer

    ## position embeddings
    total += S * i64 # indices cached by embedding layer

    print("embeddings", total)
    last = total

    ## x = self.transformer.drop(tok_emb + pos_emb)
    total += B*S*D * bo if dropout else 0 # dropout mask for resid_dropout
    
    print("gpt-droput", total-last)
    last = total
    
    # add blocks
    total += B *S *D * fp # cache input, normally it is cached as the input variable to the first layer norm. 
    total_block= n_layers * calc_block(B,S,D,NH,amp,dropout, bias)
    total +=total_block
    print("total block",str(total_block))

    print("block", total-last)
    last = total

    # final ln
    # NativeLayerNormBackward0-input
    total += B*S*D * fp # output of final attention block
    # NativeLayerNormBackward0-result1
    total += B*S * fp
    # NativeLayerNormBackward0-result2
    total += B*S * fp
    # NativeLayerNormBackward0-weight
    total += 0 # model parameter at full precision, no need to account for. 
    # NativeLayerNormBackward0-bias
    total +=int(bias) *  0 * fp # amp does not affect ln and full precision weights are not added to total, because it is part of parameter weights.
    
    print("ln_f", total-last)
    last = total

    ## lm head
    # lm_head-mat1
    total += V * D * hp if amp else 0 # in amp half preicison copy of embedding parameter weights are used. in full precision embedding parameter weights is used, no need to add to total
    # lm_head-mat1
    total += B*S*D * mp # output of final ln
    
    print("lm_head", total-last)
    last = total

    return total

In [3]:
calc_block(B=2,S=4096,D=1024,NH=2, amp=False, dropout=False, bias=False)

LN1 65536
LN2 33619968
attention-c_attn 33554432
scaled dot product attention 134283280
attention-c_proj 0
attention-droput 0
c_fc 33554432
gelu 134217728
c_proj 134217728
mlp dropout 0


503513104

In [4]:
calc_gpt(B=2,S=4096,D=1024,NH=2,V=16, n_layers=1, amp=True, dropout=True, bias=True)

embeddings 98304
gpt-droput 8388608
LN1 65536
LN2 33619968
attention-c_attn 23068672
scaled dot product attention 67174416
attention-c_proj 2097152
attention-droput 8388608
c_fc 25165824
gelu 67108864
c_proj 75497472
mlp dropout 8388608
total block 310575120
block 344129552
ln_f 33619968
lm_head 16809984


403046416