In [None]:
import os
import sys
import json
import io
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, ConcatDataset, Dataset, random_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from contextlib import redirect_stdout
import numpy as np

# -------------------------------
# Logging Utilities
# -------------------------------
class Tee(object):
    def __init__(self, *fileobjects):
        self.fileobjects = fileobjects
    def write(self, text):
        for f in self.fileobjects:
            f.write(text)
            f.flush()
    def flush(self):
        for f in self.fileobjects:
            f.flush()

# -------------------------------
# Paths and Checkpoints
# -------------------------------
log_file_path = "./logs/cifar10_training_log.txt"
metrics_file_path = "./metrics/cifar10_training_metrics.json"
best_model_path = "./models/cifar10_best_student_model.pth"
save_path = "./models/cifar10_student_model"
teacher_checkpoint = "/notebooks/Resnet18/models/cifar10_best_model"  # Pre-trained teacher

# -------------------------------
# Data Augmentations (No Mixup)
# -------------------------------
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

# Define three augmentation pipelines for the enriched dataset.
transform_a = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm
])

transform_b = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    norm
])

transform_c = transforms.Compose([
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    norm
])

# For evaluation and adversarial generation, use a simple transform.
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    norm
])

# -------------------------------
# Load CIFAR-10 Datasets
# -------------------------------
# Create an enriched dataset using three copies with different augmentations.
dataset_a = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_a)
dataset_b = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_b)
dataset_c = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_c)
enriched_dataset = ConcatDataset([dataset_a, dataset_b, dataset_c])

# Full training set with eval_transform (for generating adversarial examples).
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=eval_transform)

# Split the full training set for validation.
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset_eval, valset = random_split(full_trainset, [train_size, val_size])

# Test set.
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)

# -------------------------------
# FGSM Adversarial Example Generator
# -------------------------------
def generate_FGSM(image, label, teacher, epsilon=0.01):
    """
    Generates an adversarial example using the FGSM method.
    
    Args:
        image (torch.Tensor): A normalized image tensor (range [-1,1]).
        label (int): The true label of the image.
        teacher (torch.nn.Module): The teacher model.
        epsilon (float): Perturbation magnitude.
        
    Returns:
        torch.Tensor: The adversarial image.
    """
    # Clone the image and enable gradient computation.
    image = image.clone().detach().to(device)
    image.requires_grad = True

    # Forward pass.
    teacher.zero_grad()
    logits = teacher(image.unsqueeze(0))  # add batch dimension

    # Compute loss using the ground-truth label.
    loss = nn.CrossEntropyLoss()(logits, torch.tensor([label]).to(device))
    loss.backward()
    # Get the sign of the gradients.
    data_grad = image.grad.data
    # Create the adversarial example by perturbing the image.
    perturbed_image = image + epsilon * data_grad.sign()
    # Ensure the perturbed image remains in the valid range (for normalized images, this is [-1, 1]).
    perturbed_image = torch.clamp(perturbed_image, -1, 1)
    
    return perturbed_image.detach()

# -------------------------------
# Prepare the Teacher Model (ResNet18 for CIFAR-10)
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = models.resnet18(weights=None)
# Modify the first convolution and maxpool for CIFAR-10.
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, 10)

# Load teacher checkpoint if it exists.
if os.path.exists(teacher_checkpoint):
    teacher_state = torch.load(teacher_checkpoint, map_location=device)
    teacher.load_state_dict(teacher_state)
else:
    print("Warning: Teacher checkpoint not found. Using an untrained teacher model.")

teacher.to(device)
teacher.eval()

# -------------------------------
# Create FGSM-Based Adversarial Dataset
# -------------------------------
# Determine the fraction of full_trainset to perturb.
adv_fraction = 0.1  
num_adv_samples = int(len(full_trainset) * adv_fraction)
print(f"Generating FGSM adversarial examples for {num_adv_samples} samples...")

adv_samples = []
# Randomly select indices for FGSM adversarial generation.
selected_indices = np.random.choice(len(full_trainset), num_adv_samples, replace=False)
# FGSM epsilon value (tunable).
fgsm_epsilon = 0.01

