In [1]:
from sympy import Symbol, MatrixSymbol, Function

from model import (
    repeat,
    count,
    tree_map,
    Embeddings,
    ScaledDotProductAttention,
    MultiHeadAttention,
    PositionWiseFFN,
    LayerNorm,
    Block,
    Transformer,
)


In [2]:
l = Symbol("l")
d_model = Symbol(r"d_{model}")
d_ff = Symbol(r"d_{ff}")
n_vocab = Symbol(r"n_{vocab}")
n_layers = Symbol(r"n_{layers}")
n_heads = Symbol(r"n_{heads}")

inputs = MatrixSymbol("\text{inputs}", 1, l)

softmax = Function(r"\text{softmax}")
relu = Function(r"\text{relu}")

In [3]:
Embeddings(l, d_model, n_vocab).flops_count

d_{model}*l

In [4]:
ScaledDotProductAttention(l, l, d_model, d_model).flops_count

4*d_{model}*l**2 + 5*l**2

In [5]:
MultiHeadAttention(l, l, d_model, d_model, d_model, n_heads).flops_count

8*d_{model}**2*l*n_{heads} + 4*d_{model}*l**2*n_{heads} + 5*l**2*n_{heads}

In [6]:
PositionWiseFFN(l, d_model, d_ff, d_model).flops_count

4*d_{ff}*d_{model}*l + d_{ff} + d_{model}

In [7]:
LayerNorm(l, d_model).flops_count

6*d_{model}*l + 7*d_{model}

In [8]:
Block(l, d_model, d_ff, n_heads).flops_count

4*d_{ff}*d_{model}*l + d_{ff} + 8*d_{model}**2*l + 4*d_{model}*l**2 + 14*d_{model}*l + 15*d_{model} + 5*l**2*n_{heads}

In [9]:
flops = Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_count
flops

4*d_{ff}*d_{model}*l*n_{layers} + d_{ff}*n_{layers} + 8*d_{model}**2*l*n_{layers} + 4*d_{model}*l**2*n_{layers} + 12*d_{model}*l*n_{layers} + 2*d_{model}*l*n_{vocab} + 7*d_{model}*l + 13*d_{model}*n_{layers} + 7*d_{model} + 5*l**2*n_{heads}*n_{layers} + n_{layers}*(2*d_{model}*l + 2*d_{model})

In [10]:
flops = flops.subs(d_model, Symbol("d"))
flops = flops.subs(d_ff, Symbol("f"))
flops = flops.subs(n_layers, Symbol("n"))
flops = flops.subs(n_vocab, Symbol("v"))
flops = flops.subs(n_heads, Symbol("h"))
flops

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

In [11]:
flops.collect("l")

13*d*n + 7*d + f*n + l**2*(4*d*n + 5*h*n) + l*(8*d**2*n + 4*d*f*n + 12*d*n + 2*d*v + 7*d) + n*(2*d*l + 2*d)

In [12]:
flops.collect("d")

8*d**2*l*n + d*(4*f*l*n + 4*l**2*n + 12*l*n + 2*l*v + 7*l + n*(2*l + 2) + 13*n + 7) + f*n + 5*h*l**2*n

In [13]:
flops.collect("f")

8*d**2*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*(4*d*l*n + n) + 5*h*l**2*n + n*(2*d*l + 2*d)

In [14]:
flops.collect("v")

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

In [15]:
flops.collect("h")

8*d**2*l*n + 4*d*f*l*n + 4*d*l**2*n + 12*d*l*n + 2*d*l*v + 7*d*l + 13*d*n + 7*d + f*n + 5*h*l**2*n + n*(2*d*l + 2*d)

---

In [16]:
def count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads):
    return int(N*Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_count)

count_flops(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

51248964096

In [17]:
def count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads):
    return Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_dict

count_dict(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

{'Embeddings': {'X_word + X_pos': 262144},
 'LayerNorm_Final': {'mean(x)': 262656,
  'var_x': {'X**2': 262144,
   'sum': 262144,
   '/ m': 512,
   'mean_x**2': 512,
   'm - mean_x**2': 512},
  'numerator': 512,
  'denominator': 1024,
  'gamma*(numerator/denominator)+beta': 786432},
 'ProjectToVocab': 15728640000,
 'Blocks': {'LayerNorms': {'mean(x)': 3151872,
   'var_x': {'X**2': 3145728,
    'sum': 3145728,
    '/ m': 6144,
    'mean_x**2': 6144,
    'm - mean_x**2': 6144},
   'numerator': 6144,
   'denominator': 12288,
   'gamma*(numerator/denominator)+beta': 9437184},
  'ResidualConnections': 3145728,
  'MultiHeadAttention': {'Q * WQ': 1610612736.0,
   'K * WK': 1610612736.0,
   'V * WV': 1610612736.0,
   'ScaledDotProductAttention': {'Q * K.T': 1610612736.0,
    '/ sqrt(d_k)': 12582912,
    '+ mask': 12582912,
    'Softmax': {'x - max(x)': 12582912, 'e^x': 12582912, 'sum': 12582912},
    '* V': 1610612736.0},
   '* WO': 1610612736.0},
  'PositionWiseFFN': {'* W1': 12884901888,
   '

In [18]:
def tree_map(d, f):
    return {k: tree_map(v, f) if isinstance(v, dict) else f(v) for k, v in d.items()}

def percent_flops(N,l,d_model,d_ff,n_vocab,n_layers,n_heads):
    total_flops = count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    return tree_map(
        d=count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads),
        f=lambda x: x / total_flops
    )

In [19]:
percent_flops(
    N=1,
    l=512,
    d_model=512,
    d_ff=4096,
    n_vocab=30000,
    n_layers=6,
    n_heads=8,
)

{'Embeddings': {'X_word + X_pos': 5.115108268509576e-06},
 'LayerNorm_Final': {'mean(x)': 5.125098714346509e-06,
  'var_x': {'X**2': 5.115108268509576e-06,
   'sum': 5.115108268509576e-06,
   '/ m': 9.990445836932766e-09,
   'mean_x**2': 9.990445836932766e-09,
   'm - mean_x**2': 9.990445836932766e-09},
  'numerator': 9.990445836932766e-09,
  'denominator': 1.9980891673865532e-08,
  'gamma*(numerator/denominator)+beta': 1.534532480552873e-05},
 'ProjectToVocab': 0.3069064961105746,
 'Blocks': {'LayerNorms': {'mean(x)': 6.150118457215811e-05,
   'var_x': {'X**2': 6.138129922211491e-05,
    'sum': 6.138129922211491e-05,
    '/ m': 1.198853500431932e-07,
    'mean_x**2': 1.198853500431932e-07,
    'm - mean_x**2': 1.198853500431932e-07},
   'numerator': 1.198853500431932e-07,
   'denominator': 2.397707000863864e-07,
   'gamma*(numerator/denominator)+beta': 0.00018414389766634474},
  'ResidualConnections': 6.138129922211491e-05,
  'MultiHeadAttention': {'Q * WQ': 0.031427225201722836,
   '

In [20]:
def analysis(
    N,
    l,
    d_model,
    d_ff,
    n_vocab,
    n_layers,
    n_heads,
):
    keys={
        "Embeddings",
        "LayerNorm_Final",
        "LayerNorms",
        "PositionWiseFFN",
        "ScaledDotProductAttention",
    }

    _count = lambda d: sum(_count(v) if isinstance(v, dict) else v for v in d.values())

    def gather(d):
        if not isinstance(d, dict):
            return d
        return {k: _count(v) if k in keys else gather(v) for k, v in d.items()}

    flops = count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    total_flops = count_flops(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)

    counts = gather(flops)
    counts["LayerNorms"] = counts.pop("LayerNorm_Final") + counts["Blocks"].pop("LayerNorms")
    counts.update(counts.pop("Blocks"))
    counts.update(counts.pop("MultiHeadAttention"))

    percentages = tree_map(d=counts, f=lambda x: x / total_flops)
    return counts, percentages


In [21]:
analysis(
    N=1,
    l=2048,
    d_model=8192,
    d_ff=32768,
    n_vocab=51200,
    n_layers=64,
    n_heads=64,
)[1]

{'Embeddings': 7.566801465307276e-08,
 'ProjectToVocab': 0.007748404700474651,
 'LayerNorms': 5.86004067287985e-05,
 'ResidualConnections': 9.685505875593314e-06,
 'PositionWiseFFN': 0.6347493248860107,
 'Q * WQ': 0.07934366413286043,
 'K * WK': 0.07934366413286043,
 'V * WV': 0.07934366413286043,
 'ScaledDotProductAttention': 0.04005925230145394,
 '* WO': 0.07934366413286043}

In [22]:
analysis(
    N=1,
    l=2048,
    d_model=1024,
    d_ff=4096,
    n_vocab=51200,
    n_layers=24,
    n_heads=16
)[1]

{'Embeddings': 1.119801735515862e-06,
 'ProjectToVocab': 0.11466769771682427,
 'LayerNorms': 0.00032940925516123466,
 'ResidualConnections': 5.375048330476138e-05,
 'PositionWiseFFN': 0.44032402484598815,
 'Q * WQ': 0.05504049490407565,
 'K * WK': 0.05504049490407565,
 'V * WV': 0.05504049490407565,
 'ScaledDotProductAttention': 0.2244620182806835,
 '* WO': 0.05504049490407565}

In [36]:
import plotly.graph_objects as go



fig = go.Figure(go.Bar(x=x, y=y))
fig.update_yaxes(range=[0,1]) 

# fig.update_traces(overwrite=True, marker={"opacity": 0.4})

fig.show()


In [24]:
from model import Transformer, count, repeat, tree_map

def analysis(
    N,
    l,
    d_model,
    d_ff,
    n_vocab,
    n_layers,
    n_heads,
):
    def count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads):
        return repeat(
            Transformer(l, d_model, d_ff, n_vocab, n_layers, n_heads).flops_dict, N
        )

    keys = {
        "Embeddings",
        "LayerNorm_Final",
        "LayerNorms",
        "PositionWiseFFN",
        "ScaledDotProductAttention",
    }

    _count = lambda d: sum(_count(v) if isinstance(v, dict) else v for v in d.values())

    def gather(d):
        if not isinstance(d, dict):
            return d
        return {k: _count(v) if k in keys else gather(v) for k, v in d.items()}

    flops = count_dict(N, l, d_model, d_ff, n_vocab, n_layers, n_heads)
    total_flops = count(flops)

    counts = gather(flops)
    counts["LayerNorms"] = counts.pop("LayerNorm_Final") + counts["Blocks"].pop(
        "LayerNorms"
    )
    counts.update(counts.pop("Blocks"))
    counts.update(counts.pop("MultiHeadAttention"))

    percentages = tree_map(d=counts, f=lambda x: x / total_flops)
    return counts, percentages