In [1]:
from sympy import Symbol, MatrixSymbol, Function

In [2]:
l = Symbol("l")
d_model = Symbol(r"d_{model}")
d_ff = Symbol(r"d_{ff}")
n_vocab = Symbol(r"n_{vocab}")
n_layers = Symbol(r"n_{layers}")
n_heads = Symbol(r"n_{heads}")

inputs = MatrixSymbol("\text{inputs}", 1, l)

softmax = Function(r"\text{softmax}")
relu = Function(r"\text{relu}")

In [3]:
class Layer:
    @property
    def flops_dict(self):
        raise NotImplementedError()

    @property
    def flops_count(self):
        def count(d):
            return sum([count(v) if isinstance(v, dict) else v for v in d.values()])
        return count(self.flops_dict)
    
    def evaluate(self, *args, **kwargs):
        raise NotImplementedError()

In [4]:
def repeat(d, n):
    return {k: repeat(v, n) if isinstance(v, dict) else n*v for k, v in d.items()}

### Embeddings

In [5]:
class Embeddings(Layer):    
    def __init__(self, l, d_model, n_vocab):
        super().__init__()
        self.l, self.d_model, self.n_vocab = l, d_model, n_vocab

    def evaluate(self):
        X_word = MatrixSymbol(r"X_{word}", self.l, self.d_model)
        X_pos = MatrixSymbol(r"X_{pos}", self.l, self.d_model)

        return X_word + X_pos

    @property
    def flops_dict(self):
        return {
            "X_word + X_pos": self.l * self.d_model,
        }

In [6]:
Embeddings(l, d_model, n_vocab).flops_count

d_{model}*l

### Scaled Dot Product Attention

In [7]:
class Softmax(Layer):
    def __init__(self, m, n):
        super().__init__()
        self.m, self.n = m, n

    @property
    def flops_dict(self):
        return {
            "x - max(x)": self.m * self.n,
            "e^x": self.m * self.n,
            "sum": self.m * self.n,
        }


class ScaledDotProductAttention(Layer):
    def __init__(self, n_q, n_k, d_k, d_v):
        super().__init__()
        self.n_q, self.n_k, self.d_k, self.d_v = n_q, n_k, d_k, d_v

    def evaluate(self):
        # Q = MatrixSymbol("Q", self.n_q, self.d_k)
        # K = MatrixSymbol("K", self.n_k, self.d_k)
        # V = MatrixSymbol("Q", self.n_k, self.d_v)
        # mask = MatrixSymbol("Q", self.n_q, self.n_k)
        pass

    @property
    def flops_dict(self):
        return {
            "Q * K.T": 2 * self.n_q * self.d_k * self.n_k,
            "/ sqrt(d_k)": self.n_q * self.n_k,
            "+ mask": self.n_q * self.n_k,
            "Softmax": Softmax(self.n_q, self.n_k).flops_dict,
            "* V": 2 * self.n_q * self.n_k * self.d_v,
        }


In [8]:
ScaledDotProductAttention(l, l, d_model, d_model).flops_count

4*d_{model}*l**2 + 5*l**2

### Multi-head Attention

In [9]:
class MultiHeadAttention(Layer):
    def __init__(self, n_q, n_k, d_k, d_v, d_model, n_heads):
        super().__init__()
        self.n_q = n_q
        self.n_k = n_k
        self.d_k = d_k
        self.d_v = d_v
        self.d_model = d_model
        self.n_heads = n_heads
    
    def evaluate(self):
        # WQ = MatrixSymbol("WQ", d_model, d_k * n_heads)
        # WK = MatrixSymbol("WK", d_model, d_k * n_heads)
        # WV = MatrixSymbol("WV", d_model, d_v * n_heads)
        # WO = MatrixSymbol("WO", d_v * n_heads, d_model)
        pass

    @property
    def flops_dict(self):
        # the FLOPs count is the same no matter the number of heads, the heads just 
        # enable different representation subspaces, it doesn't add any computational 
        # complexity as far as FLOPs is concerned
        return {
            "Q * WQ": 2 * self.n_q * self.d_model * (self.d_k * self.n_heads),
            "K * WK": 2 * self.n_k * self.d_model * (self.d_k * self.n_heads),
            "V * WV": 2 * self.n_k * self.d_model * (self.d_v * self.n_heads),
            "ScaledDotProductAttention": repeat(ScaledDotProductAttention(self.n_q, self.n_k, self.d_k, self.d_v).flops_dict, self.n_heads),
            "* WO": 2 * self.n_q * (self.d_v * self.n_heads) * self.d_model
        }

