In [1]:
def embedding(seq_len, vocab_size, d_model):
    return 2 * seq_len * vocab_size * d_model


def attention(seq_len, d_model, key_size, num_heads):
    projections = 2 * 3 * seq_len * d_model * (key_size * num_heads)
    logits = 2 * seq_len * seq_len * (key_size * num_heads)
    softmax = 3 * num_heads * seq_len * seq_len
    softmax_query_reduction = 2 * seq_len * seq_len * (key_size * num_heads)
    final_layer = 2 * seq_len * (key_size * num_heads) * d_model
    return projections + logits + softmax + softmax_query_reduction + final_layer


def dense(seq_len, d_model, ffw_size, swiglu=False):
    if not swiglu:
        return 2 * seq_len * (2 * d_model * ffw_size)
    else:
        return 2 * seq_len * (3 * d_model * ffw_size)


def moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=False):
    dense_flops = top_k * dense(seq_len, d_model, ffw_size, swiglu)
    gate_flops = 3 * seq_len * n_experts
    return dense_flops + gate_flops


def final_logits(seq_len, d_model, vocab_size):
    return 2 * seq_len * d_model * vocab_size


def get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads,
    ffw_size,
    swiglu=False,
    **kwargs,
):
    return (
        embedding(seq_len, vocab_size, d_model)
        + n_layers
        * (
            attention(seq_len, d_model, key_size, num_heads)
            + dense(seq_len, d_model, ffw_size, swiglu=swiglu)
        )
        + final_logits(seq_len, d_model, vocab_size)
    )


def flops_moe(
    seq_len,
    vocab_size,
    n_layers,
    d_model,
    key_size,
    num_heads,
    ffw_size,
    n_experts,
    top_k,
    swiglu=False,
):
    return (
        embedding(seq_len, vocab_size, d_model)
        + n_layers
        * (
            attention(seq_len, d_model, key_size, num_heads)
            + moe(dense, n_experts, top_k, seq_len, d_model, ffw_size, swiglu=swiglu)
        )
        + final_logits(seq_len, d_model, vocab_size)
    )


def parameter_count(
    vocab_size,
    n_layers,
    d_model,
    key_size,
    num_heads,
    num_kv_heads,
    ffw_size,
    n_experts=1,
    swiglu_or_geglu=False,
    **kwargs,
):
    mul_factor_ffn = 3 if swiglu_or_geglu else 2
    attn = 2 * d_model * num_heads * key_size + 2 * d_model * num_kv_heads * key_size
    return vocab_size * d_model + n_layers * (
        attn + mul_factor_ffn * n_experts * d_model * ffw_size
    )

In [2]:
multiple_of = 256

