# Distilling the Mistral Small Model from Deepseek R1 on TPU 

In this notebook we will use the Wikitext-2 dataset from Hugging Face Datasets to perform distillation.
**Steps:**
1. Setup the TPU environment.
2. Load the teacher (Deepseek R1) and student (Mistral Small) models along with the tokenizer.
3. Load and tokenize the Wikitext-2 dataset.
4. Define a collator for the DataLoader.
5. Define a distillation loss function (KL divergence + cross-entropy).
6. Run the training loop on TPU using PyTorch/XLA’s distributed tools.
7. Save the distilled student model.

**Required Packages:**

- `torch`  
- `torch_xla`  
- `transformers`  
- `datasets`  
- `huggingface_hub`

You can install them via pip:

```bash
pip install torch torchvision torchaudio
pip install torch_xla  # Make sure you use the appropriate version for your TPU environment
pip install transformers datasets huggingface_hub
```

In [None]:
!pip install torch torchvision torchaudio
# !pip install torch_xla
!pip install transformers==4.44.2 datasets huggingface_hub

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Import Hugging Face Transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

# Import PyTorch XLA modules for TPU support
# import torch_xla
# import torch_xla.core.xla_model as xm
# import torch_xla.distributed.parallel_loader as pl

## 1. Setup TPU Environment

This cell initializes the TPU device. In Google Colab, make sure you select TPU as the accelerator.


In [None]:
# device = xm.xla_device() 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Load Models and Tokenizer

**Note:** Replace `"deepseek/deepseek-r1"` and `"meta-llama/Meta-Llama-3-8B"` with the model identifiers or paths which you Want to distill.


In [None]:
!huggingface-cli login --token Add here the hf token

In [None]:
# Model names
teacher_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
student_model_name = "meta-llama/Llama-3.2-1B"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

# Ensure tokenizer has a pad token
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Load teacher model
teacher = AutoModelForCausalLM.from_pretrained(teacher_model_name, trust_remote_code=True)
teacher = teacher.to(dtype=torch.float16)
teacher.resize_token_embeddings(len(tokenizer))
teacher.to(device)


In [None]:
# Load student model
student = AutoModelForCausalLM.from_pretrained(student_model_name,trust_remote_code=True)
student = student.to(dtype=torch.float16)
student.resize_token_embeddings(len(tokenizer))
student.to(device)


## 3. Load and Tokenize the Wikitext-2 Dataset

Here we load the “wikitext-2-raw-v1” split (training portion) and tokenize the text using the loaded tokenizer.

We use a maximum sequence length of 128 tokens. Adjust as needed.


## 4. Define a Data Collator

A collator is used to batch the data correctly. This collator pads sequences dynamically to the longest
sequence in the batch. We also ensure that our `input_ids` and `attention_mask` tensors are padded.


## 5. Define Distillation Loss and Helper Functions

We combine two losses:

 - **KL Divergence Loss:** Between the softened output distributions (using a temperature) of teacher and student.
 - **Cross Entropy Loss:** Using ground truth tokens.

The final loss is a weighted sum of both. Adjust the temperature and alpha as needed.


In [None]:
# Loss functions and hyperparameters
temperature = 2.0
alpha = 0.7
kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
ce_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
max_length = 128

In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling

# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=max_length,
        padding="max_length",  # Pad to max_length
        return_tensors="pt"
    )

raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# Data collator and DataLoader
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Disable MLM for causal language modeling
    pad_to_multiple_of=8  # Optional: Pad to a multiple of 8 for better performance
)

