In [2]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m")

  from .autonotebook import tqdm as notebook_tqdm
Downloading tokenizer_config.json: 100%|██████████| 396/396 [00:00<00:00, 2.71MB/s]
Downloading tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 9.73MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 99.0/99.0 [00:00<00:00, 948kB/s]
Downloading config.json: 100%|██████████| 567/567 [00:00<00:00, 4.62MB/s]
Downloading model.safetensors: 100%|██████████| 166M/166M [00:01<00:00, 102MB/s]  


In [None]:
model

In [1]:
import torch

linear = torch.nn.Linear(768, 768)

In [7]:
lin_weight = linear.weight

In [14]:
u, s, v = torch.svd(lin_weight)
u.shape, s.shape, v.shape

(torch.Size([768, 768]), torch.Size([768]), torch.Size([768, 768]))

In [15]:
torch.dist(lin_weight, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT))

tensor(2.8584e-05, grad_fn=<DistBackward0>)

In [17]:
u, s, v = torch.svd_lowrank(lin_weight)
u.shape, s.shape, v.shape

(torch.Size([768, 6]), torch.Size([6]), torch.Size([768, 6]))

In [18]:
torch.dist(lin_weight, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.mT))

tensor(15.7942, grad_fn=<DistBackward0>)

In [19]:
768 * 256 * 2, 768 * 768

(393216, 589824)

In [6]:
1024 * 8 + 3072 * 8, 1024 * 4096, 1024 * 1024 * 3

(32768, 4194304, 3145728)

In [1]:
768 * 768 * 3, 768 * 8 * 2 * 3

(1769472, 36864)

In [7]:
1024 * 1024

1048576

In [8]:
1024 * 8 + 8 * 1024

16384

In [1]:
1_997_624_552 - 2_775_208_960

-777584408

In [3]:
50304 * 1024, 50304 * 512 + 512 * 1024

(51511296, 26279936)

In [4]:
2_775_208_960 - 1_692_226_992

1082981968

## custom embedding

In [9]:
from typing import Optional
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F


class CustomEmbedding(torch.nn.Embedding):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int | None = None,
        max_norm: float | None = None,
        norm_type: float = 2,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        _weight: Tensor | None = None,
        _freeze: bool = False,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            _weight,
            _freeze,
            device,
            dtype,
        )
        self.weight = None
        factory_kwargs = {"device": device, "dtype": dtype}
        self.A = Parameter(
            torch.empty((num_embeddings, 16), **factory_kwargs),
            requires_grad=not _freeze,
        )
        self.B = Parameter(
            torch.empty((16, embedding_dim), **factory_kwargs),
            requires_grad=not _freeze,
        )
        self.B.data.normal_(mean=0.0, std=0.02)
        self.A.data.normal_(mean=0.0, std=0.02)

    def forward(self, input: Tensor) -> Tensor:
        weight = self.A @ self.B
        return F.embedding(
            input,
            weight,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )


custom_emb = CustomEmbedding(123, 123)
default_emb = torch.nn.Embedding(123, 123)
custom_emb(torch.tensor([1])).shape, default_emb(torch.tensor([1])).shape

(torch.Size([1, 123]), torch.Size([1, 123]))

In [6]:
A = Parameter(
    torch.empty(
        (123, 16),
    ),
    requires_grad=True,
)
B = Parameter(
    torch.empty(
        (16, 123),
    ),
    requires_grad=True,
)

(A @ B).shape

tensor([[6.2585e-19, 3.1788e-36, 7.1940e-30,  ..., 7.1955e-30, 0.0000e+00,
         7.1944e-30],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]], grad_fn=<MmBackward0>)

In [12]:
50304 * 768, 50304 * 256 * 2 + 2 * 256 * 768

(38633472, 26148864)