tiny = {
    "d_model": 384,
    "key_size": 64,
    "num_heads": 6,
    "num_kv_heads": 6,
    "ffw_size": int(8 / 3 * 384),
    "n_layers": 8,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
tiny["ffw_size"] = multiple_of * (
    (tiny["ffw_size"] + multiple_of - 1) // multiple_of
)
mini = {
    "d_model": 512,
    "key_size": 64,
    "num_heads": 8,
    "num_kv_heads": 8,
    "ffw_size": int(8 / 3 * 512),
    "n_layers": 10,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
mini["ffw_size"] = multiple_of * (
    (mini["ffw_size"] + multiple_of - 1) // multiple_of
)

small = {
    "d_model": 768,
    "key_size": 64,
    "num_heads": 12,
    "num_kv_heads": 12,
    "ffw_size": int(8 / 3 * 768),
    "n_layers": 12,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
small["ffw_size"] = multiple_of * (
    (small["ffw_size"] + multiple_of - 1) // multiple_of
)


_210M = {
    "d_model": 768,
    "key_size": 64,
    "num_heads": 12,
    "num_kv_heads": 12,
    "ffw_size": int(8 / 3 * 768),
    "n_layers": 24,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
_210M["ffw_size"] = multiple_of * (
    (_210M["ffw_size"] + multiple_of - 1) // multiple_of
)


In [3]:
all_flops = []
all_params = []

In [4]:
model = tiny

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
# lr 0.002 for cos, 0.001 for wsd
print(params / 1e6)
print(flops)
iters = [2400 / 0.8, 5600 / 0.8, 8000 / 0.8]#, 9600 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))

33.454464
171834605568
iters [3.0, 7.0, 10.0]
tokens [0.3, 0.7, 1.0]
ratio [9.2, 21.4, 30.6]
flops [0.1031007633408, 0.2405684477952, 0.343669211136]
flop savings 0.6


In [5]:
tiny2 = {
    "d_model": 512,
    "key_size": 64,
    "num_heads": 8,
    "num_kv_heads": 8,
    "ffw_size": int(8 / 3 * 512),
    "n_layers": 8,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
tiny2["ffw_size"] = multiple_of * (
    (tiny2["ffw_size"] + multiple_of - 1) // multiple_of
)

model = tiny2
n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
# lr 0.002 for cos, 0.001 for wsd
print(params / 1e6)
print(flops)
iters = [3000 / 0.8, 6000 / 0.8, 9000 / 0.8]# 12000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))

52.99456
254882611200
iters [3.8, 7.5, 11.2]
tokens [0.4, 0.8, 1.2]
ratio [7.2, 14.5, 21.7]
flops [0.1911619584, 0.3823239168, 0.5734858752]
flop savings 0.6


In [6]:
model = mini

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)

print(params / 1e6)
print(flops)
iters = [6000 / 0.8, 10000 / 0.8, 14000 / 0.8]# 18000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# lr 0.002 for cos, 0.001 for wsd


59.810304
279079550976
iters [7.5, 12.5, 17.5]
tokens [0.8, 1.3, 1.8]
ratio [12.8, 21.4, 30.0]
flops [0.418619326464, 0.69769887744, 0.976778428416]
flop savings 0.5733333333333334


In [7]:
mini2 = {
    "d_model": 640,
    "key_size": 64,
    "num_heads": 10,
    "num_kv_heads": 10,
    "ffw_size": int(8 / 3 * 640),
    "n_layers": 12,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}
mini2["ffw_size"] = multiple_of * (
    (mini2["ffw_size"] + multiple_of - 1) // multiple_of
)


model = mini2

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)

print(params / 1e6)
print(flops)
iters = [8000 / 0.8, 14000 / 0.8, 20000 / 0.8]# 26000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# lr 0.002 for cos, 0.001 for wsd

93.11296
409294602240
iters [10.0, 17.5, 25.0]
tokens [1.0, 1.8, 2.6]
ratio [11.0, 19.2, 27.5]
flops [0.81858920448, 1.43253110784, 2.0464730112]
flop savings 0.5809523809523809


In [8]:
model = small

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
print(params / 1e6)
print(flops)
iters = [12000 / 0.8, 20000 / 0.8, 28000 / 0.8]# 36000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# lr 0.001

123.532032
527392309248
iters [15.0, 25.0, 35.0]
tokens [1.5, 2.6, 3.6]
ratio [12.4, 20.7, 29.0]
flops [1.582176927744, 2.63696154624, 3.691746164736]
flop savings 0.5733333333333334


In [9]:
_151M = {
    "d_model": 768,
    "key_size": 64,
    "num_heads": 12,
    "num_kv_heads": 12,
    "ffw_size": int(8 / 3 * 768),
    "n_layers": 16,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}

_151M["ffw_size"] = multiple_of * (
    (_151M["ffw_size"] + multiple_of - 1) // multiple_of
)

model = _151M

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
print(params / 1e6)
print(flops)
iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# double batch size?

151.843584
624142319616
iters [25.0, 37.5, 50.0]
tokens [2.6, 3.8, 5.1]
ratio [16.9, 25.3, 33.7]
flops [3.12071159808, 4.68106739712, 6.24142319616]
flop savings 0.5555555555555556


In [10]:
_166M = {
    "d_model": 896,
    "key_size": 64,
    "num_heads": 14,
    "num_kv_heads": 14,
    "ffw_size": int(8 / 3 * 896),
    "n_layers": 12,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}

_166M["ffw_size"] = multiple_of * (
    (_166M["ffw_size"] + multiple_of - 1) // multiple_of
)

model = _166M

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
print(params / 1e6)
print(flops)
iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# double batch size?

166.1408
682936762368
iters [25.0, 37.5, 50.0]
tokens [2.6, 3.8, 5.1]
ratio [15.4, 23.1, 30.8]
flops [3.41468381184, 5.12202571776, 6.82936762368]
flop savings 0.5555555555555556


In [11]:
model = _210M

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
print(params / 1e6)
print(flops)
# iters = [22222, 44444, 66666]
iters = [30000 / 0.8, 40000 / 0.8, 50000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{200 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{200 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))

208.466688
817642340352
iters [37.5, 50.0, 62.5]
tokens [3.8, 5.1, 6.4]
ratio [18.4, 24.6, 30.7]
flops [6.13231755264, 8.17642340352, 10.2205292544]
flop savings 0.5333333333333333


In [12]:
_350M = {
    "d_model": 1024,
    "key_size": 64,
    "num_heads": 16,
    "num_kv_heads": 16,
    "ffw_size": int(8 / 3 * 1024),
    "n_layers": 24,
    "vocab_size": 50257,
    "swiglu": True,
    "seq_len": 512,
}

_350M["ffw_size"] = multiple_of * (
    (_350M["ffw_size"] + multiple_of - 1) // multiple_of
)

model = _350M

n_layers = model["n_layers"]
d_model = model["d_model"]
key_size = model["key_size"]
num_heads = model["num_heads"]
num_kv_heads = model["num_kv_heads"]
ffw_size = model["ffw_size"]
vocab_size = model["vocab_size"]
swiglu = model["swiglu"]
n_experts = 8
top_k = 2
seq_len = model["seq_len"]


flops = 3 * get_flops(
    n_layers,
    seq_len,
    vocab_size,
    d_model,
    key_size,
    num_heads=num_heads,
    ffw_size=ffw_size,
    swiglu=swiglu,
)
params = parameter_count(
    vocab_size=vocab_size,
    n_layers=n_layers,
    d_model=d_model,
    key_size=key_size,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    ffw_size=ffw_size,
    swiglu_or_geglu=swiglu,
)
print(params / 1e6)
print(flops)
iters = [20000 / 0.8, 30000 / 0.8, 40000 / 0.8]
print("iters", [float(f"{i / 1e3:.1f}") for i in iters])
print("tokens", [float(f"{400 * 512 * i / 1e9:.1f}") for i in iters])
print("ratio", [float(f"{400 * 512 * i / params:.1f}") for i in iters])
flops_all = [flops * 200 * i / 1e18 for i in iters]
print("flops", flops_all)
all_flops.append(flops_all)
all_params.append(params)
print("flop savings", (flops_all[-1] + 0.2 * sum(flops_all[:-1])) / sum(flops_all))
# double batch size?

359.744512
1341445373952
iters [25.0, 37.5, 50.0]
tokens [5.1, 7.7, 10.2]
ratio [14.2, 21.3, 28.5]
flops [6.70722686976, 10.06084030464, 13.41445373952]
flop savings 0.5555555555555556
