<a href="https://colab.research.google.com/github/datacraft-paris/2311-Cerisara-LLM/blob/main/Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from pathlib import Path
import copy
from tqdm import tqdm
from pandas import DataFrame
from transformers import pipeline, AutoModelForCausalLM, PreTrainedModel, AutoTokenizer
from datasets import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch

In [None]:
# TODO: Load the model
llm = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t",
                                           torch_dtype=torch.bfloat16)
llm = llm.cuda()
lr_llm = copy.deepcopy(llm)

In [None]:
class LowRankLinear(torch.nn.Module):
    def __init__(self, in_features, rank, out_features):
        super().__init__()
        self.linear = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_features,
                            out_features=rank,
                            bias=False,
                            dtype=torch.bfloat16),
            torch.nn.Linear(in_features=rank,
                            out_features=out_features,
                            bias=False,
                            dtype=torch.bfloat16)
        )

    def forward(self, x):
        return self.linear(x)

In [None]:
def setmodule(module, target_module, value):
    """Set a target module from in a given module."""
    submodules = target_module.split(".", 1)
    if len(submodules) == 1:
        setattr(module, submodules[0], value)
    else:
        setmodule(getattr(module, submodules[0]), submodules[-1], value)

In [None]:
def load_lowrank_weights(path, llm):
    """
    Loads distilled Low-Rank Linears into the LLM.

    Parameters
    ----------
    - path: folder containing the saved distilled Low-Rank Linears.
    - llm: LLM on which to load the Low-Rank Linear.
    """
    total = sum(1 for _ in Path(path).glob("*.pt"))
    for weights in tqdm(Path(path).glob("*.pt"), total=total):
        loaded_weights = torch.load(weights, map_location=torch.device('cpu'))
        in_features = max(loaded_weights["linear.0.weight"].shape)
        rank = min(loaded_weights["linear.0.weight"].shape)
        out_features = max(loaded_weights["linear.1.weight"].shape)
        lowrank_linear = LowRankLinear(in_features, rank, out_features)
        lowrank_linear.load_state_dict(loaded_weights)
        setmodule(llm, Path(weights).stem, lowrank_linear)

In [None]:
load_lowrak_weigh("TODO")

# Hand Test

In [None]:
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
                                          trust_remote_code=True)
pipe = pipeline("text-generation", model=lr_llm, tokenizer=tokenizer, do_sample=False)

In [None]:
print(pipe("One day, I will", max_new_tokens=32, min_new_tokens=8)[0]["generated_text"])

# Perplexity

We want to assess the quality of the Low-Rank LLM on a given test set. We can achieve this by using the Perplexity metric, which measure the accuracy of the model at predicting the Test set.

The cross-entropy loss of the next token defines the log-probability of the gold-next token at each timestamps in the sequence:

>$$
CE(x_{i+1}) = log\;p(y=x_{i+1}|x_{1}, x_{2}, ..., x_{i})
$$
Where $x_{i+1}$ means the next token following the token $x_{i}$

Then, the perplexity is just defined as:

$$
perplexity = exp(CE)
$$

Use the test corpus and compute the perplexity for the low-rank LLM and the base LLM.

In [None]:
test_data = ""

In [None]:
@torch.no_grad()
def forward_dataset(dataset: Dataset,
                    llm: PreTrainedModel,
                    batch_size: int=64
                    ) -> dict:
    """Forwards all the dataset through the LLM and computes the perplexity."""
    dataset.set_format(type="torch", columns=["input_ids"])
    dataloader = DataLoader(dataset, batch_size=batch_size)
    for batch in tqdm(dataloader, total=len(dataloader)):
        inputs = batch["input_ids"].to(llm.device)
        loss = llm(input_ids=inputs, labels=inputs)
        yield {"perplexity": "TODO: compute the perplexity"}