batch_size = 2
data_loader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(
    student_logits, 
    teacher_logits, 
    target_ids, 
    temperature=2.0, 
    alpha=0.5,
    kl_loss_fn=nn.KLDivLoss(reduction="batchmean"), 
    ce_loss_fn=nn.CrossEntropyLoss(ignore_index=-100)
):
    """
    Computes the distillation loss as a weighted sum of KL divergence loss (between teacher and student outputs)
    and cross-entropy loss (using target_ids for next-token prediction).

    Args:
        student_logits (Tensor): Logits from the student model with shape [batch, seq_len, vocab_size].
        teacher_logits (Tensor): Logits from the teacher model with shape [batch, seq_len, vocab_size].
        target_ids (Tensor): Ground-truth token IDs with shape [batch, seq_len].
        temperature (float): Temperature for smoothing the logits.
        alpha (float): Weight for the distillation (KL) loss; (1 - alpha) is used for the CE loss.
        kl_loss_fn (nn.Module): KL divergence loss function.
        ce_loss_fn (nn.Module): Cross-entropy loss function.

    Returns:
        Tensor: The computed combined loss.
    """
    
    # 1. Apply temperature scaling to logits
    student_logits_temp = student_logits / temperature
    teacher_logits_temp = teacher_logits / temperature

    # 2. Compute softened probabilities/log-probabilities
    student_log_probs = F.log_softmax(student_logits_temp, dim=-1)
    teacher_probs = F.softmax(teacher_logits_temp, dim=-1)
    
    # 3. Compute KL divergence loss
    #    Reshape to 2D tensors with shape [batch * seq_len, vocab_size]
    loss_kl = kl_loss_fn(
        student_log_probs.view(-1, student_log_probs.size(-1)),
        teacher_probs.view(-1, teacher_probs.size(-1))
    ) * (temperature ** 2)
    
    # 4. Prepare logits and labels for the next-token prediction cross-entropy loss.
    #    We shift the student logits and target_ids so that each prediction is compared to the next token.
    shift_logits = student_logits[..., :-1, :].contiguous()  # shape: [batch, seq_len-1, vocab_size]
    shift_labels = target_ids[..., 1:].contiguous()           # shape: [batch, seq_len-1]
    
    # 5. (Optional) Check that the labels are within the valid range.
    vocab_size = student_logits.size(-1)
    if torch.any(shift_labels < 0) or torch.any(shift_labels >= vocab_size):
        raise ValueError(f"Target ids contain values outside the valid range [0, {vocab_size-1}].")
    
    # 6. Compute the cross-entropy loss for next-token prediction.
    loss_ce = ce_loss_fn(
        shift_logits.view(-1, vocab_size),  # flatten logits to [batch * (seq_len-1), vocab_size]
        shift_labels.view(-1)               # flatten labels to [batch * (seq_len-1)]
    )
    
    # 7. Combine losses with weighting
    loss = alpha * loss_kl + (1 - alpha) * loss_ce
    return loss


## 6. Training Loop on TPU

We use PyTorch/XLA’s `ParallelLoader` to feed data to the TPU cores. In a multi-core TPU setup,
you might launch the training using `xmp.spawn()`. Here, we run a single-process loop.

The loop performs the following for each batch:
 - Moves the batch to the TPU device.
 - Computes the teacher’s outputs (without gradient computation).
 - Computes the student’s outputs.
 - Computes the distillation loss.
 - Backpropagates and updates the student model using the TPU-optimized optimizer step.


In [None]:
num_epochs = 3
learning_rate = 2e-5
optimizer = optim.AdamW(student.parameters(), lr=learning_rate)

In [None]:
scaler = torch.cuda.amp.GradScaler()
accumulation_steps = 4

def train_loop_fn(loader, epoch, student, teacher, optimizer, device):
    student.train()
    teacher.eval()

    total_loss = 0.0
    total_steps = len(loader)
    optimizer.zero_grad()

    for step, batch in enumerate(loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        with torch.no_grad():
            teacher_outputs = teacher(**batch)
            teacher_logits = teacher_outputs.logits
        
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            student_outputs = student(**batch)
            student_logits = student_outputs.logits
            loss = distillation_loss(student_logits, teacher_logits, batch['input_ids'])
            
        scaler.scale(loss).backward()

        if (step + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item()
        if step % 500 == 0:
            print(f"Epoch {epoch} | Step {step}/{total_steps} | Loss: {loss.item()*accumulation_steps:.4f}")

    average_loss = total_loss / total_steps
    print(f"Epoch {epoch} completed. Average Loss: {average_loss:.4f}")
    return average_loss


## 7. Run Training with PyTorch/XLA Parallel Loader

We wrap our DataLoader with a ParallelLoader so that data is distributed to the TPU cores.

In [None]:
# Training
for epoch in range(1, num_epochs + 1):
    epoch_loss = train_loop_fn(data_loader, epoch, student, teacher, optimizer, device)
    print(f"Epoch {epoch} completed. Average Loss: {epoch_loss:.4f}")

## 8. Save the Distilled Student Model

Finally, we save the distilled student model and tokenizer. Adjust the saving path as necessary.


In [None]:
save_directory = "./R!--distill--Qwen2.5"
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

student.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
print("Distilled model saved to:", save_directory)

## 8. Push the Distilled Model to the Hugging Face Hub

Before pushing, you can log in interactively if needed. Run the following cell and follow the instructions:

```python
from huggingface_hub import notebook_login
notebook_login()
```

Once authenticated, push the model and tokenizer to your repository.

**Important:** Replace `"your-username/R!--distill--llama"` with your desired repository name.


In [None]:

# Push the student model
student.push_to_hub("codewithdark/R!-distill-Qwen2.5")
# Push the tokenizer
tokenizer.push_to_hub("codewithdark/R!-distill-Qwen2.5")

print("Model and tokenizer have been pushed to the Hugging Face Hub!")