<a href="https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/LowRankCompression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. [Introduction](https://www.google.com/)
2. [A Brief Overview of LLMs](https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/LLMs.ipynb)
3. [Preparing the data](https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/Data.ipynb)
4. Low-Rank Approximation (This notebook)
    1. [Recall](#recall)
    2. [Setup](#setup)
    3. [Low-Rank Linear](#lrl)
    4. [~22% Reduction](#reduction)

We will explore in this notebook a classic low-rank decomposition applied to LLMs. You will see that it doesn't work very well... but it's an experience that will give us a solid basis for the next notebook.

# Recall <a name="recal"></a>

A linear function is defined by a matrix $W$:


> $$f(X) = XW
$$
With $X \in \mathbb{R}^{b*d_{1}}$ and $\mathbb{R}^{d_{1}*d_{2}}$.


We want to compress the pretrained matrix $W \in \mathbb{R}^{d_{1} * d_{2}}$. One way to achieve this is by using Low-Rank Approximation of $W$:

> $$
W = W_{1}W_{2}
$$
where $W_{1} \in \mathbb{R}^{d_{1} * r}$ and $W_{2} \in \mathbb{R}^{r * d_{2}}$.

The value $r$ is the rank of the approximating matrices, this number should be chosen such that the total number of parameters in $W_{1}$ and $W_{2}$ is lower than the number of parameters in $W$.

The question now is how to get $W_{1}$ and $W_{2}$. One way to estimate these matrices is using SVD. SVD offers the best $r$-rank approximation the matrix $W$:

> $$
W = USV^{T}
$$
where $U \in \mathbb{R}^{d_{1} * d_{2}}$ and $V \in \mathbb{R}^{d_{2} * d_{2}}$ are orthogonal matrices. $S \in \mathbb{R}^{d_{1} * d_{2}}$ is a diagonal matrix wich entries contains singular values in deacrising order.

By selecting the largest $r$ terms of the singular values, the resulting matrix is an optimal approximation of W with a lower rank $r$:

$$
W = U_{:, r}(S_{r:r}V_{:r, :}^{T})
$$
Where:

$$
U_{:, r} = W_{1}
$$

$$
(S_{r:r}V_{:r, :}^{T}) = W_{2}
$$

Now we have $W_{1}$ and $W_{2}$, we can define the linear function as:

$$
f(X) = XW_{1}W_{2}
$$

# Setup <a name="setup"></a>

In [None]:
%%capture
!pip uninstall -y transformers
!pip install git+https://github.com/huggingface/transformers accelerate tiktoken

In [None]:
import copy
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [None]:
llm = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t",
                                           torch_dtype=torch.bfloat16,
                                           device_map="cuda",
                                           trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t",
                                          trust_remote_code=True)

In [None]:
lr_llm = copy.deepcopy(llm)

# Low-Rank Linear <a name="lrl"></a>

Before we write a Low-Rank Linear class, let's take a quick example.

In [None]:
import torch

In [None]:
# pretrained matrix
W = torch.rand((128, 128))
W.requires_grad = False
# rank
r = 16

Look at the documentation of the SVD implementation of PyTorch: https://pytorch.org/docs/stable/generated/torch.linalg.svd.html

Then, call this function to get the U, S and V matrices:

In [None]:
u, s, v = "TODO"
# diagonalize 's':
s = torch.diag(s)

Based on above formula, get the $W_{1}$ and $W_{2}$

In [None]:
w1 = u[:, :r]
w2 = "TODO"

In [None]:
class LowRankLinear(torch.nn.Module):
    def __init__(self, w, rank):
        super().__init__()
        self._decompose(w, rank)

    @torch.no_grad()
    def _decompose(self, w, r):
        u, s, v = torch.linalg.svd(w.to(dtype=torch.float32))
        w1 = (u[:, :r]).to(dtype=torch.bfloat16)
        w2 = "TODO"
        self.w1 = torch.nn.Parameter(w1)
        self.w2 = torch.nn.Parameter(w2)

    def forward(self, x):
        return "TODO"

# ~22% Reduction <a name="reduction"></a>

In [None]:
config = {
    "q_proj": 384,
    "o_proj": 384,
    "gate_proj": 512,
}

In [None]:
def setmodule(module, target_module, value):
    """Set a target module from in a given module."""
    submodules = target_module.split(".", 1)
    if len(submodules) == 1:
        setattr(module, submodules[0], value)
    else:
        setmodule(getattr(module, submodules[0]), submodules[-1], value)

In [None]:
def lowrank_model(config, model):
    total = sum(1 for _ in model.named_modules())
    for name, module in tqdm(model.named_modules(), total=total):
        module_name = name.split(".")[-1]
        if module_name in config:
            rank = config[module_name]
            lowrank_linear = LowRankLinear(w=module.weight, rank=rank)
            setmodule(model, name, lowrank_linear)

In [None]:
lowrank_model(config, lr_llm)

In [None]:
print(f'Number of parameters of the Base LLM: {llm.num_parameters(only_trainable=True):,}')
print(f'Number of parameters of the Low-Rank LLM: {lr_llm.num_parameters(only_trainable=True):,}')

In [None]:
pipe = pipeline("text-generation", model=lr_llm, tokenizer=tokenizer, do_sample=False)
print(pipe("Here a python function that sum up 3 numbers:", max_new_tokens=32, min_new_tokens=8)[0]["generated_text"])

In [None]:
pipe = pipeline("text-generation", model=llm, tokenizer=tokenizer, do_sample=True)
print(pipe("One day, I will", max_new_tokens=16, min_new_tokens=8)[0]["generated_text"])