In [None]:
import os
import sys
import json
import io
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, ConcatDataset, random_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from contextlib import redirect_stdout
import numpy as np

# ----- Define Tee class for logging -----
class Tee(object):
    def __init__(self, *fileobjects):
        self.fileobjects = fileobjects
    def write(self, text):
        for f in self.fileobjects:
            f.write(text)
            f.flush()
    def flush(self):
        for f in self.fileobjects:
            f.flush()

# ----- Specify paths for saving logs, metrics, and models -----
log_file_path = "./logs/cifar10_training_log.txt"
metrics_file_path = "./metrics/cifar10_training_metrics.json"
best_model_path = "./models/cifar10_best_student_model.pth"
save_path = "./models/cifar10_student_model"
teacher_checkpoint = "/notebooks/Resnet18/models/cifar10_best_model"  # Pre-trained teacher

# ----- Define Data Augmentations for Enriched Training Data -----
# Note: we include normalization so that images are in a standard range.
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

transform_a = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    norm
])

transform_b = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    norm
])

transform_c = transforms.Compose([
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    norm
])

# For validation/testing, use a simple transform (tensor conversion and normalization)
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    norm
])

# Load CIFAR-10 training set three times with different augmentations.
dataset_a = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_a)
dataset_b = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_b)
dataset_c = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_c)

# Combine the three augmented datasets into one enriched dataset.
enriched_dataset = ConcatDataset([dataset_a, dataset_b, dataset_c])

# For validation and test, use the standard dataset with eval_transform.
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=eval_transform)
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset_eval, valset = random_split(full_trainset, [train_size, val_size])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=eval_transform)

