In [1]:
import numpy as np

In [2]:
a16_ratio = 2
w8a16_ratio =  1.
w4a16_ratio = 0.5469

In [3]:
ddr_util = 1.
mac_util = 0.5 # 目前一般能达到50%的利用率
p2p_util = 1.
pcie_avg_ms = 0.1 # ms

In [4]:
def count_embedding_weights(hidden_size, vocab_size):
    num_weights = hidden_size * vocab_size
    return num_weights

def compute_lm_head_flops(hidden_size, vocab_size):
    flops = hidden_size * vocab_size * 2
    return flops

def count_lm_head_weights(hidden_size, vocab_size):
    num_weights = hidden_size * vocab_size
    return num_weights
    

def compute_block_flops(query_len, key_len, hidden_size, inter_size, num_attn_heads, num_kv_heads):
    num_qo_mm = 2
    num_kv_mm = 2
    num_attn_mm = 2
    num_mlp_mm = 3
    
    head_dim = hidden_size / num_attn_heads
    kv_dim = head_dim * num_kv_heads
    
    kv_flops = query_len * hidden_size * kv_dim * 2 * num_kv_mm
    qo_flops = query_len * hidden_size * hidden_size * 2 * num_qo_mm
    attn_flops = query_len * key_len * hidden_size * 2 * num_attn_mm
    mlp_flops = query_len * hidden_size * inter_size * 2 * num_mlp_mm
    
    total_flops = kv_flops + qo_flops + attn_flops + mlp_flops
    return total_flops

def count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads):
    num_qo_mm = 2
    num_kv_mm = 2
    num_mlp_mm = 3
    
    head_dim = hidden_size / num_attn_heads
    kv_dim = head_dim * num_kv_heads
    
    num_qo_weights = hidden_size * hidden_size * num_qo_mm
    num_kv_weights = hidden_size * kv_dim * num_kv_mm
    num_mlp_weights = hidden_size * inter_size * num_mlp_mm
    
    num_weights = num_qo_weights + num_kv_weights + num_mlp_weights
    return num_weights

def count_block_gather_weights(max_seq_len, hidden_size, num_attn_heads):
    head_dim = hidden_size / num_attn_heads
    num_weights = max_seq_len * head_dim * 2 # cos, sin table
    return num_weights

def count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads):
    head_dim = hidden_size / num_attn_heads
    kv_dim = head_dim * num_kv_heads
    
    num_kv_cache = 2 * seq_len * kv_dim * num_layers
    return num_kv_cache

def get_prefill_compute_time(total_flops, num_device, tpu_freq=1000):
    tpu_peak_flops = 16384 * num_device * tpu_freq / 1000
    
    compute_time = total_flops / 1e9 / tpu_peak_flops / mac_util
    return compute_time

def get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq=1000):
    p2p_bw = p2p_speed
    s2l_bw = s2l_speed * tpu_freq / 1000
    l2s_bw = l2s_speed * tpu_freq / 1000
    bf16_size = 2
    ring_data_ratio = (num_device - 1) * 2 / num_device
    size = seq_len * hidden_size * bf16_size * ring_data_ratio * num_layers * 2
    
    p2p_time = size / p2p_bw / p2p_util
    add_time = (size * 2 / s2l_bw + size / l2s_bw) / ddr_util
    allreduce_time = p2p_time + add_time
    return allreduce_time

def get_decode_pcie_time(num_layers):
    pcie_ms = pcie_avg_ms * num_layers * 2
    return pcie_ms

def get_decode_allreduce_time(num_layers, num_device):
    time_per_allreduce = 0
    if num_device == 2:
        time_per_allreduce = 0.12
    elif num_device == 4:
        time_per_allreduce = 0.15
    elif num_device == 6:
        time_per_allreduce = 0.36 # old allreduce
    elif num_device == 8:
        time_per_allreduce = 0.2
    elif num_device == 1:
        return 0
    else:
        print(f"****** Do Not Support num_device = {num_device} **********")
        return -1
    
    allreduce_ms = time_per_allreduce * num_layers * 2
    return allreduce_ms

def get_decode_load_weights_time(weight_bytes, num_device, tpu_freq=1000):
    s2l_bw = s2l_speed * tpu_freq / 1000
    
    load_weights_ms = weight_bytes / s2l_bw / num_device * 1000 / ddr_util
    return load_weights_ms

def get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq=1000):
    s2l_bw = s2l_speed * tpu_freq / 1000
    
    load_kv_cache_ms = kv_cache_bytes / s2l_bw / num_device * 1000 / ddr_util
    return load_kv_cache_ms

# Qwen-72B

## W4A16

In [5]:
seq_len = 8192
hidden_size = 8192
inter_size = 24576
num_attn_heads = 64
num_kv_heads = 64
num_layers = 80
vocab_size = 152064
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen-72B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

