NB: I have executed paper's code locally, but it required a specific setup and 24+ hours of training time. Thus, I decided to try to **reproduce the experiment with smaller dataset**.

In [1]:
import pandas as pd

from datasets import load_dataset

from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage

import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForSequenceClassification, pipeline




## 1. Load dataset
Instead of the boring datasets, provided in the paper, we'll use a more interesting one - [Stanford Sentiment Treebank](https://huggingface.co/datasets/stanfordnlp/sst2).

In [2]:
sentences_ds = load_dataset("sst2")

# Convert to pandas dataframe. Limit to 1000 samples for faster processing.
sentences_train_df = pd.DataFrame(sentences_ds['train'][:5000])
sentences_train_df.head()

Unnamed: 0,idx,sentence,label
0,0,hide new secretions from the parental units,0
1,1,"contains no wit , only labored gags",0
2,2,that loves its characters and communicates som...,1
3,3,remains utterly satisfied to remain the same t...,0
4,4,on the worst revenge-of-the-nerds clichés the ...,0


In [3]:
# Prepare test dataset
sentences_test_df = pd.DataFrame(sentences_ds['train'][5000:6000])
sentences_test_df.head()

Unnamed: 0,idx,sentence,label
0,5000,entirely stale concept,0
1,5001,will amuse or entertain them,1
2,5002,wobbly premise work,0
3,5003,drifts aimlessly,0
4,5004,town,1


## 2. Teacher Model
For teacher model we will use **DeepSeek R1 (1.5B parameters)**.

In [4]:
# Initialize the chat model
TEACHER_MODEL = ChatOllama(
    model="deepseek-r1:1.5b",
    base_url="http://localhost:11434",
    temperature=0.9 # We want the teacher to be less random and more deterministic. However, if the T is low, then it won't respect our ask for specific ouput JSON format.
)

### 2.1 Teacher prompt

To ensure we give our teacher model the best opportunity, we'll be employing two techniques in our classification prompt:
* **Chain-of-Thought reasoning**: Making the language model write a reasoning description to "think" through the problem before giving an answer
* **Few-shot prompting**: Providing robust examples about your expectations of both performance and format to better guide the LLM.

In [5]:
SENTIMENT_COT_PROMPT = """\
Your task is to briefly analyze the sentiment in the TEXT below and then label it with only one of these two labels:
positive, negative.
Base your label decision only on the TEXT and do not speculate e.g. based on prior knowledge about the context. 
You first reason step by step about the correct label and then return your label.
You ALWAYS respond once in the following JSON format with brackets: {{"reason": "...", "label": "..."}}. 
Again, you should ALWAYS respond in the JSON format and respect the given examples.

Examples:
Text: oh oh oh are you offering to send ducks! I love love love confit duck
JSON: {{"reason": "The text expresses enthusiasm and love for confit duck, indicating a positive sentiment", "label": "positive"}}
Text: Beautiful Day..takn it down twitters tell ALL mothers Happy Mothers Day
JSON: {{"reason": "The text describes a beautiful day and expresses positive wishes for Mother's Day", "label": "positive"}}
Text: wished didnt spend money last night
JSON: {{"reason": "The text expresses regret about spending money, indicating a negative sentiment", "label": "negative"}}
Text: yo wake your **** up and go to work go get that paper u aint sick dont lie
JSON: {{"reason": "The text is aggressive and accusatory, suggesting a negative sentiment", "label": "negative"}}
Text: Such a beautiful morning
JSON: {{"reason": "The text expresses appreciation for the morning, indicating a positive sentiment", "label": "positive"}}
Text: Nooo...i forgot my calculator for physics oh well class is allmost over :3
JSON: {{"reason": "The text expresses initial disappointment about forgetting a calculator, indicating a negative sentiment", "label": "negative"}}
"""

def generate_label_and_rationale(input_text: str) -> tuple[int, str, str]:
    """
    Use teacher model to generate a step-by-step rationale and sentiment label for the given sentence.

    Args:
        input_text (str): The input sentence to generate a sentiment label and rationale for.

    Returns:
        tuple[int, str, str]: A tuple containing the sentiment label (0 for negative, 1 for positive), the sentiment label text, and the rationale.
    """
    # Use LangChain to invoke the teacher model
    response = TEACHER_MODEL.invoke([
        SystemMessage(content=SENTIMENT_COT_PROMPT),
        HumanMessage(content=input_text)
    ])

    # Extract the rationale and sentiment label from the response
    rationale = _extract_rationale(response.content)
    label_text = _extract_label(response.content)
    label = 1 if label_text == "positive" else 0 # 0 - negative, 1 - positive
    return label, rationale

def _extract_rationale(content: str) -> str:
    return __extract_text(content, "reason")

def _extract_label(content: str) -> str:
    return __extract_text(content, "label")

def __extract_text(content: str, json_key: str) -> str:
    # Catch errors if the JSON format is not respected
    try:
        start = content.index(f'"{json_key}": "')
        chunk = content[start+len(json_key)+5:]
        end = chunk.index('"')
        return chunk[:end]
    except ValueError:
        # Second attempt to extract the text
        try:
            start = content.index(f'**{json_key.capitalize()}:** ')
            chunk = content[start+len(json_key)+6:]
            end = chunk.index('\n')
            return chunk[:end]
        except ValueError:
            print('\tERROR: Incorrect JSON format')
            return "" # This row will be removed.

