Jaisidh Singh, Diganta Misra, Antonio Orvieto
We investigate grokking in transformers through the lens of inductive bias: dispositions arising from architecture or optimization that let the network prefer one solution over another. We first show that architectural choices such as the position of Layer Normalization (LN) strongly modulates grokking speed. This modulation is explained by isolating how LN on specific pathways shapes shortcut-learning and attention entropy. Subsequently, we study how different optimization settings modulate grokking, inducing distinct interpretations of previously proposed controls such as readout scale. Particularly, we find that using readout scale as a control for lazy training can be confounded by learning rate and weight decay in our setting. Accordingly, we show that features evolve continuously throughout training, suggesting grokking in transformers can be more nuanced than a lazy-to-rich transition of the learning regime. Finally, we show how generalization predictably emerges with feature compressibility in grokking, across different modulators of inductive bias.
This repository releases our code to explanably modulate grokking by varying the position of layer normalization (LN) in transformers learning modular addition. Experiments given in the submission can be reproduced by running the files in the scripts folder. We use wandb to log all metrics across different seeds. Subsequently, these metrics are aggregated across seeds on the wandb console, from where each metric presented in the experiments is downloaded in csv format. Note that experiments measuring attention entropy are not included in these scripts. To measure attention score entropy, simply save checkpoints and then measure the entropy in the forward pass as follows:
@torch.no_grad()
def measure(cfg, model, tokens):
x = model.embed(tokens.to(cfg.device))
x = x + model.pos_emb
for layer in model.layers:
if cfg.ln_type == "off" or cfg.ln_type == "mlp-pre":
a, w = layer.attn(x, x, x, need_weights=True)
elif cfg.ln_type == "attn-pre" or cfg.ln_type == "pre":
a, w = layer.attn(layer.ln1(x), layer.ln1(x), layer.ln1(x), need_weights=True)
elif cfg.ln_type == "attn-pre-qk":
a, w = layer.attn(layer.ln1(x), layer.ln1(x), x, need_weights=True)
elif cfg.ln_type == "attn-pre-v" or cfg.ln_type == "pre-v":
a, w = layer.attn(x, x, layer.ln1(x), need_weights=True)
X = a[:, -1, :] @ a[:, -1, :].T
attn = w.clamp(cfg.eps)
entropy = -1 * (attn * attn.log()).sum(dim=-1).mean()
return entropyNote that we use different specifiers in the code for the LN configurations given in our preprint. The code specifiers are provided to the ln_type subargument of the model subconfig. These are defined as follows.
-
No LN:
off -
A
$^*$ :attn-pre -
M+A
$^*$ :pre -
M:
mlp-pre -
A
$^{qk}$ :attn-pre-qk -
A
$^{v}$ :attn-pre-v

