In [1]:
def gpt_parameter_count(
    vocab_size=50000,
    max_seq_len=512,
    d_model=512,
    d_ff=None,
    n_layers=6,
    tied_output=True
):
    """
    Calculate total parameters in a GPT-like transformer model.

    Args:
    - vocab_size (int): Size of the vocabulary.
    - max_seq_len (int): Maximum sequence length.
    - d_model (int): Embedding/hidden size.
    - d_ff (int or None): FFN size (if None, uses 4 * d_model).
    - n_layers (int): Number of transformer blocks.
    - tied_output (bool): If True, output projection shares weights with input embeddings.

    Returns:
    - total_params (int): Total number of parameters.
    """
    if d_ff is None:
        d_ff = 4 * d_model

    # Embedding layers
    token_embeddings = vocab_size * d_model
    positional_embeddings = max_seq_len * d_model
    embedding_params = token_embeddings + positional_embeddings

    # Transformer block parameters per layer
    mhsa_params_per_layer = 4 * d_model * d_model  # Q, K, V, Out
    ffn_params_per_layer = 2 * d_model * d_ff      # Linear1 + Linear2
    layernorm_params_per_layer = 4 * d_model       # 2 LayerNorms (gamma + beta)

    transformer_params_per_layer = (
        mhsa_params_per_layer + ffn_params_per_layer + layernorm_params_per_layer
    )

    transformer_total = n_layers * transformer_params_per_layer

    # Output projection
    output_projection = 0 if tied_output else d_model * vocab_size

    total_params = embedding_params + transformer_total + output_projection

    print(f"Token Embeddings:      {token_embeddings:,}")
    print(f"Positional Embeddings: {positional_embeddings:,}")
    print(f"Transformer Blocks:    {transformer_total:,}")
    print(f"Output Projection:     {output_projection:,}")
    print(f"{'-'*40}")
    print(f"Total Parameters:      {total_params:,}")

    return total_params


# Example usage
gpt_parameter_count(
    vocab_size=50000,
    max_seq_len=512,
    d_model=512,
    n_layers=6,
    tied_output=False  # Set True if output weights are shared with input
)


Token Embeddings:      25,600,000
Positional Embeddings: 262,144
Transformer Blocks:    18,886,656
Output Projection:     25,600,000
----------------------------------------
Total Parameters:      70,348,800


70348800