In [10]:
MultiHeadAttention(l, l, d_model, d_model, d_model, n_heads).flops_count

8*d_{model}**2*l*n_{heads} + 4*d_{model}*l**2*n_{heads} + 5*l**2*n_{heads}

### Position-Wise Feed Forward Network

In [11]:
class PositionWiseFFN(Layer):
    def __init__(self, l, d_in, d_ff, d_out):
        super().__init__()
        self.l, self.d_in, self.d_ff, self.d_out = l, d_in, d_ff, d_out
    
    def evaluate(self):
        # W1 = MatrixSymbol("W1", d_in, d_ff)
        # b1 = MatrixSymbol("b1", d_ff, 1)

        # W1 = MatrixSymbol("W1", d_ff, d_out)
        # b1 = MatrixSymbol("b1", d_out, 1)

        # X = MatrixSymbol("X", l, d_in)
        # output = relu(X * W1 + b1) * W2 + b2
        pass

    @property
    def flops_dict(self):
        return {
            "* W1": 2 * self.l * self.d_in * self.d_ff,
            "+ b1": self.d_ff ,
            "* W2": 2 * self.l * self.d_ff * self.d_out,
            "+ b2": self.d_out,
        }

In [12]:
PositionWiseFFN(l, d_model, d_ff, d_model).flops_count

4*d_{ff}*d_{model}*l + d_{ff} + d_{model}

### Layer Norm

In [13]:
class LayerNorm(Layer):
    def __init__(self, m, n):
        super().__init__()
        self.m, self.n = m, n
    
    def evaluate(self):
        # X = MatrixSymbol("X", n, m)
        #
        # mean_x = sum(X, axis=-1) / m
        # var_x = sum(X**2, axis=-1) / m - mean_x**2
        #
        # numer = x - mean_x
        # denom = sqrt(var_x + eps)
        #
        # gamma * (numer / denom) + beta
        pass
    
    @property
    def flops_dict(self):
        n, m = self.n, self.m
        return {
            "mean(x)": m*n + n, # element wise add and then a divide over the resulting vector
            "var_x": { 
                "X**2": m*n,
                "sum": m*n,
                "/ m": n,
                "mean_x**2": n,
                "m - mean_x**2": n
            },
            "numerator": n, # x - mean_x
            "denominator": 2*n, # sqrt(var_x + eps), elem wise addition then elem wise sqrt
            "gamma*(numerator/denominator)+beta": 3*m*n # after broadcasting, 1 elem wise divide, multiply, then add
        }

In [14]:
LayerNorm(l, d_model).flops_count

6*d_{model}*l + 7*d_{model}

### Transformer Block

In [33]:
class Block(Layer):
    def __init__(self, l, d_model, d_ff, n_heads):
        super().__init__()
        self.l, self.d_model, self.d_ff, self.n_heads = l, d_model, d_ff, n_heads
    
    def evaluate(self):
        # X = MatrixSymbol("X", l, d_model)
        # X = X + MultiHeadAttention(LayerNorm(X))
        # X = X + PositionWiseFFN(LayerNorm(X))
        pass
    
    @property
    def flops_dict(self):
        return {
            "LayerNorms": repeat(LayerNorm(self.l, self.d_model).flops_dict, 2),
            "ResidualConnections": 2 * self.l * self.d_model,
            "MultiHeadAttention": MultiHeadAttention(self.l, self.l, self.d_model / self.n_heads, self.d_model / self.n_heads, self.d_model, self.n_heads).flops_dict,
            "PositionWiseFFN": PositionWiseFFN(self.l, self.d_model, self.d_ff, self.d_model).flops_dict
        }

In [16]:
Block(l, d_model, d_ff, n_heads).flops_count

4*d_{ff}*d_{model}*l + d_{ff} + 8*d_{model}**2*l + 4*d_{model}*l**2 + 14*d_{model}*l + 15*d_{model} + 5*l**2*n_{heads}

### Transformer