NameError: name 'p2p_speed' is not defined

# Qwen-14B
注意：Qwen-14B的inter_size不要直接用config.json文件里的intermediate_size，这个size并不是Block中的MLP的inter_size，所以inter_size最好是直接看一下Block中的MLP的参数

## W4A16

In [None]:
seq_len = 2048
hidden_size = 5120
inter_size = 13696
num_attn_heads = 40
num_kv_heads = 40
num_layers = 40
vocab_size = 152064
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen-14B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Qwen-14B-W4A16-seq2048-8dev-875MHz:
### Basic Information
embedding: 1485.0 MiB
block_gather: 1.0 MiB
lm_head: 1.55713536 GFLOPs, 406.07325000000003 MiB
blocks: 1377.07388928 GFLOPs, 164.41181250000002 MiB
all blocks: 55082.9555712 GFLOPs, 6576.472500000001 MiB
mm_weights: 6982.545750000001 MiB
total flops: 55084.51270656 GFLOPs
### Time Information
prefill_time = 2.125682113015873 s
decode_time = 45.42726545798095 ms, Speed = 22.013211447318444 token/s


## W8A16

In [None]:
seq_len = 2048
hidden_size = 5120
inter_size = 13696
num_heads = 40
num_layers = 40
vocab_size = 152064
ratio = w8a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen-14B-W8A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Qwen-14B-W8A16-seq2048-8dev-875MHz:
### Basic Information
embedding: 1485.0 MiB
block_gather: 1.0 MiB
lm_head: 1.55713536 GFLOPs, 742.5 MiB
blocks: 1377.07388928 GFLOPs, 300.625 MiB
all blocks: 55082.9555712 GFLOPs, 12025.0 MiB
mm_weights: 12767.5 MiB
total flops: 55084.51270656 GFLOPs
### Time Information
prefill_time = 2.125682113015873 s
decode_time = 59.87003733333333 ms, Speed = 16.7028457729596 token/s


# Llama2-13B

## W4A16

In [None]:
seq_len = 512
hidden_size = 5120
inter_size = 13824
num_attn_heads = 40
num_kv_heads = 40
num_layers = 40
vocab_size = 32000
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Llama2-13B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Llama2-13B-W4A16-seq512-8dev-875MHz:
### Basic Information
embedding: 312.5 MiB
block_gather: 0.25 MiB
lm_head: 0.32768 GFLOPs, 85.45312500000001 MiB
blocks: 330.17561088 GFLOPs, 165.43725 MiB
all blocks: 13207.0244352 GFLOPs, 6617.49 MiB
mm_weights: 6702.943125 MiB
total flops: 13207.3521152 GFLOPs
### Time Information
prefill_time = 0.521589053968254 s
decode_time = 41.733275452952384 ms, Speed = 23.961694574568934 token/s


## W8A16

In [None]:
seq_len = 512
hidden_size = 5120
inter_size = 13824
num_attn_heads = 40
num_kv_heads = 40
num_layers = 40
vocab_size = 32000
ratio = w8a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Llama2-13B-W8A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Llama2-13B-W8A16-seq512-8dev-875MHz:
### Basic Information
embedding: 312.5 MiB
block_gather: 0.25 MiB
lm_head: 0.32768 GFLOPs, 156.25 MiB
blocks: 330.17561088 GFLOPs, 302.5 MiB
all blocks: 13207.0244352 GFLOPs, 12100.0 MiB
mm_weights: 12256.25 MiB
total flops: 13207.3521152 GFLOPs
### Time Information
prefill_time = 0.521589053968254 s
decode_time = 55.59771428571429 ms, Speed = 17.9863509291235 token/s


# Qwen-7B

In [None]:
seq_len = 8192
hidden_size = 4096
inter_size = 11008
num_attn_heads = 32
num_kv_heads = 32
num_layers = 32
vocab_size = 151936
ratio = w4a16_ratio

tpu_freq = 950
num_device = 1

p2p_speed = 3e9
s2l_speed = 60e9
l2s_speed = 45e9

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen-7B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Qwen-7B-W4A16-seq8192-1dev-950MHz:
### Basic Information
embedding: 1187.0 MiB
block_gather: 4.0 MiB
lm_head: 1.244659712 GFLOPs, 324.58515000000006 MiB
blocks: 4415.226380288 GFLOPs, 105.55170000000001 MiB
all blocks: 141287.244169216 GFLOPs, 3377.6544000000004 MiB
mm_weights: 3702.2395500000002 MiB
total flops: 141288.488828928 GFLOPs
### Time Information
prefill_time = 18.154873667368424 s
decode_time = 149.8569620066807 ms, Speed = 6.673029978783498 token/s


In [None]:
seq_len = 8192
hidden_size = 3584
inter_size = 18944
num_attn_heads = 28
num_kv_heads = 2
num_layers = 28
vocab_size = 151936
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

