# Neuron Distillation Training

This notebook converts the `run_distillation.sh` script to run knowledge distillation training on AWS Neuron.

## Environment Setup

In [1]:
import os

from src.util import prettyprint_python

# Set Neuron compilation flags
os.environ['NEURON_CC_FLAGS'] = "--model-type transformer --retry_failed_compilation"
os.environ['NEURON_FUSE_SOFTMAX'] = "1"
os.environ['NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS'] = "3"
os.environ['MALLOC_ARENA_MAX'] = "64"
os.environ['WORLD_SIZE'] = "8"

## Training Configuration

In [2]:
# Training parameters
PROCESSES_PER_NODE = 2
NUM_EPOCHS = 3
TP_DEGREE = 2
BS = 1
GRADIENT_ACCUMULATION_STEPS = 16
LOGGING_STEPS = 1
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
OUTPUT_DIR = f"{MODEL_NAME.split('/')[-1]}-finetuned"

# Set max steps based on environment
MAX_STEPS = 5 if os.environ.get('NEURON_EXTRACT_GRAPHS_ONLY') == '1' else -1

print(f"Model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Max steps: {MAX_STEPS}")

Model: meta-llama/Llama-3.2-1B-Instruct
Output directory: Llama-3.2-1B-Instruct-finetuned
Max steps: -1


## KnowledgeDistillationTrainer Code

Let's examine the KnowledgeDistillationTrainer class from `distill_neuron_torchrun.py`:

In [3]:
prettyprint_python("src/distill_neuron_torchrun.py", line_numbers=True, line_range=(35, 78))

0,1
1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44,"class KnowledgeDistillationTrainer(NeuronTrainer):  def __init__(self, temperature=4.0, alpha=0.7, *args, **kwargs):  super().__init__(*args, **kwargs)  self.temperature = temperature  self.alpha = alpha  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):  """"""  Compute knowledge distillation loss combining:  - Hard loss: standard cross-entropy with true labels  - Soft loss: KL divergence between teacher and student logits  """"""  if num_items_in_batch is not None:  # Do something with num_items_in_batch  pass  student_outputs = model(  input_ids=inputs['input_ids'],  attention_mask=inputs['attention_mask']  )  student_logits = student_outputs.logits  inputs = {k: v.to('xla') if torch.is_tensor(v) else v for k, v in inputs.items()}  # Hard loss (standard language modeling loss)  hard_loss = F.cross_entropy(  student_logits.view(-1, student_logits.size(-1)),  inputs['labels'].view(-1)  )  # Soft loss (knowledge distillation)  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)  teacher_soft = F.softmax(inputs['teacher_logits'] / self.temperature, dim=-1)  soft_loss = F.kl_div(  student_soft.view(-1, student_soft.size(-1)),  teacher_soft.view(-1, teacher_soft.size(-1)),  reduction='batchmean'  ) * (self.temperature ** 2)  # Combined loss  total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss  return (total_loss, student_outputs) if return_outputs else total_loss"


## Key Methods of KnowledgeDistillationTrainer

The `compute_loss` method is the core of the knowledge distillation process:

In [5]:
prettyprint_python("src/distill_neuron_torchrun.py", line_numbers=True, line_range=(41, 78))

0,1
1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38,"def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):  """"""  Compute knowledge distillation loss combining:  - Hard loss: standard cross-entropy with true labels  - Soft loss: KL divergence between teacher and student logits  """"""  if num_items_in_batch is not None:  # Do something with num_items_in_batch  pass  student_outputs = model(  input_ids=inputs['input_ids'],  attention_mask=inputs['attention_mask']  )  student_logits = student_outputs.logits  inputs = {k: v.to('xla') if torch.is_tensor(v) else v for k, v in inputs.items()}  # Hard loss (standard language modeling loss)  hard_loss = F.cross_entropy(  student_logits.view(-1, student_logits.size(-1)),  inputs['labels'].view(-1)  )  # Soft loss (knowledge distillation)  student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)  teacher_soft = F.softmax(inputs['teacher_logits'] / self.temperature, dim=-1)  soft_loss = F.kl_div(  student_soft.view(-1, student_soft.size(-1)),  teacher_soft.view(-1, teacher_soft.size(-1)),  reduction='batchmean'  ) * (self.temperature ** 2)  # Combined loss  total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss  return (total_loss, student_outputs) if return_outputs else total_loss"


## Run Distillation Training

Execute the training with torchrun:

In [None]:
import subprocess

# Build the torchrun command
cmd = [
    "torchrun",
    "--nproc_per_node", str(PROCESSES_PER_NODE),
    "src/distill_neuron_torchrun.py",
    "--model_id", MODEL_NAME,
    "--num_train_epochs", str(NUM_EPOCHS),
    "--do_train",
    "--max_steps", str(MAX_STEPS),
    "--per_device_train_batch_size", str(BS),
    "--gradient_accumulation_steps", str(GRADIENT_ACCUMULATION_STEPS),
    "--learning_rate", "1e-4",
    "--bf16",
    "--tensor_parallel_size", str(TP_DEGREE),
    "--warmup_steps", "5",
    "--pipeline_parallel_size", "1",
    "--logging_steps", str(LOGGING_STEPS),
    "--output_dir", OUTPUT_DIR,
    "--overwrite_output_dir"
]

print("Running command:")
print(" ".join(cmd))
print("\n" + "="*50)

# Execute the command
result = subprocess.run(cmd, capture_output=True, text=True)
print("STDOUT:")
print(result.stdout)
if result.stderr:
    print("\nSTDERR:")
    print(result.stderr)
print(f"\nReturn code: {result.returncode}")