# Why so big? Counting parameters in sequence models

PaLM has 540 billion parameters. What could they possibly all be doing? Let's figure out where the parameter budget in sequence models goes.

## Setup

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

def num_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

We'll use the same input setup as last time.

In [73]:
sentence = "This will be the input to a language model."

In [74]:
sentence_tensor = torch.tensor([[ord(x) for x in sentence]])

In [75]:
targets = sentence_tensor[:, 1:]
input_ids = sentence_tensor[:, :-1]
assert input_ids.shape == targets.shape

# Embeddings

A big chunk of the parameters of a model ends up in the word embeddings.

In [76]:
n_vocab = 256
emb_dim = 5
embedder = nn.Embedding(n_vocab, emb_dim)
embedder.weight.shape

torch.Size([256, 5])

In [77]:
num_parameters(embedder)

1280

In [78]:
input_embeds = embedder(input_ids)

In [79]:
def num_params_for_embedding(n_vocab, emb_dim):
    # return ...
    return n_vocab * emb_dim

assert (
    num_params_for_embedding(n_vocab, emb_dim)
    == num_parameters(nn.Embedding(n_vocab, emb_dim))
)
assert (
    num_params_for_embedding(50000, 2048)
    == num_parameters(nn.Embedding(50000, 2048))
)

In [80]:
class Model1(nn.Module):
    def __init__(self, n_vocab, emb_dim, tie_weights=True):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        return self.lm_head(self.word_to_embedding(input_ids))


num_parameters(Model1(n_vocab=n_vocab, emb_dim=emb_dim))

1280

*maybe: ask about why `bias=False`*

ask about what happens when `tie_weights=False`.


In [81]:
class MLP(nn.Module):
    def __init__(self, emb_dim, n_hidden):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=n_hidden),
            nn.ReLU(), # or nn.GELU() or others
            nn.Linear(in_features=n_hidden, out_features=emb_dim)
        )

    def forward(self, x):
        return self.model(x)

num_parameters(MLP(emb_dim=emb_dim, n_hidden=16))

181

Don't forget the biases.

In [82]:
def num_parameters_for_mlp(emb_dim, n_hidden):
    # return ...
    return (emb_dim + 1) * n_hidden + (n_hidden + 1) * emb_dim
num_parameters_for_mlp(emb_dim, 16)

181

In [83]:
class FeedForwardLM(nn.Module):
    def __init__(self, n_vocab, emb_dim, n_hidden, tie_weights=True):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.model = MLP(emb_dim=emb_dim, n_hidden=n_hidden)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        if tie_weights:
            assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
            self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        x = self.model(input_embeds)
        return self.lm_head(x)


ff_lm = FeedForwardLM(n_vocab=n_vocab, emb_dim=emb_dim, n_hidden=16)
num_parameters(ff_lm)

1461

In [84]:
ff_lm(input_ids).shape

torch.Size([1, 42, 256])

## Transformer

We're going to implement an oversimplified Transformer layer. If you're interested, here are a few reference implementations of the real thing:

- [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
- PyTorch's builtin implementation: [`TransformerEncoderLayer`](https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer) and [`MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention).
- [minGPT](https://github.com/karpathy/minGPT/blob/master/mingpt/model.py)

In [85]:
class BareBonesSelfAttention(nn.Module):
    '''Implements *single-head* attention, no masking, no dropout, no scaling, no init'''
    def __init__(self, emb_dim):
        super().__init__()
        self.get_query = nn.Linear(emb_dim, emb_dim)
        self.get_key = nn.Linear(emb_dim, emb_dim)
        self.get_value = nn.Linear(emb_dim, emb_dim)
        self.to_output = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):
        n_batch, seq_len, emb_dim = x.shape

        # Compute query, key, and value vectors.
        q = self.get_query(x) # (n_batch, seq_len, emb_dim)
        k = self.get_key(x)
        v = self.get_value(x)

        # Compute attention weights
        k_transpose = k.transpose(-2, -1)
        assert k_transpose.shape == (n_batch, emb_dim, seq_len)
        scores = q @ k_transpose
        assert scores.shape == (n_batch, seq_len, seq_len)
        attention_weights = scores.softmax(dim=-1)

        # Compute weighted sum of values.
        out = attention_weights @ v
        assert out.shape == x.shape
        return out

BareBonesSelfAttention(emb_dim)(input_embeds).shape


torch.Size([1, 42, 5])

In [86]:
class BareBonesTransformerLayer(nn.Module):
    '''Implements bare-bones self-attention transformer layer, no residual connections, no dropout'''
    def __init__(self, emb_dim, dim_feedforward):
        super().__init__()
        self.self_attention = BareBonesSelfAttention(emb_dim)
        self.mlp = MLP(emb_dim, n_hidden=dim_feedforward)
        self.norm_after_attn = nn.LayerNorm(emb_dim)
        self.norm_after_mlp = nn.LayerNorm(emb_dim)
    
    def forward(self, x):
        x = self.self_attention(x)
        x = self.norm_after_attn(x)
        x = self.mlp(x)
        x = self.norm_after_mlp(x)
        return x

xformer_layer = BareBonesTransformerLayer(emb_dim, emb_dim)
xformer_layer(input_embeds).shape

torch.Size([1, 42, 5])

In [87]:
num_parameters(xformer_layer)

200

In [88]:
class TransformerLM(nn.Module):
    def __init__(self, n_vocab, max_len, emb_dim, n_hidden):
        super().__init__()
        self.word_to_embedding = nn.Embedding(n_vocab, emb_dim)
        self.pos_to_embedding = nn.Embedding(max_len, emb_dim)
        self.model = BareBonesTransformerLayer(emb_dim=emb_dim, dim_feedforward=n_hidden)
        self.lm_head = nn.Linear(emb_dim, n_vocab, bias=False)

        assert self.lm_head.weight.shape == self.word_to_embedding.weight.shape
        self.lm_head.weight = self.word_to_embedding.weight
    
    def forward(self, input_ids):
        input_embeds = self.word_to_embedding(input_ids)
        # Compute position embeddings.
        position_ids = torch.arange(input_ids.shape[-1])
        pos_embeds = self.pos_to_embedding(position_ids)
        x = input_embeds + pos_embeds
        x = self.model(x)
        return self.lm_head(x)


xformer_lm = TransformerLM(n_vocab=n_vocab, max_len=50, emb_dim=emb_dim, n_hidden=16)
num_parameters(xformer_lm)

1851

In [89]:
num_parameters(BareBonesSelfAttention(emb_dim))

120

In [90]:
num_parameters(nn.LayerNorm(emb_dim))

10

In [91]:
(
    # Attention layer:
    # 3 linear layers (q, k, v) from emb_dim to emb_dim, each has a bias
    3 * emb_dim * (emb_dim + 1)
    # output projection
    + emb_dim * (emb_dim + 1)
    # Feedforward:
    # 2-layer feedforward
    + num_parameters_for_mlp(emb_dim, emb_dim)
    # layer norms
    + 2 * (emb_dim + emb_dim)
)

200

In [92]:
num_parameters(BareBonesTransformerLayer(emb_dim, emb_dim))

200