In [None]:
import os
import sys
import json
import io
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, ConcatDataset, Dataset, random_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from contextlib import redirect_stdout
import numpy as np

# -------------------------------
# Logging Utilities
# -------------------------------
class Tee(object):
    def __init__(self, *fileobjects):
        self.fileobjects = fileobjects
    def write(self, text):
        for f in self.fileobjects:
            f.write(text)
            f.flush()
    def flush(self):
        for f in self.fileobjects:
            f.flush()

# -------------------------------
# Paths and Checkpoints
# -------------------------------
log_file_path = "./logs/cifar10_training_log.txt"
metrics_file_path = "./metrics/cifar10_training_metrics.json"
best_model_path = "./models/cifar10_best_student_model.pth"
save_path = "./models/cifar10_student_model"
teacher_checkpoint = "/notebooks/Resnet18/models/cifar10_best_model"  # Pre-trained teacher

# -------------------------------
# Data Augmentations (No Mixup)
# -------------------------------
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

# Three different augmentation pipelines for the enriched dataset.
transform_a = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm
])

transform_b = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    norm
])

transform_c = transforms.Compose([
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    norm
])

# For evaluation (and also for generating adversarial examples), use a simple transform.
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    norm
])

# -------------------------------
# Load CIFAR-10 Datasets
# -------------------------------
# Enriched training dataset: three copies with different augmentations.
dataset_a = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_a)
dataset_b = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_b)
dataset_c = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_c)
enriched_dataset = ConcatDataset([dataset_a, dataset_b, dataset_c])

# Full training dataset with eval_transform.
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=eval_transform)

# Split full training set for validation.
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset_eval, valset = random_split(full_trainset, [train_size, val_size])

# Standard test set.
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)

# -------------------------------
# Define the Adversarial Example Generator
# -------------------------------
def generate_adversarial_trajectory(image, teacher, num_steps=10, step_size=0.001):
    """
    Generate a trajectory of adversarial images targeting a fixed class (the second-highest logit).
    
    Args:
      image (torch.Tensor): the normalized image tensor.
      teacher (torch.nn.Module): the teacher model.
      num_steps (int): number of gradient steps.
      step_size (float): step size for each update.
    
    Returns:
      List[torch.Tensor]: a list of adversarial images (one per step).
    """
    # Clone image and enable gradients.
    x_adv = image.clone().detach().to(device)
    x_adv.requires_grad = True
    adversarial_images = []
    
    # Determine fixed target using the second highest logit.
    with torch.no_grad():
        logits_init = teacher(x_adv.unsqueeze(0))
        sorted_logits, sorted_indices = torch.sort(logits_init, descending=True)
        fixed_target = sorted_indices[0, 1].item()  # second highest logit index
    
    # Generate adversarial trajectory.
    for step in range(num_steps):
        logits = teacher(x_adv.unsqueeze(0))
        loss = -logits[0, fixed_target]  # maximize fixed target logit (i.e. minimize negative)
        
        teacher.zero_grad()
        if x_adv.grad is not None:
            x_adv.grad.zero_()
        loss.backward()
        
        # Update adversarial image using the sign of the gradient.
        with torch.no_grad():
            x_adv = x_adv - step_size * x_adv.grad.sign()
        x_adv.requires_grad = True
        adversarial_images.append(x_adv.clone().detach())
        
    return adversarial_images

# -------------------------------
# Create Adversarial Dataset
# -------------------------------
# We now generate adversarial samples on a random subset of the full training set.
# Here, we generate one adversarial example per selected image by taking the final sample of the adversarial trajectory.
# Adjust adv_fraction as needed (e.g., 0.1 means 10% of full_trainset).
adv_fraction = 0.1  
num_adv_samples = int(len(full_trainset) * adv_fraction)
print(f"Generating adversarial examples for {num_adv_samples} samples...")

# Ensure the teacher model is set up before using it.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------
# Prepare Teacher Model (ResNet18 for CIFAR-10)
# -------------------------------
teacher = models.resnet18(weights=None)
# Adjust first conv layer and maxpool as appropriate for CIFAR-10.
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, 10)

# Load teacher checkpoint if available.
if os.path.exists(teacher_checkpoint):
    teacher_state = torch.load(teacher_checkpoint, map_location=device)
    teacher.load_state_dict(teacher_state)
else:
    print("Warning: Teacher checkpoint not found. Using an untrained teacher model.")

teacher.to(device)
teacher.eval()  # Set teacher to evaluation mode

# Create a list to hold (adversarial image, label) pairs.
adv_samples = []
# Randomly select indices from the full training set.
selected_indices = np.random.choice(len(full_trainset), num_adv_samples, replace=False)
# Number of gradient descent steps and step size for adversarial attack.
num_attack_steps = 10
attack_step_size = 0.001

for idx in tqdm(selected_indices, desc="Generating adversarial samples"):
    image, label = full_trainset[idx]
    # Generate the trajectory.
    adv_traj = generate_adversarial_trajectory(image, teacher, num_steps=num_attack_steps, step_size=attack_step_size)
    # Choose the final adversarial sample from the trajectory.
    adv_image = adv_traj[-1]
    # Append tuple (adv_image, label). Here we keep the original label.
    adv_samples.append((adv_image.cpu(), label))

# Create a simple Dataset for adversarial samples.
class AdversarialDataset(Dataset):
    def __init__(self, sample_list):
        self.samples = sample_list
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]  # returns (image, label)

adv_dataset = AdversarialDataset(adv_samples)

