<a href="https://colab.research.google.com/github/mridul-sahu/knowledge_distillation_intuitions/blob/main/Knowledge_Distillation_Intuitions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Intuitive Knowledge Distillation: From Teacher to Student with JAX and FLAX NNX

**Goal:** This notebook provides a clear, step-by-step guide to understanding and implementing various knowledge distillation techniques. We'll use JAX for high-performance computation, FLAX (specifically the new NNX API) for building neural networks with a more explicit stateful feel, and Optax for optimization. Our focus will be on building intuition for *why* and *how* each distillation method works, using the CIFAR-10 dataset as our playground.

**What you'll learn:**
* The core concepts of knowledge distillation.
* How to define and train models using FLAX NNX.
* How to implement baseline teacher and student models.
* How to apply standard knowledge distillation (matching output logits).
* How to use `flax.nnx.Intermediate` to capture and use intermediate model representations for more advanced distillation techniques like:
    * Matching hidden state representations using Cosine Similarity.
    * Using an intermediate regressor (FitNets-style) with MSE loss.
* How to manage state, parameters, and RNGs in FLAX NNX.

**Acknowlegement**
* This tutorial is a translation of https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html in JAX and FLAX NNX.

## Part 0: Introduction & Setup

### 0.1. A Quick Word on JAX, FLAX NNX, and Optax

* **JAX:** JAX is a Python library for high-performance numerical computation, especially popular for machine learning research. It combines NumPy's familiar API with automatic differentiation (`jax.grad`), composition of function transformations (`jax.jit` for JIT compilation to XLA, `jax.vmap` for automatic vectorization, `jax.pmap` for SPMD-style parallel programming), and execution on accelerators like GPUs and TPUs. JAX emphasizes pure functions, which means functions don't have side effects and always produce the same output for the same input.

* **FLAX:** FLAX is a neural network library built on top of JAX, designed for flexibility and clarity.
    * **FLAX NNX (`flax.nnx`):** NNX is an API within FLAX that aims to provide a more PyTorch-like, object-oriented programming model. Modules in NNX are stateful objects, meaning their parameters and other state (like BatchNorm statistics) are stored as attributes directly on the module instance. This can make model definition and state management feel more intuitive, especially when dealing with complex models or needing to access/modify parts of the model state (like we will for intermediate representations). It still seamlessly integrates with JAX's functional transformations.

* **Optax:** Optax is a gradient processing and optimization library for JAX. It provides a wide range of popular optimizers (Adam, SGD, etc.) and makes it easy to build custom optimization schemes.

Now, let's get our environment set up.

In [None]:
# Install necessary libraries if not already present in your Colab environment
# It's good practice to ensure you have recent versions.
!pip install --upgrade pip
!pip install --upgrade jax # jaxlib
!pip install --upgrade flax optax
!pip install --upgrade tqdm
!pip install tensorflow_datasets orbax-checkpoint matplotlib

In [None]:
import os
import jax
import jax.numpy as jnp
import flax
from flax import nnx
import optax
import tensorflow_datasets as tfds
import tensorflow as tf # For tf.data pipelines
import numpy as np
import matplotlib.pyplot as plt
from functools import partial # For using partial functions
import orbax.checkpoint as ocp # For saving and loading model checkpoints
from typing import Sequence, Tuple, Dict, Any, Optional, Callable
# Additional import for progress bar
from tqdm.auto import tqdm

# Helper for managing PRNGKeys in JAX
from jax import random

import pathlib
import shutil # For cleaning up directories if needed
import urllib.request
import zipfile

# Check JAX device
print(f"JAX version: {jax.__version__}")
print(f"FLAX version: {flax.__version__}") # NNX is part of flax
print(f"Optax version: {optax.__version__}")
print(f"Default JAX backend: {jax.default_backend()}")
print(f"Available JAX devices: {jax.devices()}")

Download all checkpoints and add utility to restore models

In [None]:
checkpoint_urls = [
 ("student_model_final", "https://raw.githubusercontent.com/mridul-sahu/knowledge_distillation_intuitions/main/student_model_final.zip"),
 ("student_model_kd_final", "https://raw.githubusercontent.com/mridul-sahu/knowledge_distillation_intuitions/main/student_model_kd_final.zip"),
 ("teacher_model_final", "https://raw.githubusercontent.com/mridul-sahu/knowledge_distillation_intuitions/main/teacher_model_final.zip")
]
checkpoint_dir = pathlib.Path('/tmp/flax_nnx_kd_checkpoints/')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

for model_name, checkpoint_url in checkpoint_urls:
  checkpoint_zip_path = checkpoint_dir / f"{model_name}.zip"
  checkpoint_dir_sub = checkpoint_dir / model_name

  # --- Download and Extract Checkpoints ---
  if not checkpoint_dir_sub.exists():
      if not checkpoint_zip_path.exists():
          print(f"Downloading checkpoints from {checkpoint_url}...")
          urllib.request.urlretrieve(checkpoint_url, checkpoint_zip_path)
          print(f"Checkpoints downloaded to {checkpoint_zip_path}")
      else:
          print(f"Checkpoints zip file already exists at {checkpoint_zip_path}")

      print(f"Extracting {checkpoint_zip_path}...")
      with zipfile.ZipFile(checkpoint_zip_path, "r") as zip_ref:
          zip_ref.extractall("/") # zips have full path /tmp/flax_nnx_kd_checkpoints/{model_name}
      print(f"Checkpoints extracted to {checkpoint_dir_sub}")
  else:
      print(f"Checkpoints directory {checkpoint_dir} already exists. Skipping download and extraction.")

In [None]:
!ls /tmp/flax_nnx_kd_checkpoints/

In [None]:
def restore_model(model: nnx.Module, model_name):
  with ocp.StandardCheckpointer() as checkpointer:
    model_ckpt_path = os.path.join(checkpoint_dir, model_name)
    model_weights = checkpointer.restore(model_ckpt_path, nnx.state(model, nnx.Param))
    nnx.update(model, model_weights)

### 0.2. Global Configurations

We can set some global configurations like batch size, learning rate, etc., here. These can be overridden later if needed for specific experiments.

In [None]:
# Global configurations
BATCH_SIZE = 128
LEARNING_RATE = 0.001 # A common starting point for Adam
NUM_EPOCHS_TEACHER = 30 # Teacher needs to learn well
NUM_EPOCHS_STUDENT = 30 # Student also gets a fair number of epochs
NUM_CLASSES = 10
RNG_SEED = 42 # For reproducibility

# For NNX, we often need to manage PRNG key sequences.
# Let's create a main key and split it for different purposes.
# This initial key will be used to derive more keys as needed.
main_key = random.key(RNG_SEED)

# CIFAR-10 mean and std for normalization
# These values are commonly used for CIFAR-10 pre-trained models.
CIFAR10_MEAN = jnp.array([0.49139968, 0.48215841, 0.44653091])
CIFAR10_STD = jnp.array([0.24703223, 0.24348513, 0.26158784])
# We'll stick with one set for consistency. Let's use the first one.

## Part 1: Understanding Knowledge Distillation (A Quick Recap)

Before diving into the code, let's briefly revisit the core idea of knowledge distillation.

Imagine you have a very skilled teacher (a large, complex, and accurate neural network) and a student (a smaller, faster neural network). The student wants to learn to perform a task as well as the teacher.

**How does the student learn?**

1.  **Learning from the ground truth:** The student can learn directly from the actual labels in the training data (e.g., "this image is a cat"). This is the standard way of training.
2.  **Learning from the teacher's "wisdom":** The teacher, having learned from vast amounts of data, often has a more nuanced understanding. For example, the teacher might say, "I'm 90% sure this is a cat, but it also looks a tiny bit like a dog (5%), and definitely not like a car (0.01%)." This full probability distribution over classes is called "soft targets" or "soft labels."

**Knowledge Distillation (KD)** is a technique where the student model is trained to do both:
* Match the true labels (hard targets).
* Mimic the soft targets produced by the pre-trained teacher model.

**Why is this helpful?**

