# Grokking : Generalization beyond Overfitting on Small Dataset

This Notebook aims at exploring the official implementation


In [1]:
import torch as th
from torch.nn import functional as F
import pandas as pd
import numpy as np
from typing import Dict, Tuple
import itertools
import time
import grok
import pytorch_lightning as pl


## Understanding the Dataset and DataLoader


### Dataset


In [2]:
from grok.data import ArithmeticDataset, ArithmeticTokenizer

train_ds, valid_ds = ArithmeticDataset.splits(50, "s5", None, "./data")


In [3]:
assert 120 ** 2 == train_ds.data.shape[0] + valid_ds.data.shape[0]


In [4]:
train_ds.data[:5].shape, train_ds.tokenizer.decode(train_ds.data[:5])


(torch.Size([5, 7]),
 ['<|eos|> 21304 s5 12430 = 42310 <|eos|>',
  '<|eos|> 34102 s5 10243 = 43012 <|eos|>',
  '<|eos|> 03124 s5 31420 = 32140 <|eos|>',
  '<|eos|> 42130 s5 41203 = 32104 <|eos|>',
  '<|eos|> 40123 s5 03142 = 20314 <|eos|>'])

## Copy Network


In [5]:
from grok.transformer import Transformer


In [14]:
net = Transformer(n_layers=2, d_model=32, n_heads=4).float()
net2 = Transformer(n_layers=2, d_model=33, n_heads=4).float()
net.n_heads, net.n_layers


(4, 2)

In [25]:
from functools import reduce

In [46]:
exp_method = "zero"
for (k1,p1),(k2,p2) in zip(net.named_parameters(), net2.named_parameters()):
    assert k1 == k2
    print(k1)
    #print(p1.shape, p2.shape)

embedding.weight
decoder.blocks.0.self_attn.attn_heads.0.Wq.weight
decoder.blocks.0.self_attn.attn_heads.0.Wk.weight
decoder.blocks.0.self_attn.attn_heads.0.Wv.weight
decoder.blocks.0.self_attn.attn_heads.1.Wq.weight
decoder.blocks.0.self_attn.attn_heads.1.Wk.weight
decoder.blocks.0.self_attn.attn_heads.1.Wv.weight
decoder.blocks.0.self_attn.attn_heads.2.Wq.weight
decoder.blocks.0.self_attn.attn_heads.2.Wk.weight
decoder.blocks.0.self_attn.attn_heads.2.Wv.weight
decoder.blocks.0.self_attn.attn_heads.3.Wq.weight
decoder.blocks.0.self_attn.attn_heads.3.Wk.weight
decoder.blocks.0.self_attn.attn_heads.3.Wv.weight
decoder.blocks.0.self_attn.Wo.weight
decoder.blocks.0.self_attn_norm.weight
decoder.blocks.0.self_attn_norm.bias
decoder.blocks.0.ffn.ffn.0.weight
decoder.blocks.0.ffn.ffn.2.weight
decoder.blocks.0.ffn_norm.weight
decoder.blocks.0.ffn_norm.bias
decoder.blocks.1.self_attn.attn_heads.0.Wq.weight
decoder.blocks.1.self_attn.attn_heads.0.Wk.weight
decoder.blocks.1.self_attn.attn_heads.

In [7]:
th.randint(5, [2, 7])


tensor([[4, 4, 2, 4, 0, 0, 1],
        [0, 1, 4, 1, 3, 0, 3]])

In [8]:
net(th.randint(5, [1, 7]))


(tensor([[[-0.3065,  0.5112,  0.0829,  ..., -0.0302,  0.2889,  0.0773],
          [-0.1615,  0.5361,  0.0981,  ..., -0.1725,  0.4206, -0.0861],
          [ 0.7925, -1.1406,  0.1976,  ...,  0.1937,  0.5750,  1.0105],
          ...,
          [ 0.9616,  0.3109,  0.2930,  ..., -0.9916, -0.0333, -0.0638],
          [ 0.2606, -0.4889,  0.5550,  ...,  0.2374,  0.6205,  0.7361],
          [-0.1961,  1.0642,  0.3349,  ..., -0.5135,  0.6176, -0.4522]]],
        grad_fn=<UnsafeViewBackward0>),
 [],
 [])

