# Lab 1: Neuron Distillation Training

## Introduction

In this lab, you will learn how to train a smaller "student" model using knowledge distillation on AWS Neuron hardware. This lab builds directly on Lab 0, where you generated teacher model logits from the Qwen3-30B-A3B model.

Knowledge distillation allows you to compress a large, high-performance model into a smaller, more efficient model while retaining much of the original model's capabilities. The student model learns not just from the hard labels (POSITIVE, NEGATIVE, NEUTRAL), but from the full probability distributions (soft labels) produced by the teacher model.

**Why Knowledge Distillation?**
- **Cost Reduction**: Smaller models require fewer compute resources and lower inference costs
- **Faster Inference**: Reduced model size means faster response times
- **Better Generalization**: Learning from soft labels often produces models that generalize better than training on hard labels alone
- **Deployment Flexibility**: Smaller models can be deployed on edge devices or in resource-constrained environments

**Training Approach:**

You will use a custom `KnowledgeDistillationTrainer` that combines two loss functions:
1. **Hard Loss**: Standard cross-entropy loss with true labels (teaches the model the correct answers)
2. **Soft Loss**: KL divergence between teacher and student logits (teaches the model the teacher's reasoning)

The combined loss is: `total_loss = α × soft_loss + (1 - α) × hard_loss`

Where α (alpha) controls the balance between learning from the teacher vs. learning from the labels.

**Student Model:**

You'll train [Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B), a 600 million parameter model - 50x smaller than the 30B teacher model!

**Prerequisites:**
- Completed Lab 0 with teacher logits saved to `data/output.json`
- AWS Trainium based EC2 instance
- AWS Neuron SDK installed
- Sufficient disk space for model compilation and checkpoints (~30GB recommended)

## Environment Setup

Configure environment variables to optimize Neuron compilation and runtime performance for distributed training.

**Environment Variables Explained:**

- **NEURON_CC_FLAGS**: Compiler flags for the Neuron compiler
  - `--model-type transformer`: Optimizes compilation for transformer architectures
  - `--retry_failed_compilation`: Automatically retry if compilation fails (improves reliability)

- **NEURON_FUSE_SOFTMAX**: Enable softmax fusion optimization (combines operations for better performance)

- **NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS**: Maximum number of concurrent inference requests (3 provides good throughput/latency balance)

- **MALLOC_ARENA_MAX**: Limits memory arenas to reduce memory fragmentation during training (important for long-running jobs)

- **WORLD_SIZE**: Total number of processes for distributed training. Set to 8 to match the number of NeuronCores we'll use (2 processes × 4 cores each = 8 total)

These settings are critical for stable, efficient training on Neuron hardware.

In [1]:
import os

from src.util import prettyprint_python

# Set Neuron compilation flags
os.environ['NEURON_CC_FLAGS'] = "--model-type transformer --retry_failed_compilation"
os.environ['NEURON_FUSE_SOFTMAX'] = "1"
os.environ['NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS'] = "3"
os.environ['MALLOC_ARENA_MAX'] = "64"
os.environ['WORLD_SIZE'] = "8"

## Training Configuration

Define the hyperparameters and settings for the distillation training job.

**Training Parameters:**

- **PROCESSES_PER_NODE**: Number of training processes to run in parallel (2 processes for distributed training)
  - Each process will use TP_DEGREE NeuronCores
  - Total NeuronCores used = PROCESSES_PER_NODE × TP_DEGREE = 2 × 2 = 4

- **NUM_EPOCHS**: Number of complete passes through the training dataset (3 epochs)
  - More epochs = more training, but risk of overfitting on small datasets

- **TP_DEGREE**: Tensor parallelism degree (2 NeuronCores per process)
  - Distributes model layers across multiple cores for memory efficiency
  - For a 0.6B model, TP=2 is sufficient

- **BS**: Batch size per device (1 sample at a time)
  - Small batch size for memory efficiency during compilation
  - Effective batch size = BS × GRADIENT_ACCUMULATION_STEPS = 1 × 16 = 16

- **GRADIENT_ACCUMULATION_STEPS**: Number of steps to accumulate gradients before updating weights (16)
  - Simulates larger batch sizes without increasing memory usage
  - Provides more stable gradient estimates

- **LOGGING_STEPS**: Log training metrics every N steps (1 = log every step for detailed monitoring)

- **MODEL_NAME**: Hugging Face model identifier for the student model
  - "Qwen/Qwen3-0.6B": 600M parameter model (50x smaller than the 30B teacher)

- **OUTPUT_DIR**: Directory where trained model checkpoints will be saved

- **MAX_STEPS**: Maximum training steps
  - Set to 5 if NEURON_EXTRACT_GRAPHS_ONLY=1 (for graph extraction/compilation testing)
  - Set to -1 otherwise (train for full NUM_EPOCHS)

In [None]:
# Training parameters
PROCESSES_PER_NODE = 2
NUM_EPOCHS = 3
TP_DEGREE = 2
BS = 1
GRADIENT_ACCUMULATION_STEPS = 16
LOGGING_STEPS = 1
MODEL_NAME = "Qwen/Qwen3-0.6B"
OUTPUT_DIR = f"{MODEL_NAME.split('/')[-1]}-finetuned"

# Set max steps based on environment
MAX_STEPS = 5 if os.environ.get('NEURON_EXTRACT_GRAPHS_ONLY') == '1' else -1

print(f"Model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Max steps: {MAX_STEPS}")

Model: meta-llama/Llama-3.2-1B-Instruct
Output directory: Llama-3.2-1B-Instruct-finetuned
Max steps: -1


## KnowledgeDistillationTrainer Code

Let's examine the custom `KnowledgeDistillationTrainer` class from `distill_neuron_torchrun.py`.

This trainer extends the Optimum Neuron `NeuronTrainer` class to implement knowledge distillation. The key innovation is in the `compute_loss` method, which combines two types of losses:

**Class Structure:**

```python
class KnowledgeDistillationTrainer(NeuronTrainer):
    def __init__(self, temperature=4.0, alpha=0.7, *args, **kwargs):
```

**Hyperparameters:**
- **temperature** (default=4.0): Controls the "softness" of probability distributions
  - Higher temperature = softer distributions (more uniform probabilities)
  - Softer distributions reveal more about the teacher's uncertainty and reasoning
  - Typical range: 2.0-10.0

- **alpha** (default=0.7): Balances soft loss vs. hard loss
  - alpha=0.7 means 70% weight on teacher's soft labels, 30% on true labels
  - Higher alpha = more emphasis on mimicking the teacher
  - Lower alpha = more emphasis on getting correct answers
  - Typical range: 0.5-0.9

The following cell displays the complete class implementation:

In [3]:
prettyprint_python("src/distill_neuron_torchrun.py", line_numbers=True, line_range=(35, 78))

0,1
1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44,"class KnowledgeDistillationTrainer(NeuronTrainer):  def __init__(self, temperature=4.0, alpha=0.7, *args, **kwargs):  super().__init__(*args, **kwargs)  self.temperature = temperature  self.alpha = alpha  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):  """"""  Compute knowledge distillation loss combining:  - Hard loss: standard cross-entropy with true labels  - Soft loss: KL divergence between teacher and student logits  """"""  if num_items_in_batch is not None:  # Do something with num_items_in_batch  pass  student_outputs = model(  input_ids=inputs['input_ids'],  attention_mask=inputs['attention_mask']  )  student_logits = student_outputs.logits  inputs = {k: v.to('xla') if torch.is_tensor(v) else v for k, v in inputs.items()}  # Hard loss (standard language modeling loss)  hard_loss = F.cross_entropy(  student_logits.view(-1, student_logits.size(-1)),  inputs['labels'].view(-1)  )  # Soft loss (knowledge distillation)  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)  teacher_soft = F.softmax(inputs['teacher_logits'] / self.temperature, dim=-1)  soft_loss = F.kl_div(  student_soft.view(-1, student_soft.size(-1)),  teacher_soft.view(-1, teacher_soft.size(-1)),  reduction='batchmean'  ) * (self.temperature ** 2)  # Combined loss  total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss  return (total_loss, student_outputs) if return_outputs else total_loss"


## Key Methods of KnowledgeDistillationTrainer

The `compute_loss` method is the core of the knowledge distillation process. Let's break down how it works:

**Step 1: Generate Student Predictions**
```python
student_outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
student_logits = student_outputs.logits
```
Run the student model to get its predictions (logits) for the input text.

**Step 2: Compute Hard Loss**
```python
hard_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), inputs['labels'].view(-1))
```
Standard cross-entropy loss comparing student predictions to true labels (POSITIVE/NEGATIVE/NEUTRAL).
This ensures the student learns to produce correct answers.

**Step 3: Compute Soft Loss**
```python
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(inputs['teacher_logits'] / self.temperature, dim=-1)
soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (self.temperature ** 2)
```
- Apply temperature scaling to both student and teacher logits (makes distributions softer)
- Compute KL divergence: measures how different the student's distribution is from the teacher's
- Multiply by temperature² to maintain gradient scale (mathematical requirement of distillation)

**Step 4: Combine Losses**
```python
total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
```
Weighted combination of soft and hard losses.

**Example:**
For input "This phone's battery life is absolutely amazing!":
- Hard loss: Penalizes if student doesn't predict "POSITIVE"
- Soft loss: Penalizes if student's confidence distribution differs from teacher's (e.g., teacher might be 95% confident POSITIVE, 4% NEUTRAL, 1% NEGATIVE)

The following cell displays the complete `compute_loss` implementation:

In [5]:
prettyprint_python("src/distill_neuron_torchrun.py", line_numbers=True, line_range=(41, 78))

0,1
1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38,"def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):  """"""  Compute knowledge distillation loss combining:  - Hard loss: standard cross-entropy with true labels  - Soft loss: KL divergence between teacher and student logits  """"""  if num_items_in_batch is not None:  # Do something with num_items_in_batch  pass  student_outputs = model(  input_ids=inputs['input_ids'],  attention_mask=inputs['attention_mask']  )  student_logits = student_outputs.logits  inputs = {k: v.to('xla') if torch.is_tensor(v) else v for k, v in inputs.items()}  # Hard loss (standard language modeling loss)  hard_loss = F.cross_entropy(  student_logits.view(-1, student_logits.size(-1)),  inputs['labels'].view(-1)  )  # Soft loss (knowledge distillation)  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)  teacher_soft = F.softmax(inputs['teacher_logits'] / self.temperature, dim=-1)  soft_loss = F.kl_div(  student_soft.view(-1, student_soft.size(-1)),  teacher_soft.view(-1, teacher_soft.size(-1)),  reduction='batchmean'  ) * (self.temperature ** 2)  # Combined loss  total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss  return (total_loss, student_outputs) if return_outputs else total_loss"


## Run Distillation Training

Execute the distributed training job using `torchrun`, PyTorch's distributed training launcher.

**Why torchrun?**
- Manages multiple training processes automatically
- Sets up distributed communication between processes
- Handles process synchronization and fault tolerance

**Command Structure:**

```bash
torchrun --nproc_per_node 2 src/distill_neuron_torchrun.py [training args]
```

**Key Arguments:**

- **--nproc_per_node**: Number of processes to launch (2 = data parallel training across 2 processes)
- **--model_id**: Student model to train (Qwen/Qwen3-0.6B)
- **--num_train_epochs**: Number of complete passes through the dataset (3)
- **--do_train**: Enable training mode
- **--max_steps**: Maximum training steps (-1 = train for full epochs, 5 = quick test)
- **--per_device_train_batch_size**: Batch size per device (1)
- **--gradient_accumulation_steps**: Accumulate gradients over 16 steps (effective batch size = 16)
- **--learning_rate**: Optimizer learning rate (1e-4 = 0.0001, conservative for distillation)
- **--bf16**: Use BFloat16 precision (faster training, lower memory, minimal accuracy loss)
- **--tensor_parallel_size**: Distribute model across 2 NeuronCores per process
- **--warmup_steps**: Gradually increase learning rate over first 5 steps (stabilizes training)
- **--pipeline_parallel_size**: No pipeline parallelism (1 = disabled)
- **--logging_steps**: Log metrics every step for detailed monitoring
- **--output_dir**: Where to save model checkpoints
- **--overwrite_output_dir**: Overwrite existing checkpoints if present

**Training Process:**

1. **Compilation Phase** (first run only, ~20-30 minutes):
   - Neuron compiler optimizes the model for Trainium hardware
   - Generates NEFF (Neuron Executable File Format) files
   - Cached for subsequent runs

2. **Training Phase** (~10-15 minutes):
   - Loads teacher logits from `data/output1.json`
   - Trains student model using knowledge distillation
   - Saves checkpoints to OUTPUT_DIR

3. **Final Model** saved to `./final_distilled_model`

**Monitoring Training:**
- Watch for loss values decreasing over time
- Soft loss and hard loss are logged separately
- Training is complete when you see "Training completed" message

**Note:** The first run will take significantly longer due to compilation. Be patient!

In [None]:
import subprocess

# Build the torchrun command
cmd = [
    "torchrun",
    "--nproc_per_node", str(PROCESSES_PER_NODE),
    "src/distill_neuron_torchrun.py",
    "--model_id", MODEL_NAME,
    "--num_train_epochs", str(NUM_EPOCHS),
    "--do_train",
    "--max_steps", str(MAX_STEPS),
    "--per_device_train_batch_size", str(BS),
    "--gradient_accumulation_steps", str(GRADIENT_ACCUMULATION_STEPS),
    "--learning_rate", "1e-4",
    "--bf16",
    "--tensor_parallel_size", str(TP_DEGREE),
    "--warmup_steps", "5",
    "--pipeline_parallel_size", "1",
    "--logging_steps", str(LOGGING_STEPS),
    "--output_dir", OUTPUT_DIR,
    "--overwrite_output_dir"
]

print("Running command:")
print(" ".join(cmd))
print("\n" + "="*50)

# Execute the command
result = subprocess.run(cmd, capture_output=True, text=True)
print("STDOUT:")
print(result.stdout)
if result.stderr:
    print("\nSTDERR:")
    print(result.stderr)
print(f"\nReturn code: {result.returncode}")

## Training Results and Next Steps

Congratulations on completing the knowledge distillation training!

**What You've Accomplished:**

1. ✅ Trained a 0.6B parameter student model using knowledge from a 30B parameter teacher
2. ✅ Achieved 50x model size reduction while retaining much of the teacher's performance
3. ✅ Learned to use distributed training on AWS Trainium with Neuron SDK
4. ✅ Implemented custom knowledge distillation loss combining soft and hard targets

**Model Outputs:**

Your trained model is saved in two locations:
- **Checkpoints**: `{OUTPUT_DIR}/` - Contains intermediate training checkpoints
- **Final Model**: `./final_distilled_model/` - The completed distilled model ready for deployment

**Evaluating Your Model:**

To test your distilled model's performance:

```python
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the distilled model
tokenizer = AutoTokenizer.from_pretrained("./final_distilled_model")
model = AutoModelForCausalLM.from_pretrained("./final_distilled_model")

# Test on a sample
test_text = "This phone's battery life is absolutely amazing!"
inputs = tokenizer(test_text, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=10)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

**Expected Benefits of Distillation:**

The student model (0.6B parameters) is approximately 50x smaller than the teacher model (30B parameters), which typically translates to:
- **Reduced Model Size**: Significantly smaller memory footprint for deployment
- **Faster Inference**: Fewer parameters mean faster forward passes
- **Lower Costs**: Reduced compute requirements for inference
- **Accuracy Trade-off**: Some performance loss is expected, but distillation helps retain more capability than training from scratch

**Note:** Actual performance metrics will depend on your specific dataset, training configuration, and evaluation criteria. We recommend benchmarking your distilled model against the teacher on your test set to quantify the accuracy/efficiency trade-off.

**Next Steps:**

1. **Deploy for Inference**: Use the distilled model with AWS Inferentia for cost-effective production inference
2. **Fine-tune Further**: Continue training on domain-specific data to improve performance
3. **Experiment with Hyperparameters**:
   - Try different temperature values (2.0-10.0)
   - Adjust alpha to balance soft vs. hard loss
   - Increase training epochs for better convergence
4. **Quantization**: Apply INT8 quantization for additional 4x size reduction
5. **Benchmark**: Compare inference latency and accuracy against the teacher model

**Troubleshooting:**

If training failed or results are poor:
- Check that `data/output1.json` contains valid teacher logits from Lab 0
- Verify sufficient disk space for compilation artifacts (~30GB)
- Review CloudWatch logs for detailed error messages
- Try reducing batch size or sequence length if running out of memory
- Ensure Neuron SDK version compatibility with the model

**Additional Resources:**

- [AWS Neuron Documentation](https://awsdocs-neuron.readthedocs-hosted.com/)
- [Optimum Neuron GitHub](https://github.com/huggingface/optimum-neuron)
- [Knowledge Distillation Paper](https://arxiv.org/abs/1503.02531) (Hinton et al.)
- [Qwen3 Model Card](https://huggingface.co/Qwen/Qwen3-0.6B)