* **Richer Information:** The soft targets from the teacher provide more information per training sample than just the hard labels. They reveal how the teacher model "thinks" and generalizes, including similarities between classes.
* **Model Compression:** It allows us to "distill" the knowledge from a large, computationally expensive teacher model into a smaller, more efficient student model.
* **Improved Performance:** The student often achieves better performance than if it were trained solely on hard labels, making it suitable for deployment on devices with limited computational resources (like mobile phones or embedded systems).

In this tutorial, we'll explore a few ways to transfer this knowledge.

## Part 2: Data Loading and Preprocessing (CIFAR-10)

We'll use the CIFAR-10 dataset. It consists of 60,000 32x32 color images in 10 classes (50,000 training, 10,000 test). We'll use `tensorflow_datasets` (TFDS) and `tf.data` for efficient input pipelines.

In [None]:
def preprocess_image(image, label, training: bool):
    """Converts image to float, normalizes to [0,1], applies augmentations (if training),
       and then standard mean/std normalization.
    """
    image = tf.image.convert_image_dtype(image, tf.float32) # uint8 -> float32, scales to [0,1]

    if training:
        # Augmentations
        # Random horizontal flip
        image = tf.image.random_flip_left_right(image)

        # Pad and random crop (CIFAR-10: 32x32 -> pad to 40x40 -> crop 32x32)
        image_shape = tf.shape(image)
        image_height, image_width = image_shape[0], image_shape[1]
        padded_image = tf.image.resize_with_crop_or_pad(
            image, image_height + 8, image_width + 8
        )
        image = tf.image.random_crop(padded_image, size=[image_height, image_width, 3])

    # Standard mean/std normalization
    image = (image - CIFAR10_MEAN) / CIFAR10_STD
    return image, label

