In [1]:
from typing import Literal

from pydantic import BaseModel

In [2]:
# no bias, rmsnorm

class Config(BaseModel):
        batch_size: int
        seq_len: int
        dim_model: int
        dim_mlp: int
        num_heads: int
        num_layers: int
        vocab_size: int

def num_activations(config: Config) -> int:
    return(
        config.batch_size * config.seq_len * config.num_layers * (
              2*config.dim_mlp + config.dim_model     # FFN
            + 4*config.dim_model                      # RMSNorm x2
            + 5*config.dim_model                      # SelfAttention non-head
            + 2*config.dim_model                      # Residual FFN, SelfAttention
            )
        + config.batch_size * 4*config.seq_len**2 * config.num_heads                        # SelfAttention head
        + config.batch_size * config.seq_len * (2*config.dim_model + 2*config.vocab_size)   # Embedding, Unembedding, Loss
        + config.batch_size * config.seq_len * 2*config.dim_model                           # Final RMSNorm
    )

def num_weights(config: Config) -> int:
    return(
        config.num_layers * (
                2*config.dim_model*config.dim_mlp      # FFN
                + 4*config.dim_model*config.dim_model  # SelfAttention
                + 2*config.dim_model                   # RMSNorm x2
                )
        + 2*config.dim_model*config.vocab_size         # Embedding, Unembedding
        + config.dim_model                             # Final RMSNorm
    )

def num_bytes_per_activation(variant: Literal["conventional", "a", "b", "c", "d"]) -> int:
    match variant:
        case "conventional":
            bytes_per_value = 2
        case "a":
            bytes_per_value = 2
        case "b":
            bytes_per_value = 2
        case "c":
            bytes_per_value = 1
        case "d":
            bytes_per_value = 1
    return bytes_per_value * 2  # 1 per activation, 1 per activation gradient

def num_bytes_per_weight(variant: Literal["conventional", "a", "b", "c", "d"]) -> int:
    match variant:
        case "conventional":
            weight = 6
            weight_grad = 2
            moment1 = 4
            moment2 = 4
        case "a":
            weight = 4
            weight_grad = 2
            moment1 = 2
            moment2 = 2
        case "b":
            weight = 4
            weight_grad = 2
            moment1 = 2
            moment2 = 0
        case "c":
            weight = 2
            weight_grad = 1
            moment1 = 1
            moment2 = 1
        case "d":
            weight = 2
            weight_grad = 1
            moment1 = 1
            moment2 = 0
    return weight + weight_grad + moment1 + moment2

def num_bytes(config: Config, variant: Literal["conventional", "a", "b", "c", "d"]) -> int:
    return (
          num_activations(config) * num_bytes_per_activation(variant)
        + num_weights(config) * num_bytes_per_weight(variant)
    )

In [3]:
config = Config(
    batch_size = 128,
    seq_len = 64,
    dim_model = 256,
    dim_mlp = 1024,
    num_heads = 8,
    num_layers = 6,
    vocab_size = 1000,
)
variant = "conventional"

print( f"weights            : {num_weights(config)     / 1e6:.4} M\n"
       f"activations        : {num_activations(config) / 1e6:.4} M\n"
       f"memory weights     : {num_weights(config)     * num_bytes_per_weight(variant)     / 1024/1024/1024:.4} GB\n"
       f"memory activations : {num_activations(config) * num_bytes_per_activation(variant) / 1024/1024/1024:.4} GB\n"
)

weights            : 5.234 M
activations        : 293.2 M
memory weights     : 0.07799 GB
memory activations : 1.092 GB