In [17]:
class Transformer(Layer):
    def __init__(self, l, d_model, d_ff, n_vocab, n_layers, n_heads):
        super().__init__()
        self.l, self.d_model, self.d_ff, self.n_vocab, self.n_layers, self.n_heads = l, d_model, d_ff, n_vocab,n_layers,  n_heads
    
    def evaluate(self):
        # X = embeddings
        # for _ in range(n_layers):
        #   X = block(X)
        # X = layer_norm(X)
        # W_lm = MatrixSymbol("LM", d_model, n_vocab) # projection_to_vocab (lm head)
        # outputs = X * W_lm
        pass
    
    @property
    def flops_dict(self):
        return {
            "Embeddings": Embeddings(self.l, self.d_model, self.n_vocab).flops_dict,
            "LayerNorm_Final": LayerNorm(self.l, self.d_model).flops_dict,
            "ProjectToVocab": 2 * self.l * self.d_model * self.n_vocab,
            "Blocks": repeat(Block(self.l, self.d_model, self.d_ff, self.n_heads).flops_dict, self.n_layers)
        }

In [18]:
flops = Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_count
flops

4*d_{ff}*d_{model}*l*n_{layers} + d_{ff}*n_{layers} + 8*d_{model}**2*l*n_{layers} + 4*d_{model}*l**2*n_{layers} + 12*d_{model}*l*n_{layers} + 2*d_{model}*l*n_{vocab} + 7*d_{model}*l + 13*d_{model}*n_{layers} + 7*d_{model} + 5*l**2*n_{heads}*n_{layers} + n_{layers}*(2*d_{model}*l + 2*d_{model})

In [19]:
flops = flops.subs(d_model, Symbol("d"))
flops = flops.subs(d_ff, Symbol("f"))
flops = flops.subs(n_layers, Symbol("n"))
flops = flops.subs(n_vocab, Symbol("v"))
flops = flops.subs(n_heads, Symbol("h"))
flops

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

In [21]:
flops.collect("l")

13*d*n + 7*d + f*n + l**2*(4*d*n + 5*h*n) + l*(8*d**2*n + 4*d*f*n + 12*d*n + 2*d*v + 7*d) + n*(2*d*l + 2*d)

In [20]:
flops.collect("d")

8*d**2*l*n + d*(4*f*l*n + 4*l**2*n + 12*l*n + 2*l*v + 7*l + n*(2*l + 2) + 13*n + 7) + f*n + 5*h*l**2*n

In [22]:
flops.collect("f")

8*d**2*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*(4*d*l*n + n) + 5*h*l**2*n + n*(2*d*l + 2*d)

In [24]:
flops.collect("v")

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

In [25]:
flops.collect("h")

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

In [44]:
def count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads):
    return int(N*Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_count)

count_flops(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

51248964096

In [52]:
def count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads):
    return Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_dict

count_dict(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

{'Embeddings': {'X_word + X_pos': 262144},
 'LayerNorm_Final': {'mean(x)': 262656,
  'var_x': {'X**2': 262144,
   'sum': 262144,
   '/ m': 512,
   'mean_x**2': 512,
   'm - mean_x**2': 512},
  'numerator': 512,
  'denominator': 1024,
  'gamma*(numerator/denominator)+beta': 786432},
 'ProjectToVocab': 15728640000,
 'Blocks': {'LayerNorms': {'mean(x)': 3151872,
   'var_x': {'X**2': 3145728,
    'sum': 3145728,
    '/ m': 6144,
    'mean_x**2': 6144,
    'm - mean_x**2': 6144},
   'numerator': 6144,
   'denominator': 12288,
   'gamma*(numerator/denominator)+beta': 9437184},
  'ResidualConnections': 3145728,
  'MultiHeadAttention': {'Q * WQ': 1610612736.0,
   'K * WK': 1610612736.0,
   'V * WV': 1610612736.0,
   'ScaledDotProductAttention': {'Q * K.T': 1610612736.0,
    '/ sqrt(d_k)': 12582912,
    '+ mask': 12582912,
    'Softmax': {'x - max(x)': 12582912, 'e^x': 12582912, 'sum': 12582912},
    '* V': 1610612736.0},
   '* WO': 1610612736.0},
  'PositionWiseFFN': {'* W1': 12884901888,
   '

In [61]:
def tree_map(d, f):
    return {k: tree_map(v, f) if isinstance(v, dict) else f(v) for k, v in d.items()}

def percent_flops(N,l,d_model,d_ff,n_vocab,n_layers,n_heads):
    total_flops = count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    return tree_map(
        d=count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads),
        f=lambda x: x / total_flops
    )

In [62]:
percent_flops(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

{'Embeddings': {'X_word + X_pos': 5.115108268509576e-06},
 'LayerNorm_Final': {'mean(x)': 5.125098714346509e-06,
  'var_x': {'X**2': 5.115108268509576e-06,
   'sum': 5.115108268509576e-06,
   '/ m': 9.990445836932766e-09,
   'mean_x**2': 9.990445836932766e-09,
   'm - mean_x**2': 9.990445836932766e-09},
  'numerator': 9.990445836932766e-09,
  'denominator': 1.9980891673865532e-08,
  'gamma*(numerator/denominator)+beta': 1.534532480552873e-05},
 'ProjectToVocab': 0.3069064961105746,
 'Blocks': {'LayerNorms': {'mean(x)': 6.150118457215811e-05,
   'var_x': {'X**2': 6.138129922211491e-05,
    'sum': 6.138129922211491e-05,
    '/ m': 1.198853500431932e-07,
    'mean_x**2': 1.198853500431932e-07,
    'm - mean_x**2': 1.198853500431932e-07},
   'numerator': 1.198853500431932e-07,
   'denominator': 2.397707000863864e-07,
   'gamma*(numerator/denominator)+beta': 0.00018414389766634474},
  'ResidualConnections': 6.138129922211491e-05,
  'MultiHeadAttention': {'Q * WQ': 0.031427225201722836,
   '

In [68]:
def analysis(N,l,d_model,d_ff,n_vocab,n_layers,n_heads, keys):
    _count = lambda d: sum(_count(v) if isinstance(v, dict) else v for v in d.values())
    def gather(d):
        if not isinstance(d, dict):
            return d
        return {k: _count(v) if k in keys else gather(v) for k, v in d.items()}

    flops = count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    total_flops = count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    
    counts = gather(flops)
    percentages = tree_map(
        d=counts,
        f=lambda x: x / total_flops
    )
    return counts, percentages

In [74]:
analysis(
    N=1,
    l=2048,
    d_model=8192,
    d_ff=32768,
    n_vocab=51200,
    n_layers=64,
    n_heads=64,
    keys={"Embeddings", "LayerNorm_Final", "LayerNorms", "PositionWiseFFN", "ScaledDotProductAttention"}
)[1]

{'Embeddings': 7.566801465307276e-08,
 'LayerNorm_Final': 4.542667188278953e-07,
 'ProjectToVocab': 0.007748404700474651,
 'Blocks': {'LayerNorms': 5.81461400099706e-05,
  'ResidualConnections': 9.685505875593314e-06,
  'MultiHeadAttention': {'Q * WQ': 0.07934366413286043,
   'K * WK': 0.07934366413286043,
   'V * WV': 0.07934366413286043,
   'ScaledDotProductAttention': 0.04005925230145394,
   '* WO': 0.07934366413286043},
  'PositionWiseFFN': 0.6347493248860107}}

In [75]:
analysis(
    N=1,
    l=2048,
    d_model=1024,
    d_ff=4096,
    n_vocab=51200,
    n_layers=24,
    n_heads=16,
    keys={"Embeddings", "LayerNorm_Final", "LayerNorms", "PositionWiseFFN", "ScaledDotProductAttention"}
)[1]

{'Embeddings': 1.119801735515862e-06,
 'LayerNorm_Final': 6.722637860433361e-06,
 'ProjectToVocab': 0.11466769771682427,
 'Blocks': {'LayerNorms': 0.00032268661730080133,
  'ResidualConnections': 5.375048330476138e-05,
  'MultiHeadAttention': {'Q * WQ': 0.05504049490407565,
   'K * WK': 0.05504049490407565,
   'V * WV': 0.05504049490407565,
   'ScaledDotProductAttention': 0.2244620182806835,
   '* WO': 0.05504049490407565},
  'PositionWiseFFN': 0.44032402484598815}}