for idx in tqdm(selected_indices, desc="Generating FGSM samples"):
    image, label = full_trainset[idx]
    adv_image = generate_FGSM(image, label, teacher, epsilon=fgsm_epsilon)
    # Keep the original label.
    adv_samples.append((adv_image.cpu(), label))

# Wrap the adversarial examples in a simple Dataset.
class AdversarialDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]  # returns (image, label)

adv_dataset = AdversarialDataset(adv_samples)
# Combine enriched (augmented) dataset with FGSM adversarial examples.
combined_dataset = ConcatDataset([enriched_dataset, adv_dataset])

# -------------------------------
# DataLoaders (No Mixup)
# -------------------------------
train_loader = DataLoader(combined_dataset, batch_size=256, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=256, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

# -------------------------------
# Prepare Teacher and Student Models for Distillation
# -------------------------------
# The teacher model is already prepared above and remains frozen.
# Prepare the student model (SqueezeNet1_1 adapted for CIFAR-10).
student = models.squeezenet1_1(pretrained=False)
student.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
student.num_classes = 10
student.to(device)

# Optionally resume student training from a checkpoint.
start_checkpoint = None
if start_checkpoint:
    student_state = torch.load(start_checkpoint, map_location=device)
    student.load_state_dict(student_state)

# Freeze teacher parameters.
for param in teacher.parameters():
    param.requires_grad = False

# -------------------------------
# Training Hyperparameters and Loss Setup
# -------------------------------
num_epochs = 200
learning_rate = 0.001
checkpoint_frequency = 10

# Distillation loss: KL divergence with temperature scaling.
kl_loss = nn.KLDivLoss(reduction='batchmean')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
temperature = 4.0

# CIFAR-10 class names (for logging if desired).
cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                 'dog', 'frog', 'horse', 'ship', 'truck']

# -------------------------------
# Metrics and Logging Setup
# -------------------------------
metrics = {
    "epochs": [],
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}
best_val_acc = 0.0
best_epoch = None
log_capture = io.StringIO()
tee = Tee(sys.stdout, log_capture)
writer = SummaryWriter(log_dir='./runs/cifar10_distillation_experiment')

