<a href="https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/FeaturesMimicking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. [Introduction](https://www.google.com/)
2. [Recall on Language Modeling and Transformers](https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/LLMs.ipynb)
3. [Preparing the data](https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/Data.ipynb)
4. [Low-Rank Approximation](https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/LowRankCompression.ipynb)
5. Low-Rank Features Mimicking (This notebook)
    1. [Weights are not Low-Rank... ðŸ˜¢](#not_low_rank)
    2. [Features Mimicking](#ft_mm)
    3. [Setup](#setup)
    3. [Faster Low-Rank Linear](#fast)
    4. [Forward Hooks](#hooks)
        1. [Callable Hook Class](#callable)
        2. [Register Hooks](#register)
    5. [Train the LowRank Linears](#train)

# Weights are not Low-Rank... ðŸ˜¢ <a name=not_low_rank></a>


So far, we have seen how to approximate the weights of the LLM with smaller weights. However, the compressed models is not too good. One reason is that the preatrained weights LLMs are often full rank. Studies have shown that transformer weights are often full rank compared to the activations (the outputs of the linear functions).
 <p align="center">
  <img src="https://github.com/datacraft-paris/2311-Cerisara-LLM/blob/main/illustration/weights_are_full.png?raw=true:, width=500" alt="attention" width=500 class="center">
</p>

So, the authors in this paper proposed to decompose the weights using the activations.

# Features Mimicking <a name=ft_mm><a/>


<p align="center">
<img src="https://github.com/datacraft-paris/2311-Cerisara-LLM/blob/main/illustration/Capture dâ€™Ã©cran 2024-02-25 Ã  01.56.15.png?raw=true:, width=500" alt="attention" width=500 class="center">
<br>
    <em>
    A figure that illustrate how we want to train low rank linear to output activations that mimick the activations of the pretrained weights.
    <br>
    The blue part represents the pretrained module and the pink the Low-Rank module trained to mimick the activations/
    </em>
</p>


\
\
\
Low-Rank approximation in our case can be formulated as this minimisation objective:

$$
\underset{W_{1}, W_{2}}{\mathrm{argmin}} \;\;\|W - W_{1}W_{2}\|_{F}^{2}
$$

Where $W_{1}$ and $W_{2}$ are Low-Rank matrices. An analytic solution of this problem is given by $SVD$, where $W_{1}$ and $W_{2}$ are defined as in the previous notebook (using $U$, $S$ and $V$)

However, as we said, these matrices are probably not Low-Rank, hence the bad approximation. Instead, we want to approximate the weights using the linear activations, which seem Low-Rank. The linear activations are just defined as:

$$
f_{b}(X) = XW
$$

we want to find another linear function, with Low-Rank weights

$$
f_{a}(X) = XW_{1}W_{2}
$$

such that $f_{a}(X)$ ~ $f_{b}(X)$

We can use a simple objective function to achieve this:

$$
\underset{W_{1}, W_{2}}{\mathrm{argmin}} \;\;\|f_{b}(X) - f_{a}(X)\|_{F}^{2}
$$

with this objective function, the Low-Rank weights will learn to mimick the activations returned by the full rank weight.

We can find an analytic solution of this by using an _eighendecomposition_ of the covariance matrix of the activations.
In our case, we will try to find a solution of this by using a gradient descent algorithme to minimize the Mean Squared Error.

# Setup <a name=setup></a>

In [None]:
import re
import copy
from pathlib import Path
import logging
from typing import Dict, Set
from argparse import ArgumentParser, BooleanOptionalAction
from tqdm.auto import tqdm
import yaml
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
import torch
import torch.nn.functional as F

In [None]:
# TODO: Load the model
llm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
                                           torch_dtype=torch.bfloat16)
llm = llm.cuda()

In [None]:
# TODO: write a simple function that freeze all the parameters of the LLM
def freeze_llm(llm):
    """A simple function that freeze the LLM by setting 'require_grad' to False."""
freeze_llm(llm)

# Faster Low-Rank Linear <a name=fast></a>

This is the same as previous low-rank linear. We just use the pytorch Linear to make it a bit faster, because built-in functions / classes are often better optimize.

In [None]:
class LowRankLinear(torch.nn.Module):
    def __init__(self, w, rank):
        super().__init__()
        self._decompose(w, rank)

    @torch.no_grad()
    def _decompose(self, w, r):
        u, s, v = torch.linalg.svd(w.to(dtype=torch.float32))
        w1 = (u[:, :r]).to(dtype=torch.bfloat16)
        w2 = (torch.diag(s)[:r, :r] @  v[:r, :]).to(dtype=torch.bfloat16)
        linear_1 = torch.nn.Linear(in_features=w1.shape[0],
                                   out_features=w.shape[-1],
                                   bias=False,
                                   dtype=torch.bfloat16,
                                   device=w.device)
        linear_1.weight = torch.nn.Parameter(w1)
        linear_2 = torch.nn.Linear(in_features=w2.shape[0],
                                   out_features=w2.shape[-1],
                                   bias=False,
                                   dtype=torch.bfloat16,
                                   device=w.device)
        linear_2.weight = torch.nn.Parameter(w2)
        self.linear = torch.nn.Sequential(linear_2, linear_1)

    def forward(self, x):
        return self.linear(x)

# Forward Hooks <a name=hooks></a>

A forward hook in pytorch is just a callable function that can be attached to any module. Each time the forward of the module is called, the hook attached to this module is called too. A forward hook in pytorch always has this signature:

```python
def forward_hook(module, input, output):
    """
    Parameters
    ----------
    - module:
      The module to which the forward hook is attached.
    - input:
      The input to the module
    - output:
      The output produced by the model.
    """
    stuff = do_stuff()
    return stuff
```

We want to attach a low-rank linear to each linear module of the LLM. When the forward of linear module of the LLM is called, the corresponding low-rank linear is also called. This hook will just train the Low-Rank linear to mimick the output of the Linear module of the LLM, by minimizing `MSE` loss wia gradient descent.

## Callable Hook <a name=callable></a>

We use a callable class so we can store data.

In [None]:
class ForwardHook:
    def __init__(self,
                 name: str,
                 lowrank_linear: torch.nn.Module,
                 output_folder: Path,
                 lr: float=0.000086,
                 log_interval: int=8
                 ):
        self.name = name # The name of the linear module
        self.weight_name = name.split(".")[-1]
        self.layer_idx = re.search(r'\d+', name).group()
        self.lowrank_linear = lowrank_linear
        self.optimizer = torch.optim.AdamW(self.lowrank_linear.parameters(), lr=lr)
        self.losses = []
        self.log_interval = log_interval
        self.current_log_idx = 0
        self.start = True
        self.best_loss = float("Inf")
        self.output_folder = Path(output_folder)
        self.output_folder.mkdir(exist_ok=True, parents=True)

    def __call__(self, module, input, output) -> None:
        """Forward through the lowrank linear and collect loss."""
        self.optimizer.zero_grad()
        student_output = self.lowrank_linear(input[0])
        loss = "TODO: use torch.functional to compute the mse loss between the student and the teacher"
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.lowrank_linear.parameters(), max_norm=2.0, norm_type=2)
        # TODO: Update the parameters of low rank linear here
        self.losses.append(str(loss.item()))
        if self.current_log_idx == self.log_interval or self.start:
            self.start = False
            self.current_log_idx = 0
            print()
            print(f"Layer={self.layer_idx}, Module={self.weight_name}, Loss={loss.item()}")
            if self.best_loss > loss.item():
                self.save()
                self.best_loss = loss.item()
        self.current_log_idx += 1

    def save(self):
        """Save the low-rank module."""
        torch.save(self.lowrank_linear.state_dict(),
                   f=self.output_folder / f"{self.name}.pt")

    def save_losses(self):
        with open(self.output_folder / f"{self.name}.losses", "w") as loss_file:
            loss_file.write("\n".join(self.losses))

    def __hash__(self):
        return hash(self.name)

    def __str__(self):
        return self.autoencoder.__str__()

## Register Hooks <a name=register></a>

In [None]:
def register_forward_hooks(config, model):
    total = sum(1 for _ in model.named_modules())
    hooks = set()
    for name, module in tqdm(model.named_modules(), total=total):
        module_name = name.split(".")[-1]
        if module_name in config:
            rank = config[module_name]
            lowrank_linear = LowRankLinear(w=module.weight, rank=rank)
            hook = ForwardHook(name=name,
                               lowrank_linear=lowrank_linear,
                               output_folder="lowrank_weights")
            module.register_forward_hook(hook)
            hooks.add(hook)
    return hooks

In [None]:
config = {
    "q_proj": 384,
    "o_proj": 384,
    "gate_proj": 512,
}

In [None]:
# 2. Register the forward hooks
#...

# Train the LowRank Linears <a name=train></a>

In [None]:
# TODO: 1. Load the tokenized data
train_data = "TODO"

In [None]:
# freeze_llm(llm)
def forward_dataset(dataset: Dataset,
                    llm: PreTrainedModel,
                    batch_size: int=64
                    ) -> None:
    """Forwards all the dataset through the LLM and computes the statistics."""
    dataset.set_format(type="torch", columns=["input_ids"])
    dataloader = DataLoader(dataset, batch_size=batch_size)
    for batch in tqdm(dataloader, total=len(dataloader)):
        inputs = batch["input_ids"].to(llm.device)
        llm(input_ids=inputs)

In [None]:
forward_dataset(test, llm)