In [9]:
t1 = th.randn([5, 2])
t2 = th.randn([7, 5])
t_ = t1.clone()
for dim in range(len(t2.shape)):
    m = t2.shape[dim] - t1.shape[dim]
    idx = th.tensor(np.random.choice(range(t1.shape[dim]), size=m, replace=True))
    m_ = th.index_select(t_, dim, idx)
    t_ = th.cat((t_, m_), dim=dim)


In [10]:
def expand_model(
    teacher_model: grok.transformer.Transformer,
    add_dmodel: int,
    exp_method: str = "random",
) -> grok.transformer.Transformer:
    """Expand a Transformer to a multiple of its size.

    Args:
        parent_net:(grok.transformer.Transformer) The parent model to expand from.
        add_dmodel:(int) increase in the size of d_model.
        exp_method:(str) [duplicate | random | zero] Method used to initialize new parameter.

    Returns:
        student_model: (grok.transformer.Transformer) The new transformer with d_model = parent_model.d_model + add_dmodel
    """
    params1 = teacher_model.state_dict()
    student_model = type(teacher_model)(
        n_layers=teacher_model.n_layers,
        n_heads=teacher_model.n_heads,
        d_model=teacher_model.d_model + add_dmodel,
    )
    params2 = student_model.state_dict()

    assert exp_method in ["duplicate", "random", "zero"], "Invalid expansion method."
    
    return student_model


In [12]:
net2 = Transformer(n_heads=4, n_layers=2, d_model=64)
print(type(net2)(n_heads=4, n_layers=2, d_model=16))
expand_model(net, 2, exp_method="random")