def create_datasets(batch_size: int):
    """Creates training and test datasets for CIFAR-10."""
    # Load using TFDS
    (train_ds_tf, test_ds_tf), ds_info = tfds.load('cifar10', split=['train', 'test'], as_supervised=True, with_info=True)

    # Training dataset
    train_preprocess_fn = partial(preprocess_image, training=True)
    train_ds = train_ds_tf.map(train_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
    train_ds = train_ds.shuffle(ds_info.splits['train'].num_examples)
    train_ds = train_ds.batch(batch_size)
    train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

    # Test dataset
    test_preprocess_fn = partial(preprocess_image, training=False) # No augmentation key needed
    test_ds = test_ds_tf.map(test_preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
    test_ds = test_ds.batch(batch_size)
    test_ds = test_ds.prefetch(tf.data.AUTOTUNE)

    return tfds.as_numpy(train_ds), tfds.as_numpy(test_ds), ds_info

In [None]:
train_ds, test_ds, ds_info = create_datasets(BATCH_SIZE)

# Inspect a batch
sample_train_batch = next(iter(train_ds))
images, labels = sample_train_batch
print("Sample Training Batch:")
print(f"  Images shape: {images.shape}, dtype: {images.dtype}")
print(f"  Labels shape: {labels.shape}, dtype: {labels.dtype}")
print(f"  Image min/max after full normalization: {images.min():.2f} / {images.max():.2f}")


# Optional: Display a few images
def denormalize_for_display(image_batch_normalized):
    # Reverse the mean/std normalization
    return np.clip((image_batch_normalized * CIFAR10_STD) + CIFAR10_MEAN, 0, 1)

images_to_show = denormalize_for_display(images[:4]) # Show fewer images

plt.figure(figsize=(8, 2)) # Adjusted figure size
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(images_to_show[i])
    plt.title(f"Label: {labels[i]}")
    plt.axis('off')
plt.suptitle("Sample Training Images (Augmented & De-normalized for Display)")
plt.tight_layout(rect=[0, 0, 1, 0.90])
plt.show()

num_train_examples = ds_info.splits['train'].num_examples
num_test_examples = ds_info.splits['test'].num_examples
steps_per_epoch_train = num_train_examples // BATCH_SIZE
steps_per_epoch_test = num_test_examples // BATCH_SIZE

print(f"\nNumber of training examples: {num_train_examples}")
print(f"Number of test examples: {num_test_examples}")
print(f"Training steps per epoch: {steps_per_epoch_train}")
print(f"Test steps per epoch: {steps_per_epoch_test}")

## Part 3: Defining Model Architectures with FLAX NNX

Now, we'll define our neural network architectures using FLAX NNX. NNX allows us to define modules in a more object-oriented way, where layers (and their parameters) are attributes of the module class.

**Key NNX Concepts Used Here:**
* `nnx.Module`: Base class for all NNX modules.
* `nnx.Conv`: Convolutional layer.
* `nnx.Linear`: Fully connected (dense) layer.
* `nnx.Dropout`: Dropout layer for regularization.
* `nnx.max_pool` (functional): Max pooling operation.
* `nnx.Sequential`: A container for running a sequence of layers or functions.
* `nnx.Rngs`: A way to manage JAX PRNGKeys for operations like dropout and parameter initialization within NNX modules.
* `nnx.Param`: Wrapper to indicate that an attribute is a trainable parameter. NNX layers like `nnx.Linear` and `nnx.Conv` manage their `nnx.Param` attributes internally.
* **Initialization**: In NNX, layers are typically instantiated in the `__init__` method, and they create their parameters at that time, given an appropriate `rngs` context for `'params'`.
* **`__call__` method**: This is where the forward pass logic is defined.

We'll create two CNNs:
1.  `DeepNN_NNX`: A deeper network to serve as our teacher.
2.  `LightNN_NNX`: A shallower, lightweight network for our student.

The architectures will be similar to the PyTorch tutorial to facilitate comparison.

In [None]:
class ConvBlock(nnx.Module):
  """A helper module for a common Conv -> ReLU -> Conv -> ReLU -> MaxPool sequence."""
  def __init__(self, in_filters1: int, out_filters1: int, out_filters2: int, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(in_filters1, out_filters1, kernel_size=(3, 3), padding=1, rngs=rngs)
    self.conv2 = nnx.Conv(out_filters1, out_filters2, kernel_size=(3, 3), padding=1, rngs=rngs)

  def __call__(self, x: jax.Array) -> jax.Array:
    x = nnx.relu(self.conv1(x))
    x = nnx.relu(self.conv2(x))
    x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
    return x

class Classifier(nnx.Module):
  """A helper module for the classification head."""
  def __init__(self, in_features: int, hidden_features: int, out_features: int, dropout_rate: float, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(in_features, hidden_features, rngs=rngs)
    self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs) # rngs for dropout state init
    self.linear2 = nnx.Linear(hidden_features, out_features, rngs=rngs)

  def __call__(self, x: jax.Array) -> jax.Array:
    x = nnx.relu(self.linear1(x))
    x = self.dropout(x)
    x = self.linear2(x)
    return x


class DeepNN_NNX(nnx.Module):
  """Deeper CNN for the teacher model."""
  def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
    # features will apply ConvBlocks sequentially
    # Each ConvBlock handles its own convs, relus, and max_pool
    self.features_block1 = ConvBlock(3, 128, 64, rngs=rngs)      # Output: 64 x 16 x 16
    self.features_block2 = ConvBlock(64, 64, 32, rngs=rngs)       # Output: 32 x 8 x 8

    # After two ConvBlocks, image size 32x32 -> 16x16 (after block1) -> 8x8 (after block2).
    # So, flattened features = 32 * 8 * 8 = 2048
    self.classifier = Classifier(32 * 8 * 8, 512, num_classes, dropout_rate=0.1, rngs=rngs)

  def __call__(self, x: jax.Array) -> jax.Array:
    x = self.features_block1(x)
    x = self.features_block2(x)
    x = x.reshape((x.shape[0], -1))  # Flatten
    x = self.classifier(x)
    return x


class LightNN_NNX(nnx.Module):
  """Lightweight CNN for the student model."""
  def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(3, 16, kernel_size=(3,3), padding=1, rngs=rngs)
    self.conv2 = nnx.Conv(16, 16, kernel_size=(3,3), padding=1, rngs=rngs)
    # Pooling and relu are functional

    # After feature extraction (2x Conv->ReLU->Pool), image size 32x32 -> 16x16 -> 8x8.
    # Filters: 16. So, flattened features = 16 * 8 * 8 = 1024
    self.classifier = Classifier(16 * 8 * 8, 256, num_classes, dropout_rate=0.1, rngs=rngs)

  def __call__(self, x: jax.Array) -> jax.Array:
    x = nnx.relu(self.conv1(x))
    x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
    x = nnx.relu(self.conv2(x))
    x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
    x = x.reshape((x.shape[0], -1))  # Flatten
    x = self.classifier(x)
    return x

In [None]:
# Let's initialize the models and test them with a dummy input
# We need to provide appropriate RNGs for 'params' and 'dropout' (if used during call)

# Derive keys for model initialization
key_teacher_params, key_student_params, main_key = random.split(main_key, 3)

# Create RNGs contexts for NNX
# 'params' for parameter initialization
# 'dropout' for dropout layers during forward pass if training
# 'default' can be used if no specific key is needed by a layer for its init
teacher_rngs = nnx.Rngs(params=key_teacher_params, dropout=random.key(1)) # new dropout key for teacher
student_rngs = nnx.Rngs(params=key_student_params, dropout=random.key(2)) # new dropout key for student

# Initialize Teacher Model
teacher_model = DeepNN_NNX(num_classes=NUM_CLASSES, rngs=teacher_rngs)
print("Teacher Model Initialized.")

# Initialize Student Model
student_model = LightNN_NNX(num_classes=NUM_CLASSES, rngs=student_rngs)
print("Student Model Initialized.")

# Create a dummy input batch (Batch, Height, Width, Channels)
dummy_images = jnp.ones((BATCH_SIZE, 32, 32, 3))

print("\nTeacher Model Summary:")
print(nnx.tabulate(teacher_model, dummy_images))

print("\nStudent Model Summary:")
print(nnx.tabulate(student_model, dummy_images))

## Part 4: Training Utilities in FLAX NNX

With our models and data pipelines ready, we'll define the utilities for training and evaluation. We will closely follow the idiomatic patterns for Flax NNX, leveraging:

* `nnx.Optimizer`: To bundle our model with an Optax optimizer, simplifying parameter updates and optimizer state management.
* `nnx.jit`: For JIT-compiling our training and evaluation steps. It's NNX-aware and handles module state correctly.
* `nnx.value_and_grad`: The NNX version for computing loss and gradients with respect to the model's `nnx.Param` variables.
* `nnx.MultiMetric`: A convenient container for managing and computing multiple metrics like loss and accuracy.

This approach makes the training code concise and leverages NNX's strengths.

In [None]:
# --- Loss and Metrics Calculation Function ---
def calculate_loss_and_logits(model: nnx.Module, batch: Tuple[jax.Array, jax.Array]):
    """
    Calculates loss and logits for a given model and batch.
    The `train` flag controls the behavior of layers like Dropout or BatchNorm.
    """
    images, labels = batch # Assuming batch is a tuple (images, labels)

    logits = model(images)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
    return loss, logits

# --- Train Step and Eval Step ---

@nnx.jit
def train_step(
    model: nnx.Module,
    optimizer: nnx.Optimizer,
    metrics_computer: nnx.MultiMetric,
    batch: Tuple[jax.Array, jax.Array]
):
    """
    Performs a single training step.
    Updates model parameters, optimizer state, and metrics in-place.
    """
    # Create a gradient function for calculate_loss_and_logits with respect to the model.
    # has_aux=True is used because calculate_loss_and_logits returns (loss, logits).
    grad_fn = nnx.value_and_grad(calculate_loss_and_logits, has_aux=True)

    # Compute loss, logits, and gradients.
    # State updates for layers happen here.
    (loss, logits), grads = grad_fn(model, batch)

    # Update metrics (e.g., loss, accuracy) in-place.
    metrics_computer.update(loss=loss, logits=logits, labels=batch[1]) # batch[1] contains labels

    # Apply gradients to update model parameters and optimizer state in-place.
    optimizer.update(grads)


@nnx.jit
def eval_step(
    model: nnx.Module,
    metrics_computer: nnx.MultiMetric,
    batch: Tuple[jax.Array, jax.Array]
):
    """
    Performs a single evaluation step.
    Updates metrics in-place.
    """
    # Compute loss and logits with the model in evaluation mode.
    loss, logits = calculate_loss_and_logits(model, batch)

    # Update metrics in-place.
    metrics_computer.update(loss=loss, logits=logits, labels=batch[1])

**Workflow with `nnx.Optimizer` and `nnx.MultiMetric`:**

1.  **Initialization:**
    * An `nnx.Optimizer` is created by passing it an `nnx.Module` instance and an Optax optimizer definition (e.g., `optax.adam(learning_rate)`).
    * An `nnx.MultiMetric` object is created to hold individual metrics like `nnx.metrics.Average` (for loss) and `nnx.metrics.Accuracy`.

2.  **`train_step(model, optimizer, metrics_computer, batch)`:**
    * The `model` argument here is typically `optimizer.module`.
    * `nnx.value_and_grad` is used to get the loss, auxiliary outputs (logits), and gradients for `calculate_loss_and_logits`.
    * `metrics_computer.update(...)` is called with the loss and logits to accumulate metric values for the current epoch/evaluation period. This updates the `metrics_computer` object in-place.
    * `optimizer.update(grads)` applies the gradients using the wrapped Optax optimizer. This updates the parameters of `optimizer.module` and the Optax optimizer's internal state, all in-place.

3.  **`eval_step(model, metrics_computer, batch)`:**
    * Similar to `train_step` but without gradient computation or optimizer updates. It calls `calculate_loss_and_logits` with `eval` mode on.

4.  **Main Loop:**
    * At the end of each evaluation period (e.g., an epoch), `metrics_computer.compute()` is called to get the aggregated metric values.
    * `metrics_computer.reset()` is then called to clear the accumulators for the next period.
    * All stateful objects (`optimizer.module`, `optimizer` itself, and `metrics_computer`) are modified in-place by the JIT-compiled step functions due to NNX's handling of reference semantics within its lifted transforms.

In [None]:
# --- Main Training Loop ---
def train_and_evaluate(
    model_to_train: nnx.Module,
    optax_optimizer: optax.GradientTransformation,
    train_ds: Any,
    test_ds: Any,
    num_epochs: int,
    steps_per_epoch_train: int,
    steps_per_epoch_test: int,
    eval_every_epochs: int = 1, # Evaluate after every 'eval_every_epochs'
    model_name: str = "Model"
) -> Tuple[nnx.Module, Dict[str, list]]:
    """
    Trains and evaluates an NNX model.
    Uses tqdm for progress bars and plots metrics at the end.
    """
    optimizer = nnx.Optimizer(model_to_train, optax_optimizer)

    train_metrics_computer = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"), # Explicitly name for clarity in dict keys
        accuracy=nnx.metrics.Accuracy()
    )
    eval_metrics_computer = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
        accuracy=nnx.metrics.Accuracy()
    )

    history = {
        'train_loss': [], 'train_accuracy': [],
        'test_loss': [], 'test_accuracy': []
    }

    print(f"\nStarting training for {model_name} for {num_epochs} epochs...")

    for epoch in tqdm(range(num_epochs), desc=f"Training {model_name}"):
        # --- Training Phase ---
        model_to_train.train()
        train_ds_iterator = iter(train_ds)
        train_metrics_computer.reset()
        for batch in tqdm(train_ds_iterator, total=steps_per_epoch_train, desc=f"Epoch {epoch+1} Training", leave=False):
            # Ensure data is JAX array; TFDS usually yields NumPy.
            batch_tuple = (jnp.asarray(batch[0]), jnp.asarray(batch[1]))
            train_step(model_to_train, optimizer, train_metrics_computer, batch_tuple)

        computed_train_metrics = train_metrics_computer.compute()
        history['train_loss'].append(computed_train_metrics['loss'].item())
        history['train_accuracy'].append(computed_train_metrics['accuracy'].item())

        tqdm.write(f"{model_name} - Epoch {epoch+1}/{num_epochs} Training: "
                   f"Avg Loss: {computed_train_metrics['loss']:.4f}, "
                   f"Avg Acc: {computed_train_metrics['accuracy']:.4f}")

        # --- Evaluation Phase ---
        if (epoch + 1) % eval_every_epochs == 0 or (epoch + 1) == num_epochs:
          model_to_train.eval()
          test_ds_iterator = iter(test_ds)
          eval_metrics_computer.reset()
          for eval_batch in tqdm(test_ds_iterator, total=steps_per_epoch_test, desc=f"Epoch {epoch+1} Evaluation", leave=False):
              eval_batch_tuple = (jnp.asarray(eval_batch[0]), jnp.asarray(eval_batch[1]))
              eval_step(model_to_train, eval_metrics_computer, eval_batch_tuple)

          computed_eval_metrics = eval_metrics_computer.compute()
          history['test_loss'].append(computed_eval_metrics['loss'].item())
          history['test_accuracy'].append(computed_eval_metrics['accuracy'].item())

          tqdm.write(f"{model_name} - Epoch {epoch+1}/{num_epochs} Evaluation: "
                      f"Avg Loss: {computed_eval_metrics['loss']:.4f}, "
                      f"Avg Acc: {computed_eval_metrics['accuracy']:.4f}")
        tqdm.write("-" * 70)

    print(f"Training finished for {model_name}.")

    # --- Plotting Final Metrics ---
    epochs_evaluated_train = list(range(1, num_epochs + 1))
    # Test metrics might not be available for every epoch if eval_every_epochs > 1
    epochs_evaluated_test = list(range(eval_every_epochs, num_epochs + eval_every_epochs, eval_every_epochs))
    # Ensure the last epoch is included if it was an eval epoch
    if num_epochs not in epochs_evaluated_test and (num_epochs % eval_every_epochs == 0 or num_epochs == num_epochs) :
        if len(history['test_loss']) == len(epochs_evaluated_train) / eval_every_epochs : # a bit heuristic
             pass # already covered
        elif len(history['test_loss']) < len(epochs_evaluated_train) and len(history['test_loss']) > 0 : # if last epoch was eval but not caught by range
             if epochs_evaluated_test[-1] != num_epochs : epochs_evaluated_test.append(num_epochs)


    plt.figure(figsize=(14, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_evaluated_train, history['train_loss'], label='Train Loss', marker='o', linestyle='-')
    if history['test_loss']: # Only plot if there's data
        plt.plot(epochs_evaluated_test, history['test_loss'], label='Test Loss', marker='x', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{model_name} - Loss Over Epochs')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs_evaluated_train, history['train_accuracy'], label='Train Accuracy', marker='o', linestyle='-')
    if history['test_accuracy']: # Only plot if there's data
        plt.plot(epochs_evaluated_test, history['test_accuracy'], label='Test Accuracy', marker='x', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'{model_name} - Accuracy Over Epochs')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    return model_to_train, history

## Part 5: Experiment 1 - Baseline: Training Teacher and Student Independently

Now that we have our model architectures (Part 3) and training utilities (Part 4) defined, we can start training our models.

**Goals for this section:**
1.  Train the `DeepNN_NNX` (Teacher) model from scratch on CIFAR-10 and record its performance.
2.  Train the `LightNN_NNX` (Student) model from scratch on CIFAR-10, also independently, and record its performance. This will serve as our baseline student accuracy.

We will use the `train_and_evaluate` function we defined earlier. For reproducibility, and to ensure fair comparisons later (e.g., when training the student with distillation vs. without), we need to be careful with our JAX PRNGKeys for model initialization.

**Checkpointing:**
We'll also include a basic setup for saving and loading model weights using Orbax, which is the recommended checkpointing library for JAX/Flax.

In [None]:
# --- 1. Train the Teacher Model (DeepNN_NNX) ---

print("="*30)
print("Training Teacher Model (DeepNN_NNX)")
print("="*30)

# Define Optax optimizer for the teacher
teacher_optimizer = optax.adam(learning_rate=LEARNING_RATE)

restored = False
try:
  restore_model(teacher_model, "teacher_model_final")
  trained_teacher_model = teacher_model
  restored = True
except Exception as e:
  print(f"Restore failed: {e}")
  restored=False

if not restored:
  trained_teacher_model, teacher_history = train_and_evaluate(
      model_to_train=teacher_model, # Pass the instantiated model
      optax_optimizer=teacher_optimizer,
      train_ds=train_ds,
      test_ds=test_ds,
      num_epochs=NUM_EPOCHS_TEACHER,
      steps_per_epoch_train=steps_per_epoch_train,
      steps_per_epoch_test=steps_per_epoch_test,
      model_name="Teacher (DeepNN)"
  )

In [None]:
try:
  teacher_model_save_path = os.path.join(checkpoint_dir, "teacher_model_final")
  ocp.StandardCheckpointer().save(teacher_model_save_path, nnx.state(trained_teacher_model, nnx.Param))
  print(f"Teacher model parameters saved to {teacher_model_save_path}")
except Exception as e:
  print(f"Could not save teacher model: {e}")

In [None]:
!zip -r /tmp/flax_nnx_kd_checkpoints/teacher_model_final.zip /tmp/flax_nnx_kd_checkpoints/teacher_model_final

from google.colab import files
files.download("/tmp/flax_nnx_kd_checkpoints/teacher_model_final.zip")

In [None]:
# --- 2. Train the Baseline Student Model (LightNN_NNX) ---

print("\n" + "="*30)
print("Training Baseline Student Model (LightNN_NNX)")
print("="*30)

# Define Optax optimizer for the student
student_optimizer = optax.adam(learning_rate=LEARNING_RATE)

restored = False
try:
  restore_model(student_model, "student_model_final")
  trained_student_model = student_model
  restored = True
except Exception as e:
  print(f"Restore failed: {e}")
  restored=False

if not restored:
  trained_student_model, student_history = train_and_evaluate(
      model_to_train=student_model, # Pass the instantiated model
      optax_optimizer=student_optimizer,
      train_ds=train_ds,
      test_ds=test_ds,
      num_epochs=NUM_EPOCHS_STUDENT,
      steps_per_epoch_train=steps_per_epoch_train,
      steps_per_epoch_test=steps_per_epoch_test,
      model_name="Student (DeepNN)"
  )

In [None]:
try:
  student_model_save_path = os.path.join(checkpoint_dir, "student_model_final")
  ocp.StandardCheckpointer().save(student_model_save_path, nnx.state(trained_student_model, nnx.Param))
  print(f"Student model parameters saved to {student_model_save_path}")
except Exception as e:
  print(f"Could not save student model: {e}")

In [None]:
!zip -r /tmp/flax_nnx_kd_checkpoints/student_model_final.zip /tmp/flax_nnx_kd_checkpoints/student_model_final

files.download("/tmp/flax_nnx_kd_checkpoints/student_model_final.zip")

### Baseline Performance Summary

After training both models independently, we have:

* **Teacher Model (DeepNN_NNX) Final Test Accuracy:** We'll fill this in after running the cell above.
* **Baseline Student Model (LightNN_NNX) Final Test Accuracy:** We'll fill this in after running the cell above.

The goal of knowledge distillation will be to train a *new instance* of the `LightNN_NNX` (initialized with the *same* starting parameters as this baseline student for a fair comparison) to achieve an accuracy closer to the teacher's, or ideally even surpass this baseline student's performance, without changing the student model's architecture.

## Part 6: Experiment 2 - Standard Knowledge Distillation (Matching Output Logits)

Now that we have baseline performances for our teacher and student models, we'll implement the classic knowledge distillation technique, often attributed to Hinton et al. (2015).

**The Core Idea:**

The student model learns by minimizing a combined loss function:
1.  **Standard Cross-Entropy Loss:** Calculated between the student's predictions and the true labels (hard targets). This ensures the student still learns the primary task correctly.
2.  **Distillation Loss (Soft Target Loss):** Calculated between the student's predictions and the "soft targets" provided by the pre-trained teacher model.

**Soft Targets and Temperature:**

* **Teacher's Knowledge:** A well-trained teacher model, even when it makes a wrong prediction, often assigns higher probabilities to classes that are semantically similar to the true class. For example, it might confuse a "truck" with an "automobile" but is unlikely to confuse it with a "dog". This rich similarity information is present in the teacher's full output probability distribution (logits).
* **Softening Probabilities:** To make this nuanced information more accessible to the student (especially the smaller probabilities for non-target classes), we use a "temperature" hyperparameter ($T$).
    * The logits of both the teacher and the student are divided by $T$ before applying the softmax function: $p_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}$, where $z_i$ are the logits.
    * A higher $T > 1$ "softens" the probability distribution, making it less peaky (i.e., probabilities become more uniform). This amplifies the contribution of smaller logit values, allowing the student to learn from the relative similarities the teacher has learned.
    * When $T=1$, it's the standard softmax.
* **Distillation Loss Calculation:** The distillation loss typically measures the difference between the student's softened probabilities and the teacher's softened probabilities. A common choice is the Kullback-Leibler (KL) divergence or a cross-entropy loss between these softened distributions. The original paper by Hinton et al. also suggests scaling this part of the loss by $T^2$.

**Combined Loss Function:**
The total loss for the student is a weighted sum:
$L_{total} = w_{CE} \cdot L_{CE}(\text{student_preds, true_labels}) + w_{KD} \cdot L_{KD}(\text{student_soft_preds, teacher_soft_preds}, T)$
where $w_{CE}$ and $w_{KD}$ are weights that balance the two loss components.

In this experiment, we will:
1.  Initialize a new student model instance **with the exact same initial parameters** as our baseline student for a fair comparison.
2.  Use our already trained `trained_teacher_model` (in evaluation mode) to provide soft targets.
3.  Define a new training step that incorporates this combined loss.
4.  Train the student and compare its performance against the baseline student.


In [None]:
# --- Knowledge Distillation Specific Functions ---

def distillation_loss_calculation(
    student_logits: jax.Array,
    teacher_logits: jax.Array, # Teacher logits should be from teacher in eval mode
    temperature: float
) -> jax.Array:
    """
    Calculates the knowledge distillation loss.
    Uses cross-entropy between softened teacher predictions and softened student predictions.
    """
    # Soften probabilities with temperature
    # Teacher provides soft targets (probabilities)
    soft_teacher_targets = jax.nn.softmax(teacher_logits / temperature, axis=-1)

    # Student's output (log probabilities for numerical stability with cross-entropy)
    log_soft_student_probs = jax.nn.log_softmax(student_logits / temperature, axis=-1)

    # Distillation loss: Cross-entropy between teacher's soft targets and student's soft log-probabilities
    # This is equivalent to minimizing KL divergence KL(P_teacher_soft || P_student_soft)
    # when the P_teacher_soft part of KL divergence is considered constant wrt student params.
    kd_loss = -jnp.sum(soft_teacher_targets * log_soft_student_probs, axis=-1).mean()

    # Scale by T^2 as suggested in Hinton et al. (2015)
    # This scaling ensures that the relative contribution of the distillation loss
    # is maintained as temperature changes the scale of the logits.
    return kd_loss * (temperature**2)


def combined_loss_and_logits_for_distillation(
    student_model_for_grad: nnx.Module, # The student model instance being differentiated
    teacher_model_eval: nnx.Module,     # The pre-trained teacher model (used in eval mode)
    batch: Tuple[jax.Array, jax.Array],
    temperature: float,
    ce_student_weight: float, # Weight for the student's cross-entropy loss with true labels
    kd_loss_weight: float,    # Weight for the distillation loss
):
    """
    Calculates the combined loss (CE + KD) and student logits.
    This function will be differentiated w.r.t. student_model_for_grad.
    """
    images, true_labels = batch

    # 1. Forward pass for student
    student_logits = student_model_for_grad(images)

    # 2. Standard Cross-Entropy loss for student with true labels
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=student_logits, labels=true_labels
    ).mean()

    # 3. Forward pass for teacher (in evaluation mode, no gradients for teacher)
    # Ensure teacher_model_eval is not updated by JAX's autodiff by not including its
    # parameters in the differentiation target. NNX handles this if teacher_model_eval
    # is just a regular argument and not the one specified for `nnx.value_and_grad`.
    teacher_logits_eval = teacher_model_eval(images) # Teacher always in eval mode

    # 4. Knowledge Distillation loss
    kd_loss = distillation_loss_calculation(student_logits, teacher_logits_eval, temperature)

    # 5. Combined loss
    total_loss = (ce_student_weight * ce_loss) + (kd_loss_weight * kd_loss)

    return total_loss, student_logits # Return student_logits for accuracy calculation


@nnx.jit
def train_step_distillation(
    student_model: nnx.Module,        # This is optimizer_student_kd.module
    optimizer_student_kd: nnx.Optimizer,
    teacher_model_eval: nnx.Module,   # Pre-trained teacher model
    metrics_aggregator: nnx.MultiMetric,
    batch: Tuple[jax.Array, jax.Array],
    temperature: float,
    ce_student_weight: float,
    kd_loss_weight: float
):
    """
    Performs a single training step with knowledge distillation.
    """
    # Define the function to be differentiated.
    # It needs to take student_model as its first argument for nnx.value_and_grad.
    # Other arguments (teacher, batch, T, weights) are "closed over" or passed as static args.
    # For nnx.value_and_grad, if we differentiate w.r.t. student_model,
    # other nnx.Module arguments like teacher_model_eval should be treated as static if they are not
    # part of the differentiation target.

    def loss_fn_wrapper(model_to_diff): # model_to_diff will be student_model
        return combined_loss_and_logits_for_distillation(
            model_to_diff, # Student model (target of differentiation)
            teacher_model_eval,
            batch,
            temperature,
            ce_student_weight,
            kd_loss_weight,
        )

    grad_fn = nnx.value_and_grad(loss_fn_wrapper, has_aux=True)

    # Execute grad_fn. student_model is updated here.
    (loss, student_logits_for_metric), grads = grad_fn(student_model)

    metrics_aggregator.update(loss=loss, logits=student_logits_for_metric, labels=batch[1])
    optimizer_student_kd.update(grads) # Updates student_model in-place

# Note: The eval_step remains the same as before, as it only evaluates the student.
# We will reuse the `eval_step` defined in Part 4.

In [None]:
# --- Setup for Knowledge Distillation Training ---
# The main training loop `train_and_evaluate` needs to be adapted or a new one created
# to handle the specific arguments of `train_step_distillation`.
# Let's create a new specific training loop for distillation for clarity.

def train_and_evaluate_distillation(
    student_model_to_train: nnx.Module,
    teacher_model_for_guidance: nnx.Module,
    optax_optimizer: optax.GradientTransformation,
    train_ds: Any,
    test_ds: Any,
    num_epochs: int,
    steps_per_epoch_train: int,
    steps_per_epoch_test: int,
    temperature: float,
    ce_weight: float,
    kd_weight: float,
    eval_every_epochs: int = 1,
    model_name: str = "Student_KD"
) -> Tuple[nnx.Module, Dict[str, list]]:

    optimizer = nnx.Optimizer(student_model_to_train, optax_optimizer)

    train_metrics_aggregator = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
        accuracy=nnx.metrics.Accuracy()
    )
    eval_metrics_aggregator = nnx.MultiMetric( # For student's performance on true labels
        loss=nnx.metrics.Average("loss"),
        accuracy=nnx.metrics.Accuracy()
    )

    history = {
        'train_loss': [], 'train_accuracy': [],
        'test_loss': [], 'test_accuracy': []
    }

    total_training_steps = num_epochs * steps_per_epoch_train
    print(f"\nStarting distillation training for {model_name} for {num_epochs} epochs...")

    for epoch in tqdm(range(num_epochs), desc=f"Training {model_name}"):
        student_model_to_train.train()
        teacher_model_for_guidance.eval()
        train_metrics_aggregator.reset()
        train_epoch_bar = tqdm(iter(train_ds), total=steps_per_epoch_train,
                               desc=f"Epoch {epoch+1}/{num_epochs} [Distill Train]",
                               leave=False)
        for batch in train_epoch_bar:
            batch_images, batch_labels = batch
            batch_tuple = (jnp.asarray(batch_images), jnp.asarray(batch_labels))

            train_step_distillation(
                student_model_to_train,
                optimizer,
                teacher_model_for_guidance,
                train_metrics_aggregator,
                batch_tuple,
                temperature,
                ce_weight,
                kd_weight
            )

        computed_train_metrics = train_metrics_aggregator.compute()
        history['train_loss'].append(computed_train_metrics['loss'].item())
        history['train_accuracy'].append(computed_train_metrics['accuracy'].item())
        tqdm.write(f"{model_name} - Epoch {epoch+1} Training: Avg Loss: {computed_train_metrics['loss']:.4f}, Avg Acc: {computed_train_metrics['accuracy']:.4f}")

        if (epoch + 1) % eval_every_epochs == 0 or (epoch + 1) == num_epochs:
            student_model_to_train.eval()
            eval_metrics_aggregator.reset()
            eval_epoch_bar = tqdm(iter(test_ds), total=steps_per_epoch_test,
                                  desc=f"Epoch {epoch+1}/{num_epochs} [Distill Eval]",
                                  leave=False)
            for eval_batch in eval_epoch_bar:
                eval_batch_images, eval_batch_labels = eval_batch
                eval_batch_tuple = (jnp.asarray(eval_batch_images), jnp.asarray(eval_batch_labels))
                # Evaluate the student model using the standard eval_step
                eval_step(student_model_to_train, eval_metrics_aggregator, eval_batch_tuple)

            computed_eval_metrics = eval_metrics_aggregator.compute()
            history['test_loss'].append(computed_eval_metrics['loss'].item())
            history['test_accuracy'].append(computed_eval_metrics['accuracy'].item())
            tqdm.write(f"{model_name} - Epoch {epoch+1} Evaluation: Avg Loss: {computed_eval_metrics['loss']:.4f}, Avg Acc: {computed_eval_metrics['accuracy']:.4f}")
        tqdm.write("-" * 70)

    print(f"Distillation training finished for {model_name}.")

    # Plotting (same as before)
    evaluated_test_epochs = [e + 1 for e in range(num_epochs) if (e + 1) % eval_every_epochs == 0 or (e + 1) == num_epochs]
    plt.figure(figsize=(14, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), history['train_loss'], label='Train Loss', marker='o')
    if history['test_loss']:
        plt.plot(evaluated_test_epochs, history['test_loss'], label='Test Loss', marker='x', linestyle='--')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title(f'{model_name}: Loss'); plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), history['train_accuracy'], label='Train Accuracy', marker='o')
    if history['test_accuracy']:
        plt.plot(evaluated_test_epochs, history['test_accuracy'], label='Test Accuracy', marker='x', linestyle='--')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title(f'{model_name}: Accuracy'); plt.legend(); plt.grid(True)
    plt.tight_layout(); plt.show()

    return student_model_to_train, history



