Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

question about the memory calculation #19

Closed
ShouyangDong opened this issue Dec 26, 2023 · 2 comments
Closed

question about the memory calculation #19

ShouyangDong opened this issue Dec 26, 2023 · 2 comments

Comments

@ShouyangDong
Copy link

ShouyangDong commented Dec 26, 2023

Hello @cli99, Thank you very much for open-sourcing your library for analyzing large language models. This is very helpful for us to understand various optimization algorithms and parallel configuration strategies. After going through the code, I have encountered a few questions.

In the paper "Reducing Activation Recomputation in Large Transformer Models", there are two LayerNorms in one transformer layer. But in the code:

        weight_memory_per_layer = (
            weight_memory_attn_per_layer
            + weight_memory_mlp_per_layer
            + weight_memory_layernorm_per_layer
        )

only one "weight_memory_layernorm_per_layer" is added.

Also in this paper, the blocks which can be parallelized using tensor parallelism are attention and mlp. But in the code below, LayerNorm can also be parallelized when tensor parallelism is applied.

        weight_memory_layernorm_per_layer = (
            self.get_num_params_per_layer_layernorm()
            * self.dtype_config.weight_bits
            / BITS_PER_BYTE
            / self.parallelism_config.tp_size
            / sharded_dp_size
        )

When I print the summary_dict in the provied python script in llama2 folder, it given me the following result:

{'batch_size_per_gpu': 1, 'seq_len': 512, 'tp_size': 2, 'ep_size': 1, 'pp_size': 1, 'num_tokens_to_generate': 32, 'flops_efficiency': 0.6, 'hbm_memory_efficiency': 0.6, 'layernorm_dtype_bytes': 2, 'use_kv_cache': True, 'kv_cache_latency': 0.00014570698054601933, 'kv_cache_memory_per_gpu': 89128960.0, 'weight_memory_per_gpu': 55292731392.0, 'weight_memory_embedding_per_gpu': 262144000.0, 'prefill_activation_memory_per_gpu': 16777216.0, 'prefill_max_batch_size_per_gpu': 1824, 'prefill_num_flops_fwd_total': 57305601146880.0, 'decode_activation_memory_per_gpu': 32768.0, 'decode_max_batch_size_per_gpu': 343, 'decode_num_flops_fwd_total': 110585446400.0, 'prefill_latency': 0.15908735621271564, 'prefill_latency_fwd_attn': 0.03487366607863248, 'prefill_latency_fwd_mlp': 0.11746919100170941, 'prefill_latency_fwd_layernorm': 0.001097087853522969, 'prefill_latency_fwd_tp_comm': 0.004473924266666667, 'prefill_latency_fwd_sharded_dp_comm': 0.0, 'prefill_latency_fwd_input_embedding': 0.00045651196944907646, 'prefill_latency_fwd_output_embedding_loss': 0.0007169750427350427, 'decode_latency': 0.04685015182136376, 'decode_latency_fwd_attn': 0.009875397743992155, 'decode_latency_fwd_mlp': 0.03510895406244892, 'decode_latency_fwd_layernorm': 2.1427497139120487e-06, 'decode_latency_fwd_tp_comm': 0.0012799999999999999, 'decode_latency_fwd_sharded_dp_comm': 0.0, 'decode_latency_fwd_input_embedding': 0.00043654994278240976, 'decode_latency_fwd_output_embedding_loss': 1.4003418803418803e-06, 'total_decode_latency': 1.4992048582836404, 'total_latency': 1.658292214496356, 'total_per_token_latency': 0.05182163170301113, 'prefill_tokens_per_sec': 3218.3575878613824, 'decode_tokens_per_sec': 21.344648013370964, 'total_tokens_per_sec': 19.296960885581193, 'prefill_cost_per_1k_tokens': 0.00038149203258474565, 'decode_cost_per_1k_tokens': 0.057521575291785504, 'total_cost_per_1k_tokens': 0.06362544781314144}

weight_memory_per_gpu is 55292731392.0, which is less 70B * 2 / 2 = 70B

So can you provide more information about this? Look forward to your response, thank you once again.

@cli99
Copy link
Owner

cli99 commented Jan 3, 2024

Hi @ShouyangDong, thanks for bringing up the questions/issues

  1. there is a *2 when calculating weight_memory_layernorm_per_layer, https://github.com/cli99/llm-analysis/blob/main/llm_analysis/analysis.py#L306, we count the weight of each layernorm as hidden_dim
  2. layernorm is not tensor parallelized, thanks for catching this, I removed the dividing by tp_size.
  3. the llama2 model uses gated linear units and I recently added a field in ModelConfig to specify mlp_gated_linear_units to true in model json files. If the model config is pulled from HF by name, there is no equivalent information to use, so I added https://github.com/cli99/llm-analysis/blob/main/llm_analysis/config.py#L218-L223 and make the assumption that if the expansion ratio (intermidiate dim / model dim) is 3.5, the model uses gated linear units. Now the params estimation for llama2 shall be correct.

@ShouyangDong
Copy link
Author

Thank you for your clarification.

@cli99 cli99 closed this as completed Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants