# Paper
* **Title**: [Distilling Step-by-Step! Outperforming LLMs with less training data and smaller model sizes](https://arxiv.org/pdf/2305.02301)
* **Date**: 05 Jul 2023
* **Authors**: University of Washington, Google Cloud AI Research

- - -
# Summary
### Problem
LLMs are compute-intensive and memory inefficient for specific tasks. In result, engineers train smaller, task-specific models by either finetuning with human labels or distilling using LLM generated labels.

However, standard finetuning and distillation techniques require large amounts of training data to achieve comparable performance to LLMs.

### Solution
The paper presents a novel technique for reducing the amount of data needed for distillation. It trains small models with small datasets and the results outperform LLMs for **specific tasks**. The whole process is 100% reproducible on a consumer-grade GPU.


### How it works
The idea is to distill **rationales** instead of just **raw information**, because it seems to drastically cut the required amount of data and model size. Figuratively speaking, instead of memorizing all the facts that the teacher knows you (the student) can simply learn how the teacher "thinks".

The technique called "Distilling Step-by-Step" extracts **rationales** from larger models and uses it to train smaller ones within a **multi-task** framework (predict label + predict rationale).

For example, when asked *“Jesse’s room is 11 feet long and 15 feet wide. If she already has 16 square feet of carpet. How much more carpet does she need to cover the whole floor?”*, an LLM can be prompted by **chain-of-thought (CoT) technique** to provide **intermediate rationales** *“Area = length × width. Jesse’s room has 11 × 15 square feet.”* that better connects the input to the final answer *“(11 × 15) − 16”*. These rationales can contain relevant task knowledge, such as *“Area = length × width”*, that may originally require many data for small task-specific models to learn. We thus utilize these extracted rationales as additional, richer information to train small models through a **multi-task training setup, with both label prediction and rationale prediction tasks**.

<center><img src="img/llm_rationale.png" alt="How distilling step-by-step works" width="784" height="296" /></center>
<p style="text-align: center; font-size: small;"><i><b>Figure 1.</b> We first utilize CoT prompting to extract rationales from an LLM. We then use the generated rationales to train<br/>
small task-specific models within a <b>multi-task learning</b> framework where we prepend task prefixes to the input examples<br/>
and train the model to output differently based on the given task prefix.</i></p>

### Results
* distillation of an LLM to a small task-specific model that outperforms the LLM by using **500-2000× less** model parameters
* on average **50% less** training examples (in comparison to standard task distillation)

More specifically:
* the performance of **540B LLM** model is surpassed by a **770M T5** model; this smaller model uses 80% of a labeled dataset that would otherwise be required if using an existing finetuning method.
* when only unlabeled data is present, the small models still perform on par or better than LLMs. For example, **540B PaLM**’s performance is outperformed with only a **11B T5** model.

The paper demonstrates the effectiveness of the method on **fully unlabeled datasets**.


<center><img src="img/distilling_results.png" alt="Distilling step-by-step and Standard task distillation comparison" width="866" height="245" /></center>
<p style="text-align: center; font-size: small;"><i><b>Figure 2.</b> We compare Distilling step-by-step and Standard task distillation using 220M T5 models on varying sizes of unlabeled datasets. <br/>Distilling step-by-step is able to outperform Standard task distillation by using only a small subset of the full unlabeled dataset (e.g., 12.5% on ANLI dataset).
</i></p>

### Caveats
* **Multi-task training is much more effective than single-task rationale**
    - single-task training with LLM rationales can at times leads to worse performance than standard finetuning
    - simply treating rationale and label predictions as a single joint task may harm the model’s performance on label prediction

* **Limitations**
    - distilling step-by-step requires the user to manually produce 10-shot examples for all tasks in order to use the few-shot CoT prompting mechanism of the teacher LLM
    - training task-specific models with rationales slightly increase the training-time (due to the rational generation overhead)
    - finally, while we observe success using LLM rationales, there is evidence that **LLMs exhibit limited reasoning capability** on more complex reasoning and planning tasks


- - -
# Next papers
Related papers that seem interesting:

* [Distilling Task-Specific Knowledge from BERT into Simple NN](https://arxiv.org/pdf/1903.12136) - Recent developments have led to the conviction that previous-generation, shallower neural networks for language understanding are obsolete. In this paper, however, we demonstrate that rudimentary, lightweight
neural networks can still be made competitive

* [Chain-of-Thought Prompting Elicits Reasoning in LLM](https://arxiv.org/pdf/2201.11903) - For instance, prompting a PaLM 540B with just 8 chain-of-thought exemplars achieves SotA accuracy on the GSM8K benchmark of math word problems, surpassing even finetuned GPT-3 with a verifier

* [Distilling the Knowledge in a Neural Network](https://arxiv.org/pdf/1503.02531) - Caruana and his collaborators [1] have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model.

- - -

# New Knowledge


Here comes new terminology that I had to get acquainted with in order to comprehend the above paper.

* **finetuning** vs **distillation**: Fine-tuning adapts a model for specific tasks using human-labeled data, while distillation is a method of fine-tuning where a small, "student" model learns to mimic the behavior of a larger, "teacher" model, using the teacher's outputs. The primary goals differ: fine-tuning improves a model's accuracy and capabilities on a target task, whereas distillation focuses on creating a smaller, more efficient model that maintains performance with reduced computational cost. 

* **zero-shot**: "zero-shot" models are capable of performing tasks on new, unseen data categories without specific prior training examples for those categories. They achieve this by leveraging knowledge gained from large pre-training datasets to understand conceptual relationships between different concepts and attributes. By describing new categories in natural language or as embedding vectors, the model can infer connections and correctly categorize unseen instances based on its general understanding, rather than relying on labeled examples for that specific task.

* **few-shot**: "few-shot" models are models which are finetune-able with instructions. To train GPT-3 a new task, you do not need to update the model's weights, but give it few natural instructions (examples of input-outputs).

<center><img src="img/few_shot_models.png" alt="Few-shot ability explained" width="800" height="186" /></center>
<p style="text-align: center; font-size: small;"><i><b>Figure 3.</b> Few-shot models could be finetuned with a set of instructions.</i></p>

* **few-shot CoT**: Few-shot Chain-of-Thought (CoT) is a prompting technique for Large Language Models (LLMs) that combines the benefits of providing several examples (few-shot) with the step-by-step reasoning of Chain-of-Thought prompting to improve accuracy and performance. By including a few examples that demonstrate not just the final answer but also the intermediate reasoning process, the LLM is guided to understand complex tasks and produce more reliable, logical, and desired outputs, especially for tasks that require complex reasoning or structured outputs.  

* **logits**: the raw, unscaled output values from the final layer of a neural network - before any activation function like softmax or sigmoid is applied. They are not probabilities yet, but they play a critical role in converting model predictions into understandable outputs. Basically, **logits** are the unnormalized final scores of your model.

* **distillation techniques**: 
There are 3 types of distillation that I know of:

    * *(a.k.a. poor man's distillation)* Generating data w/ LLM1 and sft on LLM2 (usually LLM1 is stronger than LLM2) e.g. DeepSeek-r1-distill-7b is a qwen2.5 7b model sft'd on ~800k entries generated with DeepSeek-r1 (the full 600b+ model)

    * *Logit-based Distillation* (models must be the same architecture) - here you run a completion on LLM1, log the entire logit distribution and train LLM2 on matching the entire distribution, not just "the best token". With the obvious downside that the two models need to share tokenizers and so on. (i.e. you can do qwen 2.5 32b -> qwen 2.5 7b, but not qwen 2.5 32b -> llama3 8b)

    * *Hidden States-based Distillation* (models can be different architectures) - I haven't tried this, but IIRC one upside was out of family model support while taking lots of space to hold the hidden states for a lot of generations.

    Types 2 and 3 can be done with repos such as https://github.com/arcee-ai/DistillKit , while point 1 can be done with any workflow that can generate samples and sft / other fine-tuning strategies (dpo, kto, etc)

* **distillation loss**: is a loss function that measures the diff between the "student" and "teacher" models. This is typically achieved by using a **KL divergence** or **cross-entropy loss** to measure the difference between the student's soft predictions and the teacher's soft targets, often with a **temperature parameter** applied to the softmax function.
<center><img src="img/distillation_loss.png" alt="Distillation loss measures the diff between the 'student' and 'teacher' models" width="400" height="180" /></center>
<p style="text-align: center; font-size: small;"><i><b>Figure 4.</b> Distillation loss measures the diff between the "student" and "teacher" models.</i></p>

* **multi-task training** fine-tunes a pre-trained LLM to perform multiple tasks simultaneously (e.g. classify text & reason), rather than training separate models for each task.

* **temperature**:  The [temperature parameter](https://medium.com/@kelseyywang/a-comprehensive-guide-to-llm-temperature-%EF%B8%8F-363a40bbc91f) in LLMs directly affects the **variability** and **randomness** of the generated responses. Higher values like $0.8$ will make the output more random, while lower values like $0.2$ will make it more focused and deterministic.