In [None]:
print("="*30)
print("Preparing for Knowledge Distillation Training")
print("="*30)

student_model_kd = LightNN_NNX(num_classes=NUM_CLASSES, rngs=student_rngs)
print("Student model for distillation initialized with the same initial parameters as baseline student.")

# 3. Define hyperparameters for distillation
TEMPERATURE = 2.0  # Common value, can be tuned
CE_STUDENT_WEIGHT = 0.25 # Weight for the student's own CE loss on true labels
KD_LOSS_WEIGHT = 0.75    # Weight for the distillation loss from teacher

# 4. Create an optimizer for this new student model
student_kd_optimizer = optax.adam(learning_rate=LEARNING_RATE) # Can be same or different LR

print("\n" + "="*30)
print("Training Student Model with Knowledge Distillation")
print("="*30)

restored = False
try:
  restore_model(student_model_kd, "student_model_kd_final")
  trained_student_model_kd = student_model_kd
  restored = True
except Exception as e:
  print(f"Restore failed: {e}")
  restored=False

if not restored:
  trained_student_model_kd, student_kd_history = train_and_evaluate_distillation(
      student_model_to_train=student_model_kd, # The re-initialized student
      teacher_model_for_guidance=trained_teacher_model, # From Part 5
      optax_optimizer=student_kd_optimizer,
      train_ds=train_ds,
      test_ds=test_ds,
      num_epochs=NUM_EPOCHS_STUDENT, # Same number of epochs as baseline student
      steps_per_epoch_train=steps_per_epoch_train,
      steps_per_epoch_test=steps_per_epoch_test,
      temperature=TEMPERATURE,
      ce_weight=CE_STUDENT_WEIGHT,
      kd_weight=KD_LOSS_WEIGHT,
      model_name="Student_KD (LightNN + Distill)"
  )

