In [None]:
import os

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from google.colab import drive
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from tqdm import tqdm

device = 'cuda'
drive.mount('/content/drive')
DRIVE_CHECKPOINT_DIR = '/content/drive/MyDrive/CIFAR10-model5-training'
# os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64

        # conv
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # classification
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])

def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])

def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

In [None]:
DRIVE_CHECKPOINT_DIR = '/content/drive/MyDrive/CIFAR10-model5-training'
DISTILL_CHECKPOINT_DIR = '/content/drive/MyDrive/CIFAR10-model5-distillation'
# os.makedirs(DISTILL_CHECKPOINT_DIR, exist_ok=True)


def get_data_loaders(batch_size=128, num_workers=2):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=num_workers)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    return trainloader, testloader, classes

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=4.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.T = temperature
        self.criterion_ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        ce_loss = self.criterion_ce(student_logits, labels)

        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)

        kl_div = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T * self.T)

        loss = (1 - self.alpha) * ce_loss + self.alpha * kl_div
        return loss, ce_loss, kl_div

def train_distill(student_model, teacher_model, trainloader, criterion, optimizer, scheduler, epoch):
    student_model.train()
    teacher_model.eval()

    train_loss = 0
    ce_loss_sum = 0
    kl_loss_sum = 0
    correct = 0
    total = 0

    loop = tqdm(trainloader, desc=f"Distill Epoch {epoch + 1}")
    for batch_idx, (inputs, targets) in enumerate(loop):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        student_outputs = student_model(inputs)

        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)

        loss, ce_loss, kl_loss = criterion(student_outputs, teacher_outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        ce_loss_sum += ce_loss.item()
        kl_loss_sum += kl_loss.item()

        _, predicted = student_outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        loop.set_postfix(loss=train_loss/(batch_idx+1),
                        ce_loss=ce_loss_sum/(batch_idx+1),
                        kl_loss=kl_loss_sum/(batch_idx+1),
                        acc=100.*correct/total)

    scheduler.step()

    return train_loss/len(trainloader), ce_loss_sum/len(trainloader), kl_loss_sum/len(trainloader), 100.*correct/total

def val(model, testloader, criterion=nn.CrossEntropyLoss()):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return test_loss/len(testloader), 100.*correct/total

def main():
    num_epochs = 100
    batch_size = 128
    lr = 0.1
    alpha = 0.5
    temperature = 4.0

    os.makedirs('distill_results', exist_ok=True)

    trainloader, testloader, classes = get_data_loaders(batch_size)

    teacher_model = ResNet152().to(device)
    teacher_path = '/content/drive/MyDrive/CIFAR10-model5-finetuning/resnet152_finetune_best.pt'
    teacher_model.load_state_dict(torch.load(teacher_path))
    teacher_model.eval()
    print(f"Loaded teacher model ResNet152 from {teacher_path}")

    teacher_loss, teacher_acc = val(teacher_model, testloader)
    print(f"Teacher model (ResNet152) - Test Acc: {teacher_acc:.2f}%")

    student_model = ResNet18().to(device)
    print(f"Created student model ResNet18 with {sum(p.numel() for p in student_model.parameters())} parameters")
    print(f"Teacher model has {sum(p.numel() for p in teacher_model.parameters())} parameters")
    print(f"Size reduction: {sum(p.numel() for p in teacher_model.parameters()) / sum(p.numel() for p in student_model.parameters()):.1f}x")

    criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    optimizer = optim.SGD(student_model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0)

    train_losses = []
    ce_losses = []
    kl_losses = []
    test_losses = []
    train_accs = []
    test_accs = []
    best_acc = 0

    for epoch in range(num_epochs):
        train_loss, ce_loss, kl_loss, train_acc = train_distill(
            student_model, teacher_model, trainloader, criterion, optimizer, scheduler, epoch)
        train_losses.append(train_loss)
        ce_losses.append(ce_loss)
        kl_losses.append(kl_loss)
        train_accs.append(train_acc)

        test_criterion = nn.CrossEntropyLoss()
        test_loss, test_acc = val(student_model, testloader, test_criterion)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        print(f'\nEpoch {epoch + 1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f} (CE: {ce_loss:.4f}, KL: {kl_loss:.4f}), Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
        print(f'Learning rate: {scheduler.get_last_lr()[0]:.6f}')

        drive_checkpoint_path = f'{DISTILL_CHECKPOINT_DIR}/resnet18_distill_epoch_{epoch+1}.pt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': student_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'test_loss': test_loss,
            'train_acc': train_acc,
            'test_acc': test_acc
        }, drive_checkpoint_path)
        print(f'Checkpoint saved to Google Drive: resnet18_distill_epoch_{epoch+1}.pt')

        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            # Save to local directory
            torch.save(student_model.state_dict(), 'distill_results/resnet18_distill_best.pt')

            # Save to Google Drive
            best_drive_path = f'{DISTILL_CHECKPOINT_DIR}/resnet18_distill_best.pt'
            torch.save(student_model.state_dict(), best_drive_path)
            print(f'New best student model saved! (Accuracy: {test_acc:.2f}%)')

        if (epoch + 1) % 20 == 0 or epoch == num_epochs - 1:
            plt.figure(figsize=(15, 10))

            plt.subplot(2, 2, 1)
            plt.plot(train_losses, label='Total Loss')
            plt.plot(ce_losses, label='CE Loss')
            plt.plot(kl_losses, label='KL Loss')
            plt.title('Training Losses vs. Epochs')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.legend()

            plt.subplot(2, 2, 2)
            plt.plot(train_accs, label='Train Accuracy')
            plt.plot(test_accs, label='Test Accuracy')
            plt.axhline(y=teacher_acc, color='r', linestyle='--', label=f'Teacher Acc ({teacher_acc:.2f}%)')
            plt.title('Accuracy vs. Epochs')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy (%)')
            plt.legend()

            plt.tight_layout()
            plt.savefig(f'distill_results/distillation_progress_epoch_{epoch+1}.png')
            plt.savefig(f'{DISTILL_CHECKPOINT_DIR}/distillation_progress_epoch_{epoch+1}.png')
            plt.close()

    plt.figure(figsize=(15, 10))

    plt.subplot(2, 2, 1)
    plt.plot(train_losses, label='Total Loss')
    plt.plot(ce_losses, label='CE Loss')
    plt.plot(kl_losses, label='KL Loss')
    plt.title('Training Losses vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(test_accs, label='Test Accuracy')
    plt.axhline(y=teacher_acc, color='r', linestyle='--', label=f'Teacher Acc ({teacher_acc:.2f}%)')
    plt.title('Accuracy vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig('distill_results/final_distillation_curves.png')
    plt.savefig(f'{DISTILL_CHECKPOINT_DIR}/final_distillation_curves.png')
    plt.show()

    student_size = sum(p.numel() for p in student_model.parameters())
    teacher_size = sum(p.numel() for p in teacher_model.parameters())

    print(f"=== Final Results ===")
    print(f"Teacher (ResNet152) accuracy: {teacher_acc:.2f}%")
    print(f"Student (ResNet18) best accuracy: {best_acc:.2f}%")
    print(f"Accuracy difference: {teacher_acc - best_acc:.2f}%")
    print(f"Size reduction: {teacher_size/student_size:.1f}x ({teacher_size:,} → {student_size:,} parameters)")

    def measure_inference_time(model, input_size=(1, 3, 32, 32), num_iterations=100):
        model.eval()
        dummy_input = torch.randn(input_size).to(device)

        with torch.no_grad():
            for _ in range(10):
                _ = model(dummy_input)

        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)

        start_time.record()
        with torch.no_grad():
            for _ in range(num_iterations):
                _ = model(dummy_input)
        end_time.record()

        torch.cuda.synchronize()
        elapsed_time = start_time.elapsed_time(end_time) / num_iterations
        return elapsed_time

    student_time = measure_inference_time(student_model)
    print(f"Teacher inference time: {teacher_time:.2f} ms")
    print(f"Student inference time: {student_time:.2f} ms")
    print(f"Speedup: {teacher_time/student_time:.1f}x")


if __name__ == '__main__':
    main()