generate_label_and_rationale("Want to get a Blackberry but can`t afford it. Just watching the telly and relaxing. Hard sesion tomorrow.")

(1,
 'The text expresses enjoyment from spending time on the telly, lack of difficulty in the session, and overall positive outlook')

### 2.2 Generate labels and rationales

In [6]:
# For each sentence, generate a label and rationale.
# Add logging to the function to see the progress.
for i, row in sentences_train_df.iterrows():
    print(f"Processing sentence {i+1} of {len(sentences_train_df)}")
    sentences_train_df.loc[i, 'teacher_label'], sentences_train_df.loc[i, 'teacher_rationale'] = generate_label_and_rationale(row['sentence'])

Processing sentence 1 of 5000
Processing sentence 2 of 5000
Processing sentence 3 of 5000
Processing sentence 4 of 5000
Processing sentence 5 of 5000
Processing sentence 6 of 5000
Processing sentence 7 of 5000
Processing sentence 8 of 5000
Processing sentence 9 of 5000
	ERROR: Incorrect JSON format
Processing sentence 10 of 5000
	ERROR: Incorrect JSON format
Processing sentence 11 of 5000
Processing sentence 12 of 5000
Processing sentence 13 of 5000
Processing sentence 14 of 5000
Processing sentence 15 of 5000
Processing sentence 16 of 5000
Processing sentence 17 of 5000
Processing sentence 18 of 5000
Processing sentence 19 of 5000
Processing sentence 20 of 5000
Processing sentence 21 of 5000
Processing sentence 22 of 5000
Processing sentence 23 of 5000
Processing sentence 24 of 5000
	ERROR: Incorrect JSON format
Processing sentence 25 of 5000
Processing sentence 26 of 5000
Processing sentence 27 of 5000
Processing sentence 28 of 5000
	ERROR: Incorrect JSON format
	ERROR: Incorrect JSO

### 2.3 Clean data

In [7]:
# Remove samples with empty rationales
sentences_train_df = sentences_train_df[(sentences_train_df['teacher_rationale'].notna()) & (sentences_train_df['teacher_rationale'] != "")]
# Reset the index
sentences_train_df = sentences_train_df.reset_index(drop=True)
# Target value must be integer
sentences_train_df['teacher_label'] = sentences_train_df['teacher_label'].astype(int)
sentences_train_df.head()

Unnamed: 0,idx,sentence,label,teacher_label,teacher_rationale
0,0,hide new secretions from the parental units,0,0,The text expresses playful and humorous sentim...
1,1,"contains no wit , only labored gags",0,0,The text is laborious and indicates difficulty...
2,2,that loves its characters and communicates som...,1,1,The text uses descriptive words like 'loves' a...
3,3,remains utterly satisfied to remain the same t...,0,1,The text uses the word 'satisfied' which indic...
4,4,on the worst revenge-of-the-nerds clichés the ...,0,0,The text expresses frustration from forgetting...


## 3. Student model

The distilled model will be based on **BERT base (110M parameters)** (14x times smaller than the teacher model).

In [8]:
STUDENT_MODEL = "google-bert/bert-base-uncased"
student_model = AutoModelForSequenceClassification.from_pretrained(STUDENT_MODEL)
tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### 3.1 Evaluate student model (before distillation)

In [9]:
def eval_student_model(model, sentences_df: pd.DataFrame):
    student_text_classifier = pipeline("sentiment-analysis", model, tokenizer=tokenizer)

    # Measure the accuracy of the student model
    correct = 0
    false_positive = 0
    false_negative = 0
    true_positive = 0
    true_negative = 0
    for i, row in sentences_df.iterrows():
        answer = student_text_classifier(row['sentence'])
        label = 1 if answer[0]['label'] == "LABEL_1" else 0
        if label == row['label']:
            correct += 1
            if label == 1:
                true_positive += 1
            else:
                true_negative += 1
        else:
            if label == 1:
                false_positive += 1
            else:
                false_negative += 1
    print(f"Student model accuracy: {correct/len(sentences_df):.2}")

    print(f"True positives: {true_positive}")
    print(f"True negatives: {true_negative}")
    print(f"False positives: {false_positive}")
    print(f"False negatives: {false_negative}")

eval_student_model(student_model, sentences_test_df)

Device set to use cuda:0
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Student model accuracy: 0.45
True positives: 0
True negatives: 448
False positives: 0
False negatives: 552


### 3.2 Preprocess training dataset

In [10]:
# Prepare dataset with rationales
def process_data(example: dict) -> dict:
    input_text = example["sentence"]
    rationale  = example["teacher_rationale"]
    label      = example["teacher_label"]
    
    # Tokenize input and rationale
    input_enc     = tokenizer(input_text, truncation=True, padding="max_length", max_length=64)
    rationale_enc = tokenizer(rationale, truncation=True, padding="max_length", max_length=64)

    return {
        "input_ids"     : input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "labels"        : label,
        "rationale_ids" : rationale_enc["input_ids"],
        "rationale_mask": rationale_enc["attention_mask"]
    }