# -------------------------------
# Training Loop: Knowledge Distillation with FGSM Augmentation
# -------------------------------
with redirect_stdout(tee):
    print("Training student (SqueezeNet) with teacher (ResNet18) distillation and FGSM adversarial augmentation...")
    for epoch in range(num_epochs):
        student.train()
        running_loss = 0.0
        total = 0
        correct = 0

        if epoch % checkpoint_frequency == 0:
            torch.save(student.state_dict(), f"{save_path}_{epoch}.pth")

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass through teacher (for soft targets).
            with torch.no_grad():
                teacher_logits = teacher(inputs)
            # Forward pass through student.
            student_logits = student(inputs)

            # Compute distillation loss using temperature-scaled KL divergence.
            teacher_soft = torch.softmax(teacher_logits / temperature, dim=1)
            student_log_soft = torch.log_softmax(student_logits / temperature, dim=1)
            loss_distill = (temperature ** 2) * kl_loss(student_log_soft, teacher_soft)
            loss = loss_distill

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total += inputs.size(0)
            _, predicted = student_logits.max(1)
            correct += predicted.eq(labels).sum().item()

        train_epoch_loss = running_loss / total
        train_epoch_acc = 100. * correct / total

        # ----- Validation Loop -----
        student.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss_val = criterion(outputs, labels)
                val_loss += loss_val.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        val_epoch_loss = val_loss / val_total
        val_epoch_acc = 100. * val_correct / val_total

        metrics["epochs"].append(epoch + 1)
        metrics["train_loss"].append(train_epoch_loss)
        metrics["train_acc"].append(train_epoch_acc)
        metrics["val_loss"].append(val_epoch_loss)
        metrics["val_acc"].append(val_epoch_acc)

        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_acc:.2f}% | Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.2f}%")
        writer.add_scalar('Loss/Train', train_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Train', train_epoch_acc, epoch)
        writer.add_scalar('Loss/Validation', val_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Validation', val_epoch_acc, epoch)

        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_epoch = epoch + 1
            torch.save(student.state_dict(), best_model_path)
            print(f"New best student model found at epoch {epoch+1} with Val Acc: {val_epoch_acc:.2f}%")

        with open(metrics_file_path, "w") as f:
            json.dump(metrics, f, indent=4)

    print("Training complete.")
    torch.save(student.state_dict(), save_path + "_final.pth")

writer.close()

# Save captured logs.
with open(log_file_path, "w") as f:
    f.write(log_capture.getvalue())

print("Training log captured and saved to", log_file_path)
print("Training metrics saved to", metrics_file_path)
print("Best student model saved from epoch", best_epoch, "with validation accuracy of", best_val_acc)


2025-04-13 03:40:03.125950: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-13 03:40:03.126046: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-13 03:40:03.130116: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-13 03:40:03.144867: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 96050175.82it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Generating FGSM adversarial examples for 5000 samples...


Generating FGSM samples: 100%|██████████| 5000/5000 [00:47<00:00, 105.81it/s]

Training student (SqueezeNet) with teacher (ResNet18) distillation and FGSM adversarial augmentation...



Epoch 1/200: 100%|██████████| 606/606 [00:34<00:00, 17.33it/s]


Epoch [1/200] - Train Loss: 6.9489, Train Acc: 17.18% | Val Loss: 1.8601, Val Acc: 30.39%
New best student model found at epoch 1 with Val Acc: 30.39%


Epoch 2/200: 100%|██████████| 606/606 [00:36<00:00, 16.40it/s]


Epoch [2/200] - Train Loss: 5.5457, Train Acc: 30.41% | Val Loss: 1.8068, Val Acc: 38.54%
New best student model found at epoch 2 with Val Acc: 38.54%


Epoch 3/200: 100%|██████████| 606/606 [00:37<00:00, 16.30it/s]


Epoch [3/200] - Train Loss: 4.9497, Train Acc: 37.69% | Val Loss: 1.6160, Val Acc: 45.21%
New best student model found at epoch 3 with Val Acc: 45.21%


Epoch 4/200: 100%|██████████| 606/606 [00:37<00:00, 16.25it/s]


Epoch [4/200] - Train Loss: 4.5365, Train Acc: 42.06% | Val Loss: 1.7525, Val Acc: 47.42%
New best student model found at epoch 4 with Val Acc: 47.42%


Epoch 5/200: 100%|██████████| 606/606 [00:36<00:00, 16.49it/s]


Epoch [5/200] - Train Loss: 4.2473, Train Acc: 45.55% | Val Loss: 1.5415, Val Acc: 54.36%
New best student model found at epoch 5 with Val Acc: 54.36%


Epoch 6/200: 100%|██████████| 606/606 [00:36<00:00, 16.40it/s]


Epoch [6/200] - Train Loss: 4.0151, Train Acc: 47.97% | Val Loss: 1.5071, Val Acc: 57.38%
New best student model found at epoch 6 with Val Acc: 57.38%


Epoch 7/200: 100%|██████████| 606/606 [00:37<00:00, 16.24it/s]


Epoch [7/200] - Train Loss: 3.8248, Train Acc: 49.78% | Val Loss: 1.4025, Val Acc: 58.90%
New best student model found at epoch 7 with Val Acc: 58.90%


Epoch 8/200: 100%|██████████| 606/606 [00:37<00:00, 16.20it/s]


Epoch [8/200] - Train Loss: 3.6777, Train Acc: 51.32% | Val Loss: 1.4989, Val Acc: 58.92%
New best student model found at epoch 8 with Val Acc: 58.92%


Epoch 9/200: 100%|██████████| 606/606 [00:36<00:00, 16.42it/s]


Epoch [9/200] - Train Loss: 3.5390, Train Acc: 52.74% | Val Loss: 1.4708, Val Acc: 61.53%
New best student model found at epoch 9 with Val Acc: 61.53%


Epoch 10/200:  67%|██████▋   | 406/606 [00:24<00:12, 16.60it/s]