In [1]:
from transformers import AutoModelForCausalLM, LlamaForCausalLM
import torch
from custom_flop_counter import PerformanceCounterMode

%load_ext autoreload
%autoreload 2

In [3]:
model_id = "/home/ubuntu/gpt-fast-dev/checkpoints/7B"
model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
input_ids = torch.randint(
    0, model.config.vocab_size, (1, 16), dtype=torch.int64, device="cuda"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
with PerformanceCounterMode(display=False, depth=10) as flop_counter:
    _ = model(input_ids)

In [5]:
model_params = model.num_parameters()
model_dtype_size = next(model.parameters()).element_size()

In [9]:
mem_counts = flop_counter.get_memory_counts()
keys = list(mem_counts.keys())

In [18]:
lm_head_key = keys[-1]
lm_head_size = sum(v for v in mem_counts[lm_head_key].values())

In [27]:
model_modules = dict(model.named_modules())
model_children = dict(model.named_children())

In [7]:
model_params = model.model.num_parameters()
model_dtypes = set([p.dtype for p in model.parameters()])
model_size = sum([p.numel() * p.element_size() for p in model.model.parameters()])
lm_head_dtype = next(model.lm_head.parameters()).dtype
lm_head_size = model.lm_head.weight.numel() * model.lm_head.weight.element_size()

In [8]:
print(f"{model_size + lm_head_size} = {model_size} + {lm_head_size}")

13476831232 = 13214687232 + 262144000


In [9]:
embed_token_size = sum(
    [p.numel() * p.element_size() for p in model.model.embed_tokens.parameters()]
)

In [12]:
mem_counts = flop_counter.get_data_counts()


def sum_counts(counts, key):
    return sum(counts[key].values())


model_count = sum_counts(mem_counts, "LlamaForCausalLM.model")
lm_head_count = sum_counts(mem_counts, "LlamaForCausalLM.lm_head")
total_count = sum_counts(mem_counts, "Global")
print(f"{model_count + lm_head_count} = {model_count} + {lm_head_count}")
print(
    f"{model_count + lm_head_count + embed_token_size} = {model_count} + {lm_head_count} + {embed_token_size}"
)

print(f"Global count: {total_count + embed_token_size}")

13312248320 = 13048949248 + 263299072
13574392320 = 13048949248 + 263299072 + 262144000
Global count: 13574392320


In [14]:
print(f"Lm head diff: {(lm_head_count - lm_head_size)/1e9}GB")
print(f"Model diff: {(model_count + embed_token_size - model_size)/1e9}GB")

Lm head diff: 0.001155072GB
Model diff: 0.096406016GB


In [79]:
named_params = dict(model.named_parameters())
params = list(model.model.parameters())

In [83]:
params_sum = sum([p.numel() * p.element_size() for p in params]) / 1e9
lm_head_sum = (
    sum([p.numel() * p.element_size() for p in model.lm_head.parameters()]) / 1e9
)
print(f"{params_sum + lm_head_sum} = {params_sum} + {lm_head_sum}")

13.476831231999999 = 13.214687232 + 0.262144


In [48]:
from collections import defaultdict

module_names = defaultdict(int)
for k in named_params.keys():
    normalized_name = k.split(".")[-2]
    module_names[normalized_name] += (
        named_params[k].numel() * named_params[k].element_size()
    )

In [50]:
for k, v in module_names.items():
    print(f"{k}: {v/1e9}GB")

embed_tokens: 0.262144GB
q_proj: 1.073741824GB
k_proj: 1.073741824GB
v_proj: 1.073741824GB
o_proj: 1.073741824GB
gate_proj: 2.885681152GB
up_proj: 2.885681152GB
down_proj: 2.885681152GB
input_layernorm: 0.000262144GB
post_attention_layernorm: 0.000262144GB
norm: 8.192e-06GB
lm_head: 0.262144GB


In [77]:
model_params_no_norm_embed = sum(
    [
        module_names[k]
        for k in module_names.keys()
        if not ("norm" in k or "lm_head" in k or "embed_tokens" in k)
    ]
)
embed_params = sum(
    [module_names[k] for k in module_names.keys() if "embed_tokens" in k]
)
norm_params = sum([module_names[k] for k in module_names.keys() if "norm" in k])
lm_params = sum([module_names[k] for k in module_names.keys() if "lm_head" in k])

In [78]:
print(
    f"{(model_params_no_norm_embed + lm_params + norm_params + embed_params)/1e9}GB = {model_params_no_norm_embed/1e9} + {lm_params/1e9} + {norm_params/1e9} + {embed_params/1e9}"
)

13.476831232GB = 12.952010752 + 0.262144 + 0.00053248 + 0.262144


In [58]:
named_modules = dict(model.named_modules())

In [59]:
len(named_modules)

454

In [65]:
flop_counter.flop_counts.keys()
flop_mods = defaultdict(int)
for k in flop_counter.flop_counts.keys():
    if "." in k:
        splat = k.split(".")
        m = splat[-1]
        if not m.isdigit():
            flop_mods[m] += sum_counts(flop_counter.flop_counts, k)

In [66]:
flop_mods.keys()

dict_keys(['self_attn', 'q_proj', 'model', 'k_proj', 'v_proj', 'rotary_emb', 'o_proj', 'mlp', 'gate_proj', 'up_proj', 'down_proj', 'lm_head'])

In [67]:
module_names

defaultdict(int,
            {'embed_tokens': 262144000,
             'q_proj': 1073741824,
             'k_proj': 1073741824,
             'v_proj': 1073741824,
             'o_proj': 1073741824,
             'gate_proj': 2885681152,
             'up_proj': 2885681152,
             'down_proj': 2885681152,
             'input_layernorm': 262144,
             'post_attention_layernorm': 262144,
             'norm': 8192,
             'lm_head': 262144000})