# Apply function to dataframe
processed_train_dataset = sentences_train_df.apply(process_data, axis=1)
print(processed_train_dataset[0])

{'input_ids': [101, 5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': 0, 'rationale_ids': [101, 1996, 3793, 16783, 18378, 1998, 14742, 15792, 2055, 2437, 4569, 1997, 18643, 4506, 1010, 2302, 3154, 12407, 1997, 4997, 7848, 2030, 6699, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'rationale_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


### 3.3 Train student model

In [11]:
training_args = TrainingArguments(
    output_dir="./results",  # Directory to save the model and checkpoints
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=100,
    weight_decay=0.01,
    save_strategy="epoch",
    push_to_hub=False
)

In [12]:
class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        rationale_ids = inputs.pop("rationale_ids", None)
        
        outputs = model(**inputs)
        
        loss_fn = torch.nn.CrossEntropyLoss()
        label_loss = loss_fn(outputs.logits, labels)
        
        if rationale_ids is not None:
            rationale_outputs = model(input_ids=rationale_ids, attention_mask=inputs["attention_mask"])
            rationale_loss = loss_fn(rationale_outputs.logits, rationale_ids)
            loss = label_loss + 0.5 * rationale_loss  # Weighted loss
        else:
            loss = label_loss
        
        return (loss, outputs) if return_outputs else loss

trainer = MultiTaskTrainer(
    model           = student_model,
    args            = training_args,
    train_dataset   = processed_train_dataset,
    eval_dataset    = processed_train_dataset
)

trainer.train()
trainer.save_model("./results")
print("✅ Distillation Complete! Smaller model saved.")

Epoch,Training Loss,Validation Loss
1,0.5915,0.499357
2,0.5056,0.350284
3,0.4405,0.229755
4,0.3639,0.127773
5,0.2483,0.089736
6,0.1789,0.044202
7,0.1554,0.032344
8,0.129,0.032695
9,0.0854,0.016657
10,0.0633,0.011157


✅ Distillation Complete! Smaller model saved.


## 4. Accuracy

### 4.1 Baseline (Teacher model)
In order evaluate our distilled model, we need to first establish a baseline by evaluating the LLM (teacher) model.

In [13]:
# Measure the accuracy of the teacher model first as a baseline.
for i, row in sentences_test_df.iterrows():
    print(f"Test {i+1} of {len(sentences_test_df)}")
    sentences_test_df.loc[i, 'teacher_label'], sentences_test_df.loc[i, 'teacher_rationale'] = generate_label_and_rationale(row['sentence'])

# Remove samples with empty rationales
sentences_test_df = sentences_test_df[(sentences_test_df['teacher_rationale'].notna()) & (sentences_test_df['teacher_rationale'] != "")]

# Measure the accuracy of the teacher model
correct = 0
for i, row in sentences_test_df.iterrows():
    if row['teacher_label'] == row['label']:
        correct += 1
print(f"Teacher model accuracy: {correct/len(sentences_test_df):.2}")

Test 1 of 1000
Test 2 of 1000
Test 3 of 1000
Test 4 of 1000
Test 5 of 1000
Test 6 of 1000
Test 7 of 1000
Test 8 of 1000
Test 9 of 1000
Test 10 of 1000
Test 11 of 1000
Test 12 of 1000
Test 13 of 1000
	ERROR: Incorrect JSON format
	ERROR: Incorrect JSON format
Test 14 of 1000
Test 15 of 1000
Test 16 of 1000
Test 17 of 1000
Test 18 of 1000
Test 19 of 1000
Test 20 of 1000
Test 21 of 1000
Test 22 of 1000
Test 23 of 1000
Test 24 of 1000
Test 25 of 1000
Test 26 of 1000
Test 27 of 1000
Test 28 of 1000
Test 29 of 1000
Test 30 of 1000
Test 31 of 1000
Test 32 of 1000
Test 33 of 1000
Test 34 of 1000
Test 35 of 1000
Test 36 of 1000
	ERROR: Incorrect JSON format
Test 37 of 1000
Test 38 of 1000
Test 39 of 1000
Test 40 of 1000
Test 41 of 1000
Test 42 of 1000
Test 43 of 1000
	ERROR: Incorrect JSON format
	ERROR: Incorrect JSON format
Test 44 of 1000
Test 45 of 1000
Test 46 of 1000
Test 47 of 1000
Test 48 of 1000
Test 49 of 1000
Test 50 of 1000
Test 51 of 1000
Test 52 of 1000
Test 53 of 1000
Test 54 of 

### 4.2 Student model
Let's evaluate the distilled model.

In [14]:
eval_student_model(student_model, sentences_test_df)

Device set to use cuda:0


Student model accuracy: 0.72
True positives: 304
True negatives: 399
False positives: 40
False negatives: 234


## Conclusion
* The student model has increased its accuracy **from $45$% to $72$%** after training.
* Furthermore, the student model **surpassed its teacher** model's accuracy ($66$%).