# Expanding Transformer Architecture

In [1]:
import torch as th
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from numpy import cos, sin
import numpy as np
from grok.transformer import MultiHeadAttention, LayerNorm, FFN, Transformer
from typing import *
from copy import deepcopy

## Expanding Embedding Layer

In [2]:
vocab_len = 100
x = th.tensor([5,11,6, 66])
embedding_weight = th.rand((vocab_len,3))
F.embedding(x, embedding_weight)

tensor([[0.0678, 0.4101, 0.8137],
        [0.1975, 0.9752, 0.2417],
        [0.3470, 0.9060, 0.8533],
        [0.6832, 0.1224, 0.2610]])

In [3]:
embedding_weight_exp = th.zeros((vocab_len,4))
size = embedding_weight.shape
size = [slice(x) for x in size]
print(size)
embedding_weight_exp[size] = embedding_weight
F.embedding(x, embedding_weight_exp)

[slice(None, 100, None), slice(None, 3, None)]


tensor([[0.0678, 0.4101, 0.8137, 0.0000],
        [0.1975, 0.9752, 0.2417, 0.0000],
        [0.3470, 0.9060, 0.8533, 0.0000],
        [0.6832, 0.1224, 0.2610, 0.0000]])

## Expanding Embedding Layer with positional encoding 

In [4]:
def gen_position_encoding(context_len: int, d_model: int) -> th.Tensor:
        rows = [
            th.tensor(
                [
                    sin(pos / (10000 ** (i / d_model)))
                    if i % 2 == 0
                    else cos(pos / (10000 ** ((i - 1) / d_model)))
                    for i in range(d_model)
                ]
            )
            for pos in range(context_len)
        ]
        stack = th.stack(rows, dim=1)

        return stack.T  # type: ignore
gen_position_encoding(5,4)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  1.0000],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992]], dtype=torch.float64)

In [5]:
def embed(indices: th.Tensor, embedding_weight:th.Tensor, position_encoding:th.Tensor) -> th.Tensor:
        context_len = indices.shape[-1]
        pe = position_encoding[:context_len, :]  # type: ignore
        embedded = F.embedding(indices,embedding_weight)
        return pe + embedded

In [6]:
embed(x, embedding_weight, gen_position_encoding(10,3))

tensor([[ 0.0678,  1.4101,  0.8137],
        [ 1.0390,  1.5155,  0.2439],
        [ 1.2563,  0.4899,  0.8576],
        [ 0.8243, -0.8675,  0.2674]], dtype=torch.float64)

In [7]:
embed(x, embedding_weight_exp, gen_position_encoding(10,4))

tensor([[ 0.0678,  1.4101,  0.8137,  1.0000],
        [ 1.0390,  1.5155,  0.2517,  1.0000],
        [ 1.2563,  0.4899,  0.8733,  0.9998],
        [ 0.8243, -0.8675,  0.2910,  0.9996]], dtype=torch.float64)

## Expanding Head

## Adding Decoder Block

## Head Expansion logic

In [8]:
def new_emb(self, i):
    return self.embedding(i)

Transformer.embed = new_emb

In [9]:
net1 = Transformer(n_layers=1, n_heads=3, d_model=12)
net1.d_model//net1.n_heads

4

In [10]:
th.save(net1, "./checkpoints/net1.th")