# Combine the enriched dataset with the adversarial dataset.
combined_dataset = ConcatDataset([enriched_dataset, adv_dataset])

# -------------------------------
# DataLoaders (No mixup used)
# -------------------------------
train_loader = DataLoader(combined_dataset, batch_size=256, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=256, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

# -------------------------------
# Prepare Teacher and Student Models for Distillation
# -------------------------------
# Teacher is already defined above. Teacher remains fixed.
# Prepare the student model (SqueezeNet1_1 adapted for CIFAR-10).
student = models.squeezenet1_1(pretrained=False)
student.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
student.num_classes = 10
student.to(device)

# Optionally resume student training from a checkpoint by setting start_checkpoint.
start_checkpoint = None
if start_checkpoint:
    student_state = torch.load(start_checkpoint, map_location=device)
    student.load_state_dict(student_state)

# Freeze teacher parameters.
for param in teacher.parameters():
    param.requires_grad = False

# -------------------------------
# Training Hyperparameters and Loss Setup
# -------------------------------
num_epochs = 200
learning_rate = 0.001
checkpoint_frequency = 10

# Distillation loss: KL divergence with temperature scaling.
kl_loss = nn.KLDivLoss(reduction='batchmean')
# Optionally, you can also include standard cross-entropy loss.
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
temperature = 4.0

# CIFAR-10 class names (for logging if desired).
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                 'dog', 'frog', 'horse', 'ship', 'truck']

# -------------------------------
# Metrics and Logging Setup
# -------------------------------
metrics = {
    "epochs": [],
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}
best_val_acc = 0.0
best_epoch = None
log_capture = io.StringIO()
tee = Tee(sys.stdout, log_capture)
writer = SummaryWriter(log_dir='./runs/cifar10_distillation_experiment')

# -------------------------------
# Training Loop: Knowledge Distillation with Adversarial Augmentation
# -------------------------------
with redirect_stdout(tee):
    print("Training student (SqueezeNet) with teacher (ResNet18) distillation and adversarial sample augmentation...")
    for epoch in range(num_epochs):
        student.train()
        running_loss = 0.0
        total = 0
        correct = 0

        # Save checkpoint at specified frequency.
        if epoch % checkpoint_frequency == 0:
            torch.save(student.state_dict(), f"{save_path}_{epoch}.pth")

        # Standard training loop over batches (no mixup).
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass through teacher (teacher outputs are used as soft targets).
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            # Forward pass through student.
            student_logits = student(inputs)
            
            # Compute distillation loss with temperature scaling.
            teacher_soft = torch.softmax(teacher_logits / temperature, dim=1)
            student_log_soft = torch.log_softmax(student_logits / temperature, dim=1)
            loss_distill = (temperature ** 2) * kl_loss(student_log_soft, teacher_soft)
            loss = loss_distill

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total += inputs.size(0)
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()

        train_epoch_loss = running_loss / total
        train_epoch_acc = 100. * correct / total

        # ----- Validation Loop (standard CIFAR-10 evaluation using cross-entropy loss) -----
        student.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss_val = criterion(outputs, labels)
                val_loss += loss_val.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        val_epoch_loss = val_loss / val_total
        val_epoch_acc = 100. * val_correct / val_total

        metrics["epochs"].append(epoch + 1)
        metrics["train_loss"].append(train_epoch_loss)
        metrics["train_acc"].append(train_epoch_acc)
        metrics["val_loss"].append(val_epoch_loss)
        metrics["val_acc"].append(val_epoch_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_acc:.2f}% | Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.2f}%")

        writer.add_scalar('Loss/Train', train_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Train', train_epoch_acc, epoch)
        writer.add_scalar('Loss/Validation', val_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Validation', val_epoch_acc, epoch)

        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_epoch = epoch + 1
            torch.save(student.state_dict(), best_model_path)
            print(f"New best student model found at epoch {epoch+1} with Val Acc: {val_epoch_acc:.2f}%")

        with open(metrics_file_path, "w") as f:
            json.dump(metrics, f, indent=4)
    
    print("Training complete.")
    torch.save(student.state_dict(), save_path + "_final.pth")

writer.close()

# Save captured logs.
with open(log_file_path, "w") as f:
    f.write(log_capture.getvalue())

print("Training log captured and saved to", log_file_path)
print("Training metrics saved to", metrics_file_path)
print("Best student model saved from epoch", best_epoch, "with validation accuracy of", best_val_acc)


2025-04-13 03:34:54.618088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-13 03:34:54.618163: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-13 03:34:54.619382: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-13 03:34:54.647204: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Generating adversarial examples for 5000 samples...


Generating adversarial samples: 100%|██████████| 5000/5000 [06:16<00:00, 13.30it/s]


Training student (SqueezeNet) with teacher (ResNet18) distillation and adversarial sample augmentation...


Epoch 1/200: 100%|██████████| 606/606 [00:37<00:00, 16.16it/s]


Epoch [1/200] - Train Loss: 6.8891, Train Acc: 19.90% | Val Loss: 1.8230, Val Acc: 33.17%
New best student model found at epoch 1 with Val Acc: 33.17%


Epoch 2/200: 100%|██████████| 606/606 [00:36<00:00, 16.53it/s]


Epoch [2/200] - Train Loss: 5.3871, Train Acc: 34.48% | Val Loss: 1.8410, Val Acc: 43.38%
New best student model found at epoch 2 with Val Acc: 43.38%


Epoch 3/200:  63%|██████▎   | 384/606 [00:24<00:13, 15.93it/s]