In [None]:
try:
  student_model_kd_save_path = os.path.join(checkpoint_dir, "student_model_kd_final")
  ocp.StandardCheckpointer().save(student_model_kd_save_path, nnx.state(trained_student_model_kd, nnx.Param))
  print(f"Student model parameters saved to {student_model_kd_save_path}")
except Exception as e:
  print(f"Could not save student model kd: {e}")

In [None]:
!zip -r /tmp/flax_nnx_kd_checkpoints/student_model_kd_final.zip /tmp/flax_nnx_kd_checkpoints/student_model_kd_final

files.download("/tmp/flax_nnx_kd_checkpoints/student_model_kd_final.zip")

In [None]:
def compare_model_accuracies(models: Dict[str, nnx.Module], test_ds: Any) -> Dict[str, float]:
  output = {}
  for model_name, model in models.items():
    eval_metrics_aggregator = nnx.MultiMetric( # For student's performance on true labels
          accuracy=nnx.metrics.Accuracy()
    )
    model.eval()
    eval_metrics_aggregator.reset()
    for eval_batch in tqdm(iter(test_ds), total=steps_per_epoch_test, desc=f"Eval for {model_name}", leave=False):
      eval_batch_images, eval_batch_labels = eval_batch
      eval_batch_tuple = (jnp.asarray(eval_batch_images), jnp.asarray(eval_batch_labels))
      # Evaluate the student model using the standard eval_step
      eval_step(model, eval_metrics_aggregator, eval_batch_tuple)

    computed_eval_metrics = eval_metrics_aggregator.compute()
    output[model_name] = computed_eval_metrics['accuracy'].item()
    tqdm.write(f"{model_name} Evaluation: Avg Acc: {computed_eval_metrics['accuracy']:.4f}")
  return output