In [23]:
def knowledge_transfer(net2:th.nn.Module, old_state_path:str):
    net1 = th.load(old_state_path)
    old_state = net1.state_dict()
    n_layers_old = net1.n_layers
    n_head_old = net1.n_heads

    dk_old = net1.d_model//net1.n_heads
    dk_new = net2.d_model//net2.n_heads

    new_state = net2.state_dict() 
    updated_state = deepcopy(new_state)
    for k in new_state:
        # print(k)
        if k == "position_encoding" or k == "self_attn_mask":
            continue
        elif "self_attn_norm" in k.split(".") or "ffn_norm" in k.split("."):
            continue
        elif "attn_heads" in k.split("."):
            updated_state[k] = th.zeros_like(new_state[k])
            weight_name = k.split(".")
            layer_idx = int(weight_name[2])
            if layer_idx < n_layers_old:
                head_idx = int(weight_name[5])   
                lst = [(i//dk_old, i%dk_old) for i in (head_idx*dk_new, head_idx*dk_new +dk_new)]
                w = []
                if lst[0][0] == lst[1][0]:
                    w.append(old_state[k][ lst[0][1]: lst[1][1], :])
                else:
                    for prev_head_idx in range(lst[0][0], lst[1][0]+1):
                        if not prev_head_idx < n_head_old:
                            continue
                        weight_name_old = weight_name.copy()
                        weight_name_old[5] = str(prev_head_idx)
                        k_old = ".".join(weight_name_old)

                        if prev_head_idx == lst[0][0]:
                            w_dash = old_state[k_old][lst[0][1]: , :]
                            # print(rng,w_dash.shape)
                            w.append(w_dash)

                        elif prev_head_idx == lst[1][0]:
                            w_dash = old_state[k_old][ :lst[1][1], :]
                            # print(rng, w_dash.shape)
                            w.append(w_dash)
                        else:
                            w.append(old_state[k_old])
                    if w:
                        final_old_w = th.cat(w)
                        dice = [slice(dim) for dim in final_old_w.shape]
                        updated_state[k][dice] = final_old_w
        else:
            updated_state[k] = th.zeros_like(new_state[k])
            if k in old_state:
                dice = [slice(dim) for dim in old_state[k].shape]
                updated_state[k][dice] = old_state[k]
        
    net2.load_state_dict(updated_state)

net2 = Transformer(n_layers=2, n_heads=4, d_model=16)
knowledge_transfer(net2, "./checkpoints/net1.th")

position_encoding
self_attn_mask
embedding.weight
decoder.blocks.0.self_attn.attn_heads.0.Wq.weight
decoder.blocks.0.self_attn.attn_heads.0.Wk.weight
decoder.blocks.0.self_attn.attn_heads.0.Wv.weight
decoder.blocks.0.self_attn.attn_heads.1.Wq.weight
decoder.blocks.0.self_attn.attn_heads.1.Wk.weight
decoder.blocks.0.self_attn.attn_heads.1.Wv.weight
decoder.blocks.0.self_attn.attn_heads.2.Wq.weight
decoder.blocks.0.self_attn.attn_heads.2.Wk.weight
decoder.blocks.0.self_attn.attn_heads.2.Wv.weight
decoder.blocks.0.self_attn.attn_heads.3.Wq.weight
decoder.blocks.0.self_attn.attn_heads.3.Wk.weight
decoder.blocks.0.self_attn.attn_heads.3.Wv.weight
decoder.blocks.0.self_attn.Wo.weight
decoder.blocks.0.self_attn_norm.weight
decoder.blocks.0.self_attn_norm.bias
decoder.blocks.0.ffn.ffn.0.weight
decoder.blocks.0.ffn.ffn.2.weight
decoder.blocks.0.ffn_norm.weight
decoder.blocks.0.ffn_norm.bias
decoder.blocks.1.self_attn.attn_heads.0.Wq.weight
decoder.blocks.1.self_attn.attn_heads.0.Wk.weight
decod

In [24]:
# for k in updated_state:
#     print(k, updated_state[k].shape)

In [25]:
x = th.tensor([11])
em1 = net1.embed(x)
em2 = net2.embed(x)
em1, em2

(tensor([[ 0.8888, -0.2345, -0.5484,  0.7895, -0.6686, -0.7335,  0.7221,  0.9016,
           0.0201, -0.4690, -1.0916, -0.4233]], grad_fn=<EmbeddingBackward0>),
 tensor([[ 0.8888, -0.2345, -0.5484,  0.7895, -0.6686, -0.7335,  0.7221,  0.9016,
           0.0201, -0.4690, -1.0916, -0.4233,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>))

In [26]:
net1.decoder.blocks[0].self_attn(em1,em1,em1), net2.decoder.blocks[0].self_attn(em2,em2,em2) 

((tensor([[-0.1011,  0.1441,  0.2920, -0.2969,  0.0155, -0.1035, -0.0443,  0.3187,
            0.0517, -0.1360, -0.1360,  0.1196]], grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[-0.1011,  0.1441,  0.2920, -0.2969,  0.0155, -0.1035, -0.0443,  0.3187,
            0.0517, -0.1360, -0.1360,  0.1196,  0.0000,  0.0000,  0.0000,  0.0000]],
         grad_fn=<MmBackward0>),
  [],
  []))

In [27]:
net1.decoder.blocks[0](em1), net2.decoder.blocks[0](em2)

((tensor([[ 1.1587,  0.1066, -0.3665,  0.5919, -0.7335, -0.8723,  0.8522,  2.0088,
            0.1278, -0.6590, -1.8812, -0.3333]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []),
 (tensor([[ 1.2905,  0.0695, -0.2614,  0.6270, -0.9105, -1.0636,  0.9250,  2.2703,
            0.0870, -0.8132, -2.2329, -0.4385,  0.1127,  0.1127,  0.1127,  0.1127]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [28]:
em2, net2.decoder.blocks[1](em2)

(tensor([[ 0.8888, -0.2345, -0.5484,  0.7895, -0.6686, -0.7335,  0.7221,  0.9016,
           0.0201, -0.4690, -1.0916, -0.4233,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 (tensor([[ 1.5869, -0.3060, -0.8350,  1.4196, -1.0375, -1.1468,  1.3061,  1.6084,
            0.1231, -0.7012, -1.7502, -0.6241,  0.0892,  0.0892,  0.0892,  0.0892]],
         grad_fn=<NativeLayerNormBackward0>),
  [],
  []))

In [29]:
net1(x), net2(x)

((tensor([[ 0.3740,  0.0282, -0.2062,  ..., -0.5324, -0.6836, -0.3312]],
         grad_fn=<MmBackward0>),
  [],
  []),
 (tensor([[ 0.5021,  0.0963, -0.3351,  ..., -0.6437, -0.8326, -0.3108]],
         grad_fn=<MmBackward0>),
  [],
  []))

## Layer Normalization

In [18]:
norm = th.nn.LayerNorm(16)

In [19]:
em2, norm(em2)

(tensor([[ 0.8888, -0.2345, -0.5484,  0.7895, -0.6686, -0.7335,  0.7221,  0.9016,
           0.0201, -0.4690, -1.0916, -0.4233,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[ 1.5869, -0.3060, -0.8350,  1.4196, -1.0375, -1.1468,  1.3061,  1.6084,
           0.1231, -0.7012, -1.7502, -0.6241,  0.0892,  0.0892,  0.0892,  0.0892]],
        grad_fn=<NativeLayerNormBackward0>))

In [22]:
net2.decoder.blocks[0].self_attn_norm.weight

Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       requires_grad=True)

In [21]:
em2, F.layer_norm(em2, (16,), norm_w)

(tensor([[ 0.8888, -0.2345, -0.5484,  0.7895, -0.6686, -0.7335,  0.7221,  0.9016,
           0.0201, -0.4690, -1.0916, -0.4233,  0.0000,  0.0000,  0.0000,  0.0000]],
        grad_fn=<EmbeddingBackward0>),
 tensor([[ 1.5869, -0.3060, -0.8350,  1.4196, -1.0375, -1.1468,  1.3061,  1.6084,
           0.1231, -0.7012, -1.7502, -0.6241,  0.0892,  0.0892,  0.0892,  0.0892]],
        grad_fn=<NativeLayerNormBackward0>))