p2p_speed = 3e9
s2l_speed = 60e9
l2s_speed = 45e9

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen-7B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

# Qwen1.5-32B

In [None]:
seq_len = 8192
hidden_size = 5120
inter_size = 27392
num_attn_heads = 40
num_kv_heads = 8
num_layers = 64
vocab_size = 152064
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Qwen1.5-32B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Qwen1.5-32B-W4A16-seq8192-8dev-875MHz:
### Basic Information
embedding: 1485.0 MiB
block_gather: 4.0 MiB
lm_head: 1.55713536 GFLOPs, 406.07325000000003 MiB
blocks: 9298.60419584 GFLOPs, 252.25762500000002 MiB
all blocks: 595110.66853376 GFLOPs, 16144.488000000001 MiB
mm_weights: 16550.561250000002 MiB
total flops: 595112.22566912 GFLOPs
### Time Information
prefill_time = 17.834474067301585 s
decode_time = 84.83334514590476 ms, Speed = 11.787817612049853 token/s


# Yi-34B

In [None]:
seq_len = 512
hidden_size = 7168
inter_size = 20480
num_attn_heads = 56
num_kv_heads = 8
num_layers = 60
vocab_size = 64000
ratio = w4a16_ratio

tpu_freq = 875
num_device = 8

lm_head_flops = compute_lm_head_flops(hidden_size, vocab_size)
block_flops = compute_block_flops(seq_len, seq_len, hidden_size, inter_size, num_attn_heads, num_kv_heads)

all_block_flops = block_flops * num_layers
total_flops = lm_head_flops + all_block_flops

embedding_weight_bytes = count_embedding_weights(hidden_size, vocab_size) * a16_ratio
lm_head_weight_bytes = count_lm_head_weights(hidden_size, vocab_size) * ratio
block_mm_weight_bytes = count_block_mm_weights(hidden_size, inter_size, num_attn_heads, num_kv_heads) * ratio
block_gather_weight_bytes = count_block_gather_weights(seq_len, hidden_size, num_attn_heads) * a16_ratio
mm_weight_bytes = lm_head_weight_bytes + block_mm_weight_bytes * num_layers

kv_cache_bytes = count_kv_cache(seq_len, hidden_size, num_layers, num_attn_heads, num_kv_heads) * a16_ratio

prefill_compute_time = get_prefill_compute_time(total_flops, num_device, tpu_freq)
prefill_allreduce_time = get_prefill_allreduce_time(seq_len, hidden_size, num_layers, num_device, tpu_freq)
decode_pcie_time = get_decode_pcie_time(num_layers)
decode_allreduce_time = get_decode_allreduce_time(num_layers, num_device)
decode_load_weights_time = get_decode_load_weights_time(mm_weight_bytes, num_device, tpu_freq)
decode_load_kv_cache_time = get_decode_load_kv_cache_time(kv_cache_bytes, num_device, tpu_freq)
total_prefill_time = prefill_compute_time + prefill_allreduce_time
total_decode_time = decode_pcie_time + decode_allreduce_time + decode_load_weights_time + decode_load_kv_cache_time
tps = 1000 / total_decode_time

print(f'## Yi-34B-W4A16-seq{seq_len}-{num_device}dev-{tpu_freq}MHz:')
print(f'### Basic Information')
print(f'embedding: {embedding_weight_bytes/2**20} MiB')
print(f'block_gather: {block_gather_weight_bytes/2**20} MiB')
print(f'lm_head: {lm_head_flops/1e9} GFLOPs, {lm_head_weight_bytes/2**20} MiB')
print(f'blocks: {block_flops/1e9} GFLOPs, {block_mm_weight_bytes/2**20} MiB')
print(f'all blocks: {all_block_flops/1e9} GFLOPs, {block_mm_weight_bytes*num_layers/2**20} MiB')
print(f'mm_weights: {mm_weight_bytes/2**20} MiB')
print(f'total flops: {total_flops/1e9} GFLOPs')
print(f'### Time Information')
print(f'prefill_time = {total_prefill_time} s')
print(f'decode_time = {total_decode_time} ms, Speed = {tps} token/s')

## Yi-34B-W4A16-seq512-8dev-875MHz:
### Basic Information
embedding: 875.0 MiB
block_gather: 0.25 MiB
lm_head: 0.917504 GFLOPs, 239.26875 MiB
blocks: 578.746843136 GFLOPs, 290.9508 MiB
all blocks: 34724.81058816 GFLOPs, 17457.048000000003 MiB
mm_weights: 17696.31675 MiB
total flops: 34725.72809216 GFLOPs
### Time Information
prefill_time = 1.2172379733333334 s
decode_time = 80.48038607725715 ms, Speed = 12.425387709249431 token/s