Transformer(
  (embedding): Embedding(2000, 16)
  (decoder): Decoder(
    (blocks): ModuleList(
      (0): DecoderBlock(
        (self_attn): MultiHeadAttention(
          (attn_heads): ModuleList(
            (0): AttentionHead(
              (Wq): Linear(in_features=16, out_features=4, bias=False)
              (Wk): Linear(in_features=16, out_features=4, bias=False)
              (Wv): Linear(in_features=16, out_features=4, bias=False)
              (softmax): Softmax(dim=-1)
            )
            (1): AttentionHead(
              (Wq): Linear(in_features=16, out_features=4, bias=False)
              (Wk): Linear(in_features=16, out_features=4, bias=False)
              (Wv): Linear(in_features=16, out_features=4, bias=False)
              (softmax): Softmax(dim=-1)
            )
            (2): AttentionHead(
              (Wq): Linear(in_features=16, out_features=4, bias=False)
              (Wk): Linear(in_features=16, out_features=4, bias=False)
              (Wv): Linear(i

Transformer(
  (embedding): Embedding(2000, 34)
  (decoder): Decoder(
    (blocks): ModuleList(
      (0): DecoderBlock(
        (self_attn): MultiHeadAttention(
          (attn_heads): ModuleList(
            (0): AttentionHead(
              (Wq): Linear(in_features=34, out_features=8, bias=False)
              (Wk): Linear(in_features=34, out_features=8, bias=False)
              (Wv): Linear(in_features=34, out_features=8, bias=False)
              (softmax): Softmax(dim=-1)
            )
            (1): AttentionHead(
              (Wq): Linear(in_features=34, out_features=8, bias=False)
              (Wk): Linear(in_features=34, out_features=8, bias=False)
              (Wv): Linear(in_features=34, out_features=8, bias=False)
              (softmax): Softmax(dim=-1)
            )
            (2): AttentionHead(
              (Wq): Linear(in_features=34, out_features=8, bias=False)
              (Wk): Linear(in_features=34, out_features=8, bias=False)
              (Wv): Linear(i

In [14]:
class TrainableExpTransformer(pl.LightningModule):
    def __init__(self, n_layers=2, d_model=4, n_heads=2):
        super().__init__()
        self.init_n_layers = n_layers
        self.d_model = d_model
        self.init_n_heads = n_heads
        self.transformer = Transformer(
            n_layers=self.init_n_layers,
            d_model=self.d_model,
            n_heads=self.init_n_heads,
        ).float()

    def expand_model(
        self,
        add_dmodel: int,
        exp_method: str = "random",
    ) -> None:
        """Expand Transformer dmodel to dmodel + add_dmodel.

        Args:
            parent_net:(grok.transformer.Transformer) The parent model to expand from.
            add_dmodel:(int) increase in the size of d_model.
            exp_method:(str) [duplicate | random | zero] Method used to initialize new parameter.
        """
        print(f"\nExpanding to size {self.d_model*add_dmodel}")

        teacher_model = self.transformer
        params1 = teacher_model.state_dict()
        student_model = type(teacher_model)(
            n_layers=teacher_model.n_layers,
            n_heads=teacher_model.n_heads,
            d_model=teacher_model.d_model + add_dmodel,
        )
        params2 = student_model.state_dict()

        assert exp_method in [
            "duplicate",
            "random",
            "zero",
        ], "Invalid expansion method."
        params_new = {}
        for k in params2:
            if k == "self_attn_mask":
                params_new.update({k: params2[k].clone()})

            elif params2[k].shape == params1[k].shape:
                params_new.update({k: params1[k].clone()})
            else:
                new_shape = params2[k].shape
                old_shape = params1[k].shape
                w_ = params1[k].clone()
                for dim in range(len(new_shape)):
                    # m is the size  to concat in dimension `dim``
                    m = new_shape[dim] - old_shape[dim]
                    if exp_method == "duplicate":
                        idx = th.tensor(
                            np.random.choice(range(t1.shape[dim]), size=m, replace=True)
                        )
                        v_ = th.index_select(w_, dim, idx)
                        w_ = th.cat((w_, v_), dim=dim)

                    elif exp_method == "random":
                        shape_of_exta = w_.shape[:dim] + (m,) + w_.shape[dim + 1 :]
                        v_ = th.randn(shape_of_exta)
                        w_ = th.cat((w_, v_), dim=dim)

                    elif exp_method == "zero":
                        m = new_shape[dim] - old_shape[dim]
                        shape_of_exta = w_.shape[:dim] + (m,) + w_.shape[dim + 1 :]
                        v_ = th.zeros(shape_of_exta)
                        w_ = th.cat((w_, v_), dim=dim)

                params_new.update({k: w_})
        student_model.load_state_dict(params_new)
        self.transformer = student_model.float()

    def forward(self, *args, **kwargs):
        """Passes all arguments directly to Tranformer.forward()"""
        return self.transformer(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        if batch_idx == 0:
            self.training_epoch_start_time = time.time()
            self.fwd_time_in_epoch = 0

        start = time.time()
        loss = self._step(batch=batch, batch_idx=batch_idx, train=True)
        self.fwd_time_in_epoch += time.time() - start

        return loss

    def _step(
        self, batch: Dict, batch_idx: int, train: bool = True, reduction: str = "mean"
    ) -> Tuple[th.Tensor, th.Tensor, float, th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
        x, y = batch  # shape = batchsize * context_len
        y_hat, attentions, values = self(
            x=x
        )  # shape = batchsize * context_len * vocab_size
        y_hat = y_hat.transpose(-2, -1)  # shape = batchsize * vocab_size * context_len
        loss = F.cross_entropy(y_hat, y, reduction=reduction)
        return loss

    def configure_optimizers(self):
        """
        Used by pytorch_lighting

        :returns: optimizers and schedulers.
        """
        optimizer = th.optim.AdamW(
            self.parameters(),
            betas=(0.9, 0.98),
            eps=1e-8,
            lr=1,
        )
        return optimizer


In [15]:
import math


In [18]:
class ExpandModelCallback(pl.callbacks.Callback):
    def on_epoch_end(self, trainer: pl.Trainer, pl_module: TrainableExpTransformer):
        N = trainer.max_epochs / 4
        if pl_module.current_epoch > 0 and pl_module.current_epoch % N == 0:
            pl_module.expand_model(8, exp_method="zero")


In [19]:
dataset = th.randint(100, [1024, 7])
data = dataset[:, :-1]
target = dataset[:, 1:]
dataloader = th.utils.data.DataLoader((data, target), batch_size=512)

model = TrainableExpTransformer()
trainer = pl.Trainer(max_epochs=1000, callbacks=[ExpandModelCallback()])
trainer.fit(model, dataloader)


GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name        | Type        | Params
--------------------------------------------
0 | transformer | Transformer | 16.4 K
--------------------------------------------
16.4 K    Trainable params
0         Non-trainable params
16.4 K    Total params
0.066     Total estimated model params size (MB)


Epoch 250: 100%|██████████| 1/1 [00:00<00:00,  8.44it/s, loss=0.000158, v_num=29]Expanding model

Expanding to size 32
Epoch 436:   0%|          | 0/1 [00:00<00:00, 1209.43it/s, loss=0.000332, v_num=29] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