In [None]:
# Store results for final comparison table
baseline_accuracies = compare_model_accuracies({
    "Teacher": teacher_model,
    "Student_Baseline": student_model,
    "Student_KD (Output Logits)": student_model_kd
}, test_ds)

print("\n--- Accuracy Summary So Far ---")
for model_name, acc in baseline_accuracies.items():
    if isinstance(acc, str):
        print(f"{model_name}: {acc}")
    else:
        print(f"{model_name}: {acc*100:.2f}%")

## Part 7: Experiment 3 - Distilling from Intermediate Representations

Previously, we performed knowledge distillation by matching the student's final output (logits) to the teacher's. Now, we'll explore a more nuanced approach: distilling knowledge from the **intermediate feature representations** learned within the teacher model.

**Why Distill from Intermediate Layers?**

* **Richer Supervisory Signals:** Intermediate layers of a well-trained teacher often capture complex data invariances and semantic information. Guiding the student to mimic these internal representations can provide a more potent learning signal.
* **Hint-Based Learning:** This is akin to the teacher providing "hints" (as in FitNets by Romero et al., 2014) to the student about how to form its own internal features, rather than just matching the final answer.
* **Bridging Capacity Gaps:** For a student with significantly less capacity than the teacher, directly matching the teacher's final output might be too challenging. Learning from intermediate "stepping stones" can be more effective.

**Capturing Intermediates in NNX with `self.sow()`**

