# BLG-454E Learning From Data Term Project: Data Distillation

## Ömer Faruk San-150220307
## Mustafa Kerem Bulut-150220303

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import copy # For deep copying models
import os
from tqdm.notebook import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

Using device: cuda


In [4]:
# CIFAR-10 constants
IMG_SIZE = 32
N_CHANNELS = 3
N_CLASSES = 10

# Data preprocessing
# For CIFAR-10, the mean and std are standard values.
transform_train = transforms.Compose([
    transforms.RandomCrop(IMG_SIZE, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load datasets
train_dataset_full = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# DataLoaders
batch_size_real = 128 # Batch size for training on real data
train_loader_full = DataLoader(train_dataset_full, batch_size=batch_size_real, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

print(f"Full training dataset size: {len(train_dataset_full)}")
print(f"Test dataset size: {len(test_dataset)}")

# Function to display images (optional, for verification)
def imshow(img_tensor, title=None):
    img_tensor = img_tensor.cpu() / 2 + 0.5  # Unnormalize (approximate)
    npimg = img_tensor.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    if title:
        plt.title(title)
    plt.show()

# Example: Get some random training images
dataiter = iter(train_loader_full)
images, labels = next(dataiter)
# imshow(torchvision.utils.make_grid(images[:4])) # Requires torchvision
# print(' '.join(f'{train_dataset_full.classes[labels[j]]:5s}' for j in range(4)))

Files already downloaded and verified
Files already downloaded and verified
Full training dataset size: 50000
Test dataset size: 10000


In [5]:
class ConvNet(nn.Module):
    def __init__(self, num_classes=N_CLASSES):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(N_CHANNELS, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32 -> 16x16
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 16x16 -> 8x8
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Instantiate the model and move to device
# For now, this is our "teacher_model_architecture" and also "student_model_architecture"
model_architecture = ConvNet 

# Example:
# model = model_architecture(num_classes=N_CLASSES).to(device)
# summary(model, (N_CHANNELS, IMG_SIZE, IMG_SIZE)) # Requires torchsummary

In [11]:
# --- Configuration for Baseline Teacher Training ---
teacher_epochs = 15 # Number of epochs to train the teacher model
lr_teacher = 0.01
momentum_teacher = 0.9
weight_decay_teacher = 5e-4
# Trajectory saving: Save parameters at the end of these specific epochs
# This makes the trajectory shorter and easier to handle initially.
# The paper matches trajectories over a few steps (e.g. 10-20 inner loop steps for student)
# We might need to save parameters more frequently if we want to match short trajectories.
# Let's start by saving at the end of each epoch for a few epochs.
# TRAJECTORY_SAVE_EPOCHS = [0, 1, 2, 3, 4, 9, 19, 29, 39, 49] # Example: save at these epochs
TRAJECTORY_SAVE_EPOCHS = list(range(teacher_epochs)) # Save at the end of every epoch for now

# Path to save teacher model and trajectory
TEACHER_MODEL_PATH = "teacher_model.pth"
TEACHER_TRAJECTORY_PATH = "teacher_trajectory.pth"

# --- Helper function for training one epoch ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for inputs, labels in tqdm(dataloader, desc="Training Epoch", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

# --- Helper function for evaluating the model ---
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions / total_samples
    return epoch_loss, epoch_acc

# --- Train the Teacher Model and Save Trajectory ---
teacher_model = model_architecture(num_classes=N_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=lr_teacher, momentum=momentum_teacher, weight_decay=weight_decay_teacher)
# Optional: Learning rate scheduler
scheduler_teacher = optim.lr_scheduler.CosineAnnealingLR(optimizer_teacher, T_max=teacher_epochs)

teacher_trajectory = [] # List to store model parameters at specified points

print("Starting Teacher Model Training...")

# Initial state before training (optional, but often epoch 0 means after 1st epoch)
# If you want to match from the *very first* random initialization, save here.
# For simplicity, Cazenavette et al. often re-initialize student networks for each outer loop step.
# Let's save after the first epoch (epoch 0).

for epoch in range(teacher_epochs):
    epoch_train_loss, epoch_train_acc = train_epoch(teacher_model, train_loader_full, optimizer_teacher, criterion, device)
    scheduler_teacher.step() # Update learning rate

    # Evaluate on test set (optional during teacher training, but good for monitoring)
    epoch_test_loss, epoch_test_acc = evaluate_model(teacher_model, test_loader, criterion, device)

    print(f"Teacher Epoch {epoch+1}/{teacher_epochs}:")
    print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
    print(f"  Test Loss: {epoch_test_loss:.4f}, Test Acc: {epoch_test_acc:.4f}")
    print(f"  LR: {optimizer_teacher.param_groups[0]['lr']:.6f}")


    # Save model parameters for the trajectory
    if epoch in TRAJECTORY_SAVE_EPOCHS:
        # We need to save a deep copy of the state_dict, otherwise it's just a reference
        # and will change as the model continues to train.
        trajectory_point = {k: v.detach().clone().cpu() for k, v in teacher_model.state_dict().items()}
        teacher_trajectory.append(trajectory_point)
        print(f"  Saved parameters to trajectory at epoch {epoch+1}")

# Save the final teacher model
torch.save(teacher_model.state_dict(), TEACHER_MODEL_PATH)
print(f"\nTeacher model saved to {TEACHER_MODEL_PATH}")

# Save the trajectory
torch.save(teacher_trajectory, TEACHER_TRAJECTORY_PATH)
print(f"Teacher trajectory (len: {len(teacher_trajectory)}) saved to {TEACHER_TRAJECTORY_PATH}")

# Verify saved trajectory (optional)
# loaded_trajectory = torch.load(TEACHER_TRAJECTORY_PATH)
# print(f"Loaded trajectory length: {len(loaded_trajectory)}")
# print(f"Keys in first trajectory point: {loaded_trajectory[0].keys()}")

Starting Teacher Model Training...


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 1/15:
  Train Loss: 1.6088, Train Acc: 0.4061
  Test Loss: 1.3629, Test Acc: 0.4871
  LR: 0.009891
  Saved parameters to trajectory at epoch 1


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 2/15:
  Train Loss: 1.2846, Train Acc: 0.5372
  Test Loss: 1.1631, Test Acc: 0.5857
  LR: 0.009568
  Saved parameters to trajectory at epoch 2


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 3/15:
  Train Loss: 1.1049, Train Acc: 0.6067
  Test Loss: 1.0334, Test Acc: 0.6523
  LR: 0.009045
  Saved parameters to trajectory at epoch 3


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 4/15:
  Train Loss: 0.9994, Train Acc: 0.6469
  Test Loss: 0.7971, Test Acc: 0.7197
  LR: 0.008346
  Saved parameters to trajectory at epoch 4


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 5/15:
  Train Loss: 0.9218, Train Acc: 0.6777
  Test Loss: 0.7937, Test Acc: 0.7226
  LR: 0.007500
  Saved parameters to trajectory at epoch 5


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 6/15:
  Train Loss: 0.8448, Train Acc: 0.7039
  Test Loss: 0.7413, Test Acc: 0.7434
  LR: 0.006545
  Saved parameters to trajectory at epoch 6


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 7/15:
  Train Loss: 0.7953, Train Acc: 0.7216
  Test Loss: 0.6747, Test Acc: 0.7661
  LR: 0.005523
  Saved parameters to trajectory at epoch 7


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 8/15:
  Train Loss: 0.7411, Train Acc: 0.7400
  Test Loss: 0.5924, Test Acc: 0.7978
  LR: 0.004477
  Saved parameters to trajectory at epoch 8


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 9/15:
  Train Loss: 0.6993, Train Acc: 0.7577
  Test Loss: 0.5765, Test Acc: 0.8018
  LR: 0.003455
  Saved parameters to trajectory at epoch 9


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 10/15:
  Train Loss: 0.6646, Train Acc: 0.7702
  Test Loss: 0.5724, Test Acc: 0.8022
  LR: 0.002500
  Saved parameters to trajectory at epoch 10


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 11/15:
  Train Loss: 0.6373, Train Acc: 0.7811
  Test Loss: 0.5352, Test Acc: 0.8131
  LR: 0.001654
  Saved parameters to trajectory at epoch 11


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 12/15:
  Train Loss: 0.5978, Train Acc: 0.7932
  Test Loss: 0.5189, Test Acc: 0.8192
  LR: 0.000955
  Saved parameters to trajectory at epoch 12


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 13/15:
  Train Loss: 0.5789, Train Acc: 0.7993
  Test Loss: 0.5117, Test Acc: 0.8228
  LR: 0.000432
  Saved parameters to trajectory at epoch 13


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 14/15:
  Train Loss: 0.5651, Train Acc: 0.8047
  Test Loss: 0.4916, Test Acc: 0.8290
  LR: 0.000109
  Saved parameters to trajectory at epoch 14


Training Epoch:   0%|          | 0/391 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher Epoch 15/15:
  Train Loss: 0.5531, Train Acc: 0.8104
  Test Loss: 0.4867, Test Acc: 0.8301
  LR: 0.000000
  Saved parameters to trajectory at epoch 15

Teacher model saved to teacher_model.pth
Teacher trajectory (len: 15) saved to teacher_trajectory.pth


In [12]:
# --- Configuration for Synthetic Dataset ---
IPC = 10 # Images Per Class, as per your proposal
TOTAL_SYNTHETIC_IMAGES = IPC * N_CLASSES

# --- Initialize Synthetic Data ---

# Initialize Synthetic Images (X_syn)
# Option 1: Random Noise
# X_syn_init_data = torch.randn(TOTAL_SYNTHETIC_IMAGES, N_CHANNELS, IMG_SIZE, IMG_SIZE, device=device)

# Option 2: From Real Images (often preferred as a starting point)
# We'll select IPC images from each class from the original training set.
def initialize_synthetic_images_from_real(train_dataset, num_classes, ipc, img_size, n_channels):
    """
    Initializes synthetic images by picking IPC images from each class
    from the real dataset.
    """
    X_syn_list = []
    Y_syn_list = []
    
    # Create a temporary dataloader to iterate by class easily (if not already structured that way)
    # Alternatively, iterate through the dataset and collect indices per class
    class_indices = [[] for _ in range(num_classes)]
    for i, (_, label) in enumerate(train_dataset):
        if len(class_indices[label]) < ipc:
            class_indices[label].append(i)
        # Optimization: if all classes have enough samples, break
        if all(len(indices) == ipc for indices in class_indices):
            break
            
    if not all(len(indices) == ipc for indices in class_indices):
        print("Warning: Could not find enough unique images for all classes to initialize synthetic data.")
        # Fallback or error handling needed if a class has < IPC samples, though unlikely for CIFAR-10.

    for c in range(num_classes):
        indices_c = class_indices[c][:ipc] # Take the first ipc indices for class c
        for idx in indices_c:
            img, label = train_dataset[idx] # img is already transformed
            X_syn_list.append(img)
            Y_syn_list.append(torch.tensor(label)) # Store label as tensor

    X_syn_init_data = torch.stack(X_syn_list).to(device)
    # Y_syn will be fixed based on this initialization order, or we can create it explicitly
    
    print(f"Initialized X_syn with shape: {X_syn_init_data.shape}")
    return X_syn_init_data

# Choose initialization method:
INIT_METHOD = 'real' # 'random' or 'real'

if INIT_METHOD == 'real':
    # We need the original train_dataset *without* on-the-fly random augmentations
    # for consistent initialization if we run this multiple times.
    # The transform_test is suitable here as it only does ToTensor and Normalize.
    train_dataset_for_init = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_test)
    X_syn_data = initialize_synthetic_images_from_real(train_dataset_for_init, N_CLASSES, IPC, IMG_SIZE, N_CHANNELS)
else: # 'random'
    X_syn_data = torch.randn(TOTAL_SYNTHETIC_IMAGES, N_CHANNELS, IMG_SIZE, IMG_SIZE, device=device)
    # For random init, you might want to scale it to a typical image range or normalize later.
    # For now, raw randn output.

# These synthetic images are the parameters we want to learn.
# So, they need to require gradients.
X_syn = X_syn_data.detach().clone().requires_grad_(True)
# X_syn.data.clamp_(0, 1) # Optional: clamp if you initialized randomly and want to keep in [0,1] before normalization
                         # If initialized from real (normalized) data, this might not be needed or done differently.
                         # Since CIFAR-10 is normalized around 0, clamping to [0,1] is not directly applicable here
                         # unless you unnormalize first. For now, let's assume it's fine.

# Initialize Synthetic Labels (Y_syn)
# These are typically fixed and not learned.
# Create labels: IPC images for class 0, then IPC for class 1, etc.
y_syn_list = []
for c in range(N_CLASSES):
    y_syn_list.extend([c] * IPC)
Y_syn = torch.tensor(y_syn_list, dtype=torch.long, device=device)

print(f"Shape of X_syn: {X_syn.shape}") # Should be [100, 3, 32, 32]
print(f"Shape of Y_syn: {Y_syn.shape}") # Should be [100]
print(f"Sample Y_syn: {Y_syn}")

# Optimizer for the synthetic data (X_syn) itself
# This is the "outer loop" optimizer
lr_syn = 1.0 # Learning rate for updating synthetic images. This is a key hyperparameter!
             # Some papers use much smaller LRs like 0.001 or 0.01, depends on the loss scale.
             # Cazenavette et al. use Adam with lr=0.1 for X_syn.
optimizer_X_syn = optim.Adam([X_syn], lr=lr_syn)
# optimizer_X_syn = optim.SGD([X_syn], lr=lr_syn, momentum=0.5) # Another option

# We might also want to learn per-sample learning rates or other parameters,
# but let's start with just learning X_syn.

Initialized X_syn with shape: torch.Size([100, 3, 32, 32])
Shape of X_syn: torch.Size([100, 3, 32, 32])
Shape of Y_syn: torch.Size([100])
Sample Y_syn: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9], device='cuda:0')


In [1]:
# --- Configuration for Distillation Loop ---
distill_epochs = 1000       # Number of outer loop optimization steps for X_syn
student_inner_steps = 10    # Number of SGD steps to train student on X_syn in each outer step
lr_student_inner = 0.01     # LR for student model training on X_syn
# lr_syn was defined in Step 4 for optimizer_X_syn (e.g., Adam with lr=0.1 or SGD with lr=1.0)
# Re-confirm optimizer_X_syn from Step 4. If you used SGD there, lr_syn=1.0 might be okay to start.
# If Adam, lr_syn=0.1 is more common. Let's assume Adam as per Cazenavette et al. for X_syn.
# If optimizer_X_syn was SGD(lr=1.0), you might need to adjust.
# Re-define optimizer_X_syn here for clarity or if you want to change it.
# Assuming X_syn is already defined and requires_grad=True from Step 4.
optimizer_X_syn = optim.Adam([X_syn], lr=0.1, betas=(0.5, 0.999)) # Common choice from papers

# Load teacher trajectory (ensure this was saved successfully in Step 3)
try:
    teacher_trajectory = torch.load(TEACHER_TRAJECTORY_PATH, map_location=torch.device('cpu')) # Load to CPU first
    print(f"Loaded teacher trajectory with {len(teacher_trajectory)} points.")
    if not teacher_trajectory: # Empty list
        raise FileNotFoundError
except FileNotFoundError:
    print(f"Error: Teacher trajectory not found or empty at {TEACHER_TRAJECTORY_PATH}. Please run Step 3 first.")
    # For notebooks, you might want to raise an error to stop execution
    # For scripts, sys.exit() or raise is appropriate.
    raise SystemExit("Teacher trajectory missing or empty.")


# Criterion for student training (already defined as `criterion` global)
# criterion = nn.CrossEntropyLoss()

print("\nStarting Dataset Distillation Loop...")
if not X_syn.is_leaf or not X_syn.requires_grad:
     print("Warning: X_syn is not a leaf tensor or does not require gradients!")
     X_syn.requires_grad_(True) # Ensure it does

for epoch_distill in tqdm(range(distill_epochs), desc="Distillation Epochs"):
    optimizer_X_syn.zero_grad()

    # 1. Initialize Student Model and Select Teacher States
    student_model = model_architecture(num_classes=N_CLASSES).to(device)
    
    # Pick a starting point from the teacher trajectory
    teacher_idx_start = epoch_distill % len(teacher_trajectory)
    teacher_start_params_cpu = teacher_trajectory[teacher_idx_start]
    
    # Load these params into the student model.
    student_model.load_state_dict({k: v.clone().to(device) for k, v in teacher_start_params_cpu.items()})

    # Student model requires its own optimizer for the inner loop
    optimizer_student = optim.SGD(student_model.parameters(), lr=lr_student_inner, momentum=0.5)

    # 2. Train Student Model on current X_syn for `student_inner_steps`
    for _inner_step in range(student_inner_steps):
        outputs_student = student_model(X_syn) # X_syn is used here
        loss_student = criterion(outputs_student, Y_syn)
        
        optimizer_student.zero_grad()
        loss_student.backward() 
        optimizer_student.step()

    # 3. Select Teacher Target Parameters
    if len(teacher_trajectory) == 1:
        teacher_idx_target = teacher_idx_start
    else:
        teacher_idx_target = (teacher_idx_start + 1) % len(teacher_trajectory)
    teacher_target_params_cpu = teacher_trajectory[teacher_idx_target]
    
    # ////////////////////////////////////////////////////////////////////////
    # /// START OF REVISED META-LOSS CALCULATION ///
    # ////////////////////////////////////////////////////////////////////////
    
    # 4. Calculate Trajectory Matching Loss (Meta-Loss)
    meta_loss_components = []
    
    # Prepare teacher target parameters on the correct device for easy lookup
    teacher_target_params_dict_device = {k: v.clone().to(device) for k, v in teacher_target_params_cpu.items()}

    # Iterate through the student model's named parameters.
    # These parameters were updated in the inner loop and have grad_fns linking back to X_syn.
    for name, param_student in student_model.named_parameters():
        if param_student.requires_grad: # Only consider learnable parameters
            if name in teacher_target_params_dict_device:
                param_teacher_target = teacher_target_params_dict_device[name]
                # param_teacher_target is detached (from trajectory)
                # param_student has a grad_fn from the inner loop
                loss_component = torch.sum((param_student - param_teacher_target)**2)
                meta_loss_components.append(loss_component)
            else:
                # This should ideally not happen if architectures are identical
                # and state_dicts are complete.
                print(f"Warning during meta-loss: Student parameter '{name}' not found in teacher target parameters. Skipping.")
    
    if not meta_loss_components:
        print(f"CRITICAL Error at Distill Epoch {epoch_distill}: No components found for meta_loss. Meta-loss will be zero.")
        print(f"  Teacher start idx: {teacher_idx_start}, Teacher target idx: {teacher_idx_target}")
        print(f"  Number of student_model.named_parameters(): {len(list(student_model.named_parameters()))}")
        print(f"  Number of teacher_target_params_dict_device keys: {len(teacher_target_params_dict_device.keys())}")
        # This is a fallback, but indicates a serious issue if it happens.
        meta_loss = torch.tensor(0.0, device=device, requires_grad=True) 
    else:
        meta_loss = torch.stack(meta_loss_components).sum()
        

    # 5. Backpropagate Meta-Loss and Update X_syn
    meta_loss.backward()
    optimizer_X_syn.step()
    
    if epoch_distill % 50 == 0 or epoch_distill == distill_epochs - 1 :
        print(f"Distill Epoch {epoch_distill+1}/{distill_epochs}, Meta Loss: {meta_loss.item():.4f}")
        # print(f"  X_syn stats: min={X_syn.min().item():.4f}, max={X_syn.max().item():.4f}, mean={X_syn.mean().item():.4f}")

# --- End of Distillation Loop ---

distilled_data_path = 'distilled_dataset_final.pth'
torch.save({'X_syn': X_syn.detach().cpu(), 
            'Y_syn': Y_syn.cpu(),
            'N_CLASSES': N_CLASSES,
            'IPC': IPC}, 
            distilled_data_path)
print(f"Final distilled dataset saved to {distilled_data_path}")

# --- Optional: Visualize some final synthetic images ---
# (Visualization code would go here, same as before)
# Define MEAN_CIFAR and STD_CIFAR if you haven't globally
MEAN_CIFAR = torch.tensor([0.4914, 0.4822, 0.4465], device='cpu').view(1, 3, 1, 1)
STD_CIFAR = torch.tensor([0.2023, 0.1994, 0.2010], device='cpu').view(1, 3, 1, 1)

X_syn_to_show = X_syn.detach().cpu()
# Ensure X_syn_to_show has the same channel order as MEAN/STD if needed for broadcasting
# Assuming X_syn_to_show is [N, C, H, W]
X_syn_unnormalized = X_syn_to_show * STD_CIFAR + MEAN_CIFAR # Unnormalize
X_syn_unnormalized = torch.clamp(X_syn_unnormalized, 0, 1) # Clamp to valid image range

fig, axes = plt.subplots(N_CLASSES, IPC, figsize=(IPC * 1.2, N_CLASSES * 1.2))
if N_CLASSES == 1 and IPC == 1: # Handle single image case
    axes = np.array([[axes]])
elif N_CLASSES == 1: 
    axes = axes.reshape(1, IPC)
elif IPC == 1: 
    axes = axes.reshape(N_CLASSES, 1)


for i_class in range(N_CLASSES):
    for j_ipc in range(IPC):
        idx = i_class * IPC + j_ipc
        if idx < X_syn_unnormalized.shape[0]:
            img = X_syn_unnormalized[idx]
            current_ax = axes[i_class, j_ipc]
            current_ax.imshow(img.permute(1, 2, 0).numpy()) # CHW to HWC for matplotlib
            current_ax.axis('off')
            if j_ipc == 0: # Add class label to the first image of each row
                current_ax.set_title(f"C{i_class}", fontsize=8) # Using train_dataset_full.classes[i_class] would be better if available

plt.suptitle(f"Distilled Images (End of Distillation)", fontsize=10)
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust for suptitle
plt.show()

NameError: name 'optim' is not defined

In [3]:
# --- Configuration for Evaluating Distilled Dataset ---
eval_epochs_on_distilled = 200 # Number of epochs to train a NEW model on the distilled set
lr_eval = 0.01                 # Learning rate for training on distilled set
batch_size_distilled = min(IPC * N_CLASSES, 100) # Use all 100 images in a batch, or smaller if preferred
                                              # For 100 images, full-batch GD is feasible.

# --- Load the Distilled Dataset ---
distilled_data_path = 'distilled_dataset_final.pth'
try:
    saved_data = torch.load(distilled_data_path)
    X_syn_final = saved_data['X_syn'].to(device) # Move to device for training
    Y_syn_final = saved_data['Y_syn'].to(device) # Move to device
    N_CLASSES_loaded = saved_data['N_CLASSES']
    IPC_loaded = saved_data['IPC']
    print(f"Loaded distilled dataset: X_syn shape {X_syn_final.shape}, Y_syn shape {Y_syn_final.shape}")
    assert N_CLASSES_loaded == N_CLASSES
    assert IPC_loaded == IPC
except FileNotFoundError:
    print(f"Error: Distilled dataset not found at {distilled_data_path}. Please run Step 5 first.")
    raise SystemExit("Distilled dataset missing.")

# Create DataLoader for the distilled dataset
distilled_dataset = TensorDataset(X_syn_final, Y_syn_final)
distilled_loader = DataLoader(distilled_dataset, batch_size=batch_size_distilled, shuffle=True)

# --- Initialize and Train a New Model on the Distilled Dataset ---
eval_model = model_architecture(num_classes=N_CLASSES).to(device) # Fresh model
criterion_eval = nn.CrossEntropyLoss() # Same criterion as before
optimizer_eval = optim.SGD(eval_model.parameters(), lr=lr_eval, momentum=0.9, weight_decay=5e-4) # Standard SGD
# Optional: Learning rate scheduler for this evaluation training
# scheduler_eval = optim.lr_scheduler.CosineAnnealingLR(optimizer_eval, T_max=eval_epochs_on_distilled)
scheduler_eval = optim.lr_scheduler.MultiStepLR(optimizer_eval, milestones=[int(0.5*eval_epochs_on_distilled), int(0.75*eval_epochs_on_distilled)], gamma=0.1)


print("\nStarting Evaluation: Training a new model on the distilled dataset...")
for epoch_eval in tqdm(range(eval_epochs_on_distilled), desc="Eval Epochs on Distilled Data"):
    eval_model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Since the distilled dataset is small, one epoch might just be one batch if batch_size_distilled is large enough.
    for inputs_syn, labels_syn in distilled_loader: # inputs_syn and labels_syn are already on device
        optimizer_eval.zero_grad()
        outputs = eval_model(inputs_syn)
        loss = criterion_eval(outputs, labels_syn)
        loss.backward()
        optimizer_eval.step()

        running_loss += loss.item() * inputs_syn.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels_syn.size(0)
        correct_predictions += (predicted == labels_syn).sum().item()
    
    scheduler_eval.step()

    epoch_train_loss = running_loss / total_samples
    epoch_train_acc = correct_predictions / total_samples

    # Evaluate on the *real* CIFAR-10 test set
    # The `evaluate_model` function was defined in Step 3
    # def evaluate_model(model, dataloader, criterion, device): ...
    epoch_test_loss, epoch_test_acc = evaluate_model(eval_model, test_loader, criterion_eval, device)
    
    if (epoch_eval + 1) % 20 == 0 or epoch_eval == eval_epochs_on_distilled - 1:
        print(f"Eval Epoch {epoch_eval+1}/{eval_epochs_on_distilled}:")
        print(f"  Train Loss (on distilled): {epoch_train_loss:.4f}, Train Acc (on distilled): {epoch_train_acc:.4f}")
        print(f"  Test Loss (on CIFAR-10 test): {epoch_test_loss:.4f}, Test Acc (on CIFAR-10 test): {epoch_test_acc:.4f}")
        print(f"  LR: {optimizer_eval.param_groups[0]['lr']:.6f}")


# Final performance on the test set
final_test_loss, final_test_acc = evaluate_model(eval_model, test_loader, criterion_eval, device)
print("\n--- Evaluation Complete ---")
print(f"Final Test Accuracy of model trained on distilled data: {final_test_acc*100:.2f}%")

NameError: name 'IPC' is not defined