In [1]:
from tqdm.notebook import tqdm
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader

from llama import Llama
from short_llama import ShortLlama

### Load Data

In [None]:
data = load_dataset("pg19", 
                    split = "validation"
                    )  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size = 1,
    shuffle = True,
    generator = torch.Generator(device = "cuda")
)

### Fetch and Wrap Model

In [None]:
MAX_SEQ_LEN = 1024  # authors use a context width of 1024
llama = Llama.build(
    ckpt_dir = "../../llama/llama-2-7b",
    tokenizer_path = "../../llama/tokenizer.model",
    max_seq_len = MAX_SEQ_LEN,
    max_batch_size = 1,
)

In [None]:
short_llama = ShortLlama(llama = llama, 
                         n_prune_layers = 9
                         )

short_llama.llama.model.layers

In [None]:
# sample generation
short_llama.llama.text_completion(
    prompts = ["I am an avid fan of "],
    max_gen_len = 20
)

### Compute Importances

In [None]:
for batch in tqdm(dataloader):
    prompts = batch['text']

    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos = True, eos = False) for x in prompts]
    max_prompt_len = max(len(t) for t in prompt_tokens)

    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, 256):

        inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]

        short_llama.eval_importance(
            prompt_tokens = inputs,
            max_gen_len = 0
        )

In [None]:
short_llama.importances

### Remove unimportant layers

Layers removed when using pg19 val set: [25, 27, 24, 26, 28, 29, 23, 22, 21]

Note: Different order than paper but same 9 least important layers -> [27, 26, 25, 28, 24, 29, 23, 21, 22]

Additionally, authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [None]:
short_llama.remove_layers()

In [None]:
short_llama.llama.model.layers

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [None]:
short_llama.llama.text_completion(
    prompts = ["I am an avid fan of "],
    max_gen_len = 20
)

### Compute Angular Importances

In [None]:
for batch in tqdm(dataloader):
    prompts = batch['text']

    prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos = True, eos = False) for x in prompts]
    max_prompt_len = max(len(t) for t in prompt_tokens)

    # authors use a sliding window of size 1024 with a shift of 256
    for start in range(0, max_prompt_len, 256):

        inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]

        short_llama.eval_importance(
            prompt_tokens = inputs,
            max_gen_len = 0,
            angular = True
        )

In [None]:
short_llama.importances

### Remove unimportant layers

In [None]:
short_llama.remove_layers(angular = True)

In [None]:
short_llama.llama.model.layers

In [None]:
short_llama.llama.text_completion(
    prompts = ["I am an avid fan of "],
    max_gen_len = 20
)