Flax NNX provides a mechanism to "sow" or record intermediate values during the forward pass of a module. We can use `self.sow(VariableType, name, value)` within a module's `__call__` method.
* `self.sow(nnx.Intermediate, 'my_feature_map', features_to_capture)`: This call will associate the `features_to_capture` with an `nnx.Intermediate` variable named `my_feature_map`.
* After a forward pass, this "sown" variable (which becomes an attribute of the model instance if it's the first time `sow` is called with that name for that type) can be accessed, e.g., `model.my_feature_map.value`.
* Alternatively, all intermediates of a certain type can be retrieved using `nnx.state(model, nnx.Intermediate)`.

**Our Approach:**
1.  Modify our `DeepNN_NNX` and `LightNN_NNX` base classes by creating subclasses that use `self.sow()` in their `__call__` methods to capture the feature maps produced right before the classifier.
2.  Implement distillation techniques that use these captured intermediate features.


In [None]:
class ModifiedDeep(DeepNN_NNX):
  """Deeper CNN for the teacher model with feature capture."""

  def __call__(self, x: jax.Array) -> jax.Array:
    x = self.features_block1(x)
    x = self.features_block2(x)
    self.sow(nnx.Intermediate, 'deep_feature_map', x) # Capture Feature Map in self.deep_feature_map
    x = x.reshape((x.shape[0], -1))  # Flatten
    x = self.classifier(x)
    return x


class ModifiedLight(LightNN_NNX):
  """Lightweight CNN for the student model."""

  def __call__(self, x: jax.Array) -> jax.Array:
    x = nnx.relu(self.conv1(x))
    x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
    x = nnx.relu(self.conv2(x))
    x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='VALID')
    self.sow(nnx.Intermediate, 'light_feature_map', x) # Capture Feature Map in self.light_feature_map
    x = x.reshape((x.shape[0], -1))  # Flatten
    x = self.classifier(x)
    return x

In [None]:
# Initialize Teacher Modified Model
teacher_model_cosine_pt = ModifiedDeep(num_classes=NUM_CLASSES, rngs=teacher_rngs)
print("Teacher Model Modified Initialized.")

# Initialize Student Modified Model
student_model_cosine_pt = ModifiedLight(num_classes=NUM_CLASSES, rngs=student_rngs)
print("Student Model Modified Initialized.")

print("\nTeacher Model Modified Summary:")
print(nnx.tabulate(teacher_model_cosine_pt, dummy_images))

print("\nStudent Model Modified Summary:")
print(nnx.tabulate(student_model_cosine_pt, dummy_images))

## Part 8: Experiment 3a - Cosine Loss on Hidden States (Teacher Feature Pooling)

In this first variant of intermediate feature distillation, we aim to make the student's internal feature representations similar to a downsampled version of the teacher's. We will use **cosine similarity** as the metric.

**Key Steps:**

1.  **Feature Extraction (via `self.sow`)**:
    * The **teacher model** (`ModifiedDeep`) provides its intermediate feature map (e.g., `deep_feature_map`).
    * The **student model** (`ModifiedLight`) provides its intermediate feature map (e.g., `light_feature_map`).

2.  **Dimensionality Matching (Teacher Pooling)**:
    * The teacher's feature map (e.g., `B, 8, 8, 32` -> flattened to `B, 2048`) has a higher dimensionality than the student's (e.g., `B, 8, 8, 16` -> flattened to `B, 1024`).
    * To match dimensions for the cosine loss, we will **flatten both feature maps** and then apply **1D average pooling to the teacher's flattened feature vector** to reduce its dimensionality to match the student's. This pooling will be done *within the loss function itself*, not by modifying the teacher model's architecture.

3.  **Loss Calculation:**
    * **Cosine Similarity Loss:** Calculated between the student's flattened features and the teacher's *pooled and flattened* features.
    * **Cross-Entropy Loss:** The standard classification loss on the student's final output and the true labels.
    * The total loss will be a weighted sum of these two.

In [None]:
# --- Cosine Similarity Loss Function ---
def cosine_similarity_loss_fn(x1: jax.Array, x2: jax.Array, epsilon: float = 1e-8) -> jax.Array:
  """
  Computes 1 - cosine_similarity between two batches of vectors (mean over batch).
  x1, x2 are expected to be 2D arrays (batch_size, num_features).
  """
  x1_norm = x1 / (jnp.linalg.norm(x1, axis=-1, keepdims=True) + epsilon)
  x2_norm = x2 / (jnp.linalg.norm(x2, axis=-1, keepdims=True) + epsilon)
  similarity = jnp.sum(x1_norm * x2_norm, axis=-1)
  return jnp.mean(1.0 - similarity)


# --- Combined Loss Function for Cosine Distillation (Teacher Pooling) ---
def combined_loss_logits_features_cosine_pooled_teacher(
    student_model: ModifiedLight,   # Student model (sows its features)
    teacher_model: ModifiedDeep,    # Teacher model (sows its features)
    batch: Tuple[jax.Array, jax.Array],
    ce_student_weight: float,
    cosine_loss_weight: float
):
    """
    Calculates combined loss (CE + Cosine Sim with teacher feature pooling)
    and returns student logits.
    """
    images, true_labels = batch

    # 1. Forward pass for Student to get logits and sow its intermediate features
    student_logits = student_model(images)
    student_sown_state = nnx.state(student_model, nnx.Intermediate) # Get all Intermediate variables
    student_fm_raw = student_sown_state['light_feature_map'].value[0]
    student_flattened_features = student_fm_raw.reshape((student_fm_raw.shape[0], -1))
    # student_flattened_features shape: (B, 1024) for LightNN

    # 2. Standard Cross-Entropy loss for student
    ce_loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=student_logits, labels=true_labels
    ).mean()

    # 3. Forward pass for Teacher (in eval mode) to sow its intermediate features
    _ = teacher_model(images)
    teacher_sown_state = nnx.state(teacher_model, nnx.Intermediate)
    teacher_fm_raw = teacher_sown_state['deep_feature_map'].value[0]
    teacher_flattened_features_full = teacher_fm_raw.reshape((teacher_fm_raw.shape[0], -1))
    # teacher_flattened_features_full shape: (B, 2048) for DeepNN

    # 4. Dimensionality matching: Pool teacher's flattened features down to student's dimension
    # Teacher: 2048 features, Student: 1024 features. Pooling factor = 2.
    if teacher_flattened_features_full.shape[-1] == 2 * student_flattened_features.shape[-1]:
        pool_factor = 2
        # Reshape for 1D pooling: (B, num_features) -> (B, num_features/pool_factor, pool_factor)
        # Then take the mean over the last axis.
        teacher_reshaped_for_pool = teacher_flattened_features_full.reshape(
            (teacher_flattened_features_full.shape[0], -1, pool_factor)
        )
        teacher_pooled_features = jnp.mean(teacher_reshaped_for_pool, axis=-1)
    elif teacher_flattened_features_full.shape[-1] == student_flattened_features.shape[-1]:
        # Dimensions already match, no pooling needed (should not happen with current models)
        teacher_pooled_features = teacher_flattened_features_full
    else:
        raise ValueError(f"Teacher feature dim {teacher_flattened_features_full.shape[-1]} "
                         f"cannot be easily pooled to match student dim {student_flattened_features.shape[-1]} "
                         f"with a simple factor of 2 pooling.")

    # teacher_pooled_features should now be (B, 1024)

    # 5. Cosine Similarity Loss
    cos_loss = cosine_similarity_loss_fn(student_flattened_features, teacher_pooled_features)

    # 6. Combined loss
    total_loss = (ce_student_weight * ce_loss) + (cosine_loss_weight * cos_loss)

    return total_loss, student_logits

@nnx.jit
def train_step_cosine_pooled_teacher(
    student_sow_model: ModifiedLight,
    optimizer_student_cosine_pt: nnx.Optimizer,
    teacher_sow_model: ModifiedDeep, # Pre-trained teacher, used in eval
    metrics_aggregator: nnx.MultiMetric,
    batch: Tuple[jax.Array, jax.Array],
    ce_weight: float,
    cosine_weight: float
):
    """
    Training step with CE loss and Cosine feature distillation loss (teacher pooled).
    """
    def loss_fn_wrapper(model_to_diff, teacher_model): # model_to_diff will be student_sow_model
        return combined_loss_logits_features_cosine_pooled_teacher(
            model_to_diff, # Student model
            teacher_model,
            batch,
            ce_weight,
            cosine_weight
        )

    grad_fn = nnx.value_and_grad(loss_fn_wrapper, has_aux=True) # has_aux for (loss, logits)

    (loss, student_logits_for_metric), grads = grad_fn(student_sow_model, teacher_sow_model)

    metrics_aggregator.update(loss=loss, logits=student_logits_for_metric, labels=batch[1])
    optimizer_student_cosine_pt.update(grads) # Updates student_sow_model in-place