# ----- Define Mixup Function and Custom Collate -----
def mixup_data(x, y, alpha=1.0):
    """Return mixed inputs, paired labels, and mixing coefficient."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_collate_fn(batch, alpha=0.4):
    """
    Each element in 'batch' is a tuple (image, label).
    Stack images and labels, then apply mixup.
    """
    images, labels = zip(*batch)
    images = torch.stack(images)
    labels = torch.tensor(labels)
    mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, alpha)
    return mixed_images, labels_a, labels_b, lam

# Create DataLoader for the enriched training dataset with mixup applied on the fly.
train_loader = DataLoader(enriched_dataset, batch_size=256, shuffle=True, num_workers=2,
                          collate_fn=lambda b: mixup_collate_fn(b, alpha=0.4))

# For validation and test, use standard DataLoaders.
val_loader = DataLoader(valset, batch_size=256, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)

# ----- Prepare Teacher and Student Models -----
# Teacher: ResNet18 modified for CIFAR-10.
teacher = models.resnet18(weights=None)
teacher.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
teacher.maxpool = nn.Identity()
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, 10)

# Student: SqueezeNet (we use squeezenet1_1) modified for CIFAR-10.
student = models.squeezenet1_1(pretrained=False)
# SqueezeNet's classifier is a Sequential: (Dropout, Conv2d, ReLU, AvgPool2d)
student.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
student.num_classes = 10

# ----- Device and Checkpoints Setup -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 200
learning_rate = 0.001
checkpoint_frequency = 10
start_epoch = 0
start_checkpoint = None  # Optionally, resume student training checkpoint

# Load teacher checkpoint (pre-trained on CIFAR-10) if available.
if os.path.exists(teacher_checkpoint):
    teacher_state = torch.load(teacher_checkpoint, map_location=device)
    teacher.load_state_dict(teacher_state)
else:
    print("Warning: Teacher checkpoint not found. Teacher model may be untrained.")

teacher.to(device)
teacher.eval()  # Teacher is fixed during distillation.
for param in teacher.parameters():
    param.requires_grad = False

# Optionally resume student training from a checkpoint.
if start_checkpoint:
    student_state = torch.load(start_checkpoint, map_location=device)
    student.load_state_dict(student_state)
student.to(device)

# ----- Setup Loss Functions and Optimizer -----
# Distillation loss: KL divergence with temperature scaling.
kl_loss = nn.KLDivLoss(reduction='batchmean')
# You can also include standard CE loss if desired.
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(student.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
temperature = 4.0

# ----- Initialize Metrics Storage -----
metrics = {
    "epochs": [],
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": []
}
best_val_acc = 0.0
best_epoch = None

# ----- Setup Logging (Tee and TensorBoard) -----
log_capture = io.StringIO()
tee = Tee(sys.stdout, log_capture)  # Prints and captures output
writer = SummaryWriter(log_dir='./runs/cifar10_distillation_experiment')

with redirect_stdout(tee):
    print("Training student (SqueezeNet) with teacher (ResNet18) distillation on enriched CIFAR-10 data...")
    for epoch in range(start_epoch, start_epoch + num_epochs):
        student.train()
        running_loss = 0.0
        total = 0
        correct = 0
        
        # Save a checkpoint at the specified frequency.
        if epoch % checkpoint_frequency == 0:
            torch.save(student.state_dict(), f"{save_path}_{epoch}.pth")
        
        # Training loop with tqdm progress bar.
        for mixed_inputs, labels_a, labels_b, lam in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Move data to device.
            mixed_inputs = mixed_inputs.to(device)
            # For mixup, we use original labels (two sets) and later combine losses if needed.
            labels_a = labels_a.to(device)
            labels_b = labels_b.to(device)
            
            optimizer.zero_grad()
            # Forward pass through the teacher on the same mixup inputs.
            with torch.no_grad():
                teacher_logits = teacher(mixed_inputs)
            # Forward pass through the student.
            student_logits = student(mixed_inputs)
            
            # Compute distillation loss using KL divergence with temperature scaling.
            teacher_soft = torch.softmax(teacher_logits / temperature, dim=1)
            student_log_soft = torch.log_softmax(student_logits / temperature, dim=1)
            loss_distill = (temperature * temperature) * kl_loss(student_log_soft, teacher_soft)
            
            # Optionally, one can add cross-entropy loss on original targets:
            # Compute CE loss on both sets and combine using the mixup coefficient:
            # loss_ce = lam * criterion(student_logits, labels_a) + (1 - lam) * criterion(student_logits, labels_b)
            # loss = loss_distill * lambda_distill + loss_ce * lambda_ce
            # For this script, we use only the KL divergence distillation loss.
            loss = loss_distill
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * mixed_inputs.size(0)
            total += mixed_inputs.size(0)
            # For training accuracy, use student logits to predict hard labels.
            _, predicted = student_logits.max(1)
            # For simplicity, assume that the true label is the one chosen with higher mixing weight.
            # (In practice, you might use a weighted combination; here we just use labels_a.)
            correct += predicted.eq(labels_a).sum().item()
        
        train_epoch_loss = running_loss / total
        train_epoch_acc = 100. * correct / total

        # ----- Validate on the Validation Set (without mixup) -----
        student.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student(inputs)
                loss = criterion(outputs, labels)  # Use standard cross-entropy for evaluation
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        val_epoch_loss = val_loss / val_total
        val_epoch_acc = 100. * val_correct / val_total

        metrics["epochs"].append(epoch + 1)
        metrics["train_loss"].append(train_epoch_loss)
        metrics["train_acc"].append(train_epoch_acc)
        metrics["val_loss"].append(val_epoch_loss)
        metrics["val_acc"].append(val_epoch_acc)

        print(f"Epoch [{epoch+1}/{start_epoch + num_epochs}] - Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_acc:.2f}% | Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.2f}%")
        
        # TensorBoard logging.
        writer.add_scalar('Loss/Train', train_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Train', train_epoch_acc, epoch)
        writer.add_scalar('Loss/Validation', val_epoch_loss, epoch)
        writer.add_scalar('Accuracy/Validation', val_epoch_acc, epoch)
        
        # Save the best model based on validation accuracy.
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            best_epoch = epoch + 1
            torch.save(student.state_dict(), best_model_path)
            print(f"New best student model found at epoch {epoch+1} with Val Acc: {val_epoch_acc:.2f}%")
        
        # Save metrics JSON after each epoch.
        with open(metrics_file_path, "w") as f:
            json.dump(metrics, f, indent=4)
    
    print("Training complete.")
    torch.save(student.state_dict(), save_path + "_final.pth")

# Close the TensorBoard writer.
writer.close()

# Write captured logs to the specified log file.
with open(log_file_path, "w") as f:
    f.write(log_capture.getvalue())

print("Training log captured and saved to", log_file_path)
print("Training metrics saved to", metrics_file_path)
print("Best student model saved from epoch", best_epoch, "with validation accuracy of", best_val_acc)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Training student (SqueezeNet) with teacher (ResNet18) distillation on enriched CIFAR-10 data...


Epoch 1/200: 100%|██████████| 586/586 [00:23<00:00, 25.10it/s]


Epoch [1/200] - Train Loss: 5.6375, Train Acc: 14.34% | Val Loss: 2.2623, Val Acc: 24.74%
New best student model found at epoch 1 with Val Acc: 24.74%


Epoch 2/200: 100%|██████████| 586/586 [00:22<00:00, 25.72it/s]


Epoch [2/200] - Train Loss: 4.8802, Train Acc: 19.40% | Val Loss: 1.8642, Val Acc: 37.61%
New best student model found at epoch 2 with Val Acc: 37.61%


Epoch 3/200: 100%|██████████| 586/586 [00:23<00:00, 25.41it/s]


Epoch [3/200] - Train Loss: 4.3838, Train Acc: 22.69% | Val Loss: 1.7658, Val Acc: 43.98%
New best student model found at epoch 3 with Val Acc: 43.98%


Epoch 4/200: 100%|██████████| 586/586 [00:23<00:00, 25.40it/s]


Epoch [4/200] - Train Loss: 4.0163, Train Acc: 23.56% | Val Loss: 1.7773, Val Acc: 45.73%
New best student model found at epoch 4 with Val Acc: 45.73%


Epoch 5/200: 100%|██████████| 586/586 [00:23<00:00, 24.78it/s]


Epoch [5/200] - Train Loss: 3.7601, Train Acc: 25.87% | Val Loss: 1.5069, Val Acc: 50.88%
New best student model found at epoch 5 with Val Acc: 50.88%


Epoch 6/200: 100%|██████████| 586/586 [00:22<00:00, 25.51it/s]


Epoch [6/200] - Train Loss: 3.6073, Train Acc: 27.12% | Val Loss: 1.5855, Val Acc: 53.32%
New best student model found at epoch 6 with Val Acc: 53.32%


Epoch 7/200: 100%|██████████| 586/586 [00:23<00:00, 24.85it/s]


Epoch [7/200] - Train Loss: 3.4166, Train Acc: 27.01% | Val Loss: 1.4784, Val Acc: 55.56%
New best student model found at epoch 7 with Val Acc: 55.56%


Epoch 8/200: 100%|██████████| 586/586 [00:23<00:00, 25.20it/s]


Epoch [8/200] - Train Loss: 3.3251, Train Acc: 28.62% | Val Loss: 1.4103, Val Acc: 56.61%
New best student model found at epoch 8 with Val Acc: 56.61%


Epoch 9/200: 100%|██████████| 586/586 [00:23<00:00, 25.37it/s]


Epoch [9/200] - Train Loss: 3.2460, Train Acc: 29.44% | Val Loss: 1.3726, Val Acc: 58.44%
New best student model found at epoch 9 with Val Acc: 58.44%


Epoch 10/200: 100%|██████████| 586/586 [00:23<00:00, 24.99it/s]


Epoch [10/200] - Train Loss: 3.1164, Train Acc: 29.24% | Val Loss: 1.3796, Val Acc: 59.99%
New best student model found at epoch 10 with Val Acc: 59.99%


Epoch 11/200: 100%|██████████| 586/586 [00:23<00:00, 24.93it/s]


Epoch [11/200] - Train Loss: 3.0353, Train Acc: 30.00% | Val Loss: 1.2708, Val Acc: 62.47%
New best student model found at epoch 11 with Val Acc: 62.47%


Epoch 12/200: 100%|██████████| 586/586 [00:22<00:00, 25.72it/s]


Epoch [12/200] - Train Loss: 2.9050, Train Acc: 29.59% | Val Loss: 1.2658, Val Acc: 63.94%
New best student model found at epoch 12 with Val Acc: 63.94%


Epoch 13/200: 100%|██████████| 586/586 [00:23<00:00, 25.47it/s]


Epoch [13/200] - Train Loss: 2.8105, Train Acc: 30.88% | Val Loss: 1.2570, Val Acc: 63.24%


Epoch 14/200: 100%|██████████| 586/586 [00:23<00:00, 24.58it/s]


Epoch [14/200] - Train Loss: 2.7521, Train Acc: 30.78% | Val Loss: 1.3240, Val Acc: 64.04%
New best student model found at epoch 14 with Val Acc: 64.04%


Epoch 15/200:  88%|████████▊ | 518/586 [00:20<00:02, 25.32it/s]