In [None]:
def train_and_evaluate_cosine_pooled_distillation(
    student_model_to_train: ModifiedLight,
    teacher_model_for_guidance: ModifiedDeep,
    optax_optimizer: optax.GradientTransformation,
    train_ds: Any,
    test_ds: Any,
    num_epochs: int,
    steps_per_epoch_train: int,
    steps_per_epoch_test: int,
    ce_weight: float,
    cosine_weight: float,
    eval_every_epochs: int = 1,
    model_name: str = "Student_Cosine_PooledTeacher"
) -> Tuple[nnx.Module, Dict[str, list]]:

    optimizer = nnx.Optimizer(student_model_to_train, optax_optimizer)

    train_metrics_aggregator = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
        accuracy=nnx.metrics.Accuracy()
    )
    eval_metrics_aggregator = nnx.MultiMetric( # For student's performance on true labels
        loss=nnx.metrics.Average("loss"),
        accuracy=nnx.metrics.Accuracy()
    )

    history = {
        'train_loss': [], 'train_accuracy': [],
        'test_loss': [], 'test_accuracy': []
    }

    total_training_steps = num_epochs * steps_per_epoch_train
    print(f"\nStarting distillation training for {model_name} for {num_epochs} epochs...")

    for epoch in tqdm(range(num_epochs), desc=f"Training {model_name}"):
        student_model_to_train.train()
        teacher_model_for_guidance.eval()
        train_metrics_aggregator.reset()
        train_epoch_bar = tqdm(iter(train_ds), total=steps_per_epoch_train,
                               desc=f"Epoch {epoch+1}/{num_epochs} [Distill Train]",
                               leave=False)
        for batch in train_epoch_bar:
            batch_images, batch_labels = batch
            batch_tuple = (jnp.asarray(batch_images), jnp.asarray(batch_labels))

            train_step_cosine_pooled_teacher(
                student_model_to_train,
                optimizer,
                teacher_model_for_guidance,
                train_metrics_aggregator,
                batch_tuple,
                ce_weight,
                cosine_weight
            )

        computed_train_metrics = train_metrics_aggregator.compute()
        history['train_loss'].append(computed_train_metrics['loss'].item())
        history['train_accuracy'].append(computed_train_metrics['accuracy'].item())
        tqdm.write(f"{model_name} - Epoch {epoch+1} Training: Avg Loss: {computed_train_metrics['loss']:.4f}, Avg Acc: {computed_train_metrics['accuracy']:.4f}")

        if (epoch + 1) % eval_every_epochs == 0 or (epoch + 1) == num_epochs:
            student_model_to_train.eval()
            eval_metrics_aggregator.reset()
            eval_epoch_bar = tqdm(iter(test_ds), total=steps_per_epoch_test,
                                  desc=f"Epoch {epoch+1}/{num_epochs} [Distill Eval]",
                                  leave=False)
            for eval_batch in eval_epoch_bar:
                eval_batch_images, eval_batch_labels = eval_batch
                eval_batch_tuple = (jnp.asarray(eval_batch_images), jnp.asarray(eval_batch_labels))
                # Evaluate the student model using the standard eval_step
                eval_step(student_model_to_train, eval_metrics_aggregator, eval_batch_tuple)

            computed_eval_metrics = eval_metrics_aggregator.compute()
            history['test_loss'].append(computed_eval_metrics['loss'].item())
            history['test_accuracy'].append(computed_eval_metrics['accuracy'].item())
            tqdm.write(f"{model_name} - Epoch {epoch+1} Evaluation: Avg Loss: {computed_eval_metrics['loss']:.4f}, Avg Acc: {computed_eval_metrics['accuracy']:.4f}")
        tqdm.write("-" * 70)

    print(f"Distillation training finished for {model_name}.")

    # Plotting (same as before)
    evaluated_test_epochs = [e + 1 for e in range(num_epochs) if (e + 1) % eval_every_epochs == 0 or (e + 1) == num_epochs]
    plt.figure(figsize=(14, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), history['train_loss'], label='Train Loss', marker='o')
    if history['test_loss']:
        plt.plot(evaluated_test_epochs, history['test_loss'], label='Test Loss', marker='x', linestyle='--')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title(f'{model_name}: Loss'); plt.legend(); plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), history['train_accuracy'], label='Train Accuracy', marker='o')
    if history['test_accuracy']:
        plt.plot(evaluated_test_epochs, history['test_accuracy'], label='Test Accuracy', marker='x', linestyle='--')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title(f'{model_name}: Accuracy'); plt.legend(); plt.grid(True)
    plt.tight_layout(); plt.show()

    return student_model_to_train, history



In [None]:
print("="*30)
print("Preparing for Knowledge Distillation Training")
print("="*30)

student_model_kd = LightNN_NNX(num_classes=NUM_CLASSES, rngs=student_rngs)
print("Student model for distillation initialized with the same initial parameters as baseline student.")

# 3. Define hyperparameters for distillation
CE_STUDENT_WEIGHT_COSINE_PT = 0.25 # Weight for student's own CE loss
COSINE_LOSS_WEIGHT_PT = 0.75     # Weight for cosine similarity loss

# 4. Create an optimizer for this new student model
student_cosine_pt_optimizer = optax.adam(learning_rate=LEARNING_RATE) # Can be same or different LR

# 5. Update teacher_model_modified with teacher_model weights
nnx.update(teacher_model_cosine_pt, nnx.state(teacher_model))

print("\n" + "="*30)
print("Training Student with Cosine Distillation (Teacher Feature Pooling)")
print("="*30)

restored = False
try:
  restore_model(student_model_cosine_pt, "student_cosine_pt_final")
  trained_student_cosine_pt = student_model_kd
  restored = True
except Exception as e:
  print(f"Restore failed: {e}")
  restored=False

if not restored:
  trained_student_cosine_pt, student_cosine_pt_history = train_and_evaluate_cosine_pooled_distillation(
      student_model_to_train=student_model_cosine_pt,
      teacher_model_for_guidance=teacher_model_cosine_pt,
      optax_optimizer=student_cosine_pt_optimizer,
      train_ds=train_ds,
      test_ds=test_ds,
      num_epochs=NUM_EPOCHS_STUDENT, # Same number of epochs as baseline student
      steps_per_epoch_train=steps_per_epoch_train,
      steps_per_epoch_test=steps_per_epoch_test,
      ce_weight=CE_STUDENT_WEIGHT_COSINE_PT,
      cosine_weight=COSINE_LOSS_WEIGHT_PT,
  )

In [None]:
try:
  student_cosine_pt_save_path = os.path.join(checkpoint_dir, "student_cosine_pt_final")
  ocp.StandardCheckpointer().save(student_cosine_pt_save_path, nnx.state(trained_student_cosine_pt, nnx.Param))
  print(f"Student model parameters saved to {student_cosine_pt_save_path}")
except Exception as e:
  print(f"Could not save student model cosine pt: {e}")

In [None]:
!zip -r /tmp/flax_nnx_kd_checkpoints/student_cosine_pt_final.zip /tmp/flax_nnx_kd_checkpoints/student_cosine_pt_final

files.download("/tmp/flax_nnx_kd_checkpoints/student_cosine_pt_final.zip")

In [None]:
# Store results for final comparison table
baseline_accuracies = compare_model_accuracies({
    "Teacher": teacher_model,
    "Student_Baseline": student_model,
    "Student_KD (Output Logits)": student_model_kd,
    "Student_Cosine_PooledTeacher": student_model_cosine_pt,
}, test_ds)

print("\n--- Accuracy Summary So Far ---")
for model_name, acc in baseline_accuracies.items():
    if isinstance(acc, str):
        print(f"{model_name}: {acc}")
    else:
        print(f"{model_name}: {acc*100:.2f}%")