<a href="https://colab.research.google.com/github/leonsuarez24/Notebooks/blob/main/Adversarial_Network_Compression_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [21]:
import sys
import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from scipy.ndimage.interpolation import rotate as scipyrotate
!pip install torchinfo
from torchinfo import summary
from torch.autograd import Function
import torchvision.utils as vutils
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import torchvision.utils as vutils
!pip install torchmetrics
from torchmetrics import Accuracy
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


  from scipy.ndimage.interpolation import rotate as scipyrotate




# Utils

In [22]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [23]:
def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))

# Networks

##Teacher

In [24]:
class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.fc = nn.Linear(512,512)
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        features = self.fc(out)
        logits = self.classifier(features)
        return logits, features

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [25]:
t = VGG('VGG11')
print(summary(t, input_size=(12, 3, 32, 32)))

Layer (type:depth-idx)                   Output Shape              Param #
VGG                                      [12, 10]                  --
├─Sequential: 1-1                        [12, 512, 1, 1]           --
│    └─Conv2d: 2-1                       [12, 64, 32, 32]          1,792
│    └─BatchNorm2d: 2-2                  [12, 64, 32, 32]          128
│    └─ReLU: 2-3                         [12, 64, 32, 32]          --
│    └─MaxPool2d: 2-4                    [12, 64, 16, 16]          --
│    └─Conv2d: 2-5                       [12, 128, 16, 16]         73,856
│    └─BatchNorm2d: 2-6                  [12, 128, 16, 16]         256
│    └─ReLU: 2-7                         [12, 128, 16, 16]         --
│    └─MaxPool2d: 2-8                    [12, 128, 8, 8]           --
│    └─Conv2d: 2-9                       [12, 256, 8, 8]           295,168
│    └─BatchNorm2d: 2-10                 [12, 256, 8, 8]           512
│    └─ReLU: 2-11                        [12, 256, 8, 8]           --


## Student

In [26]:
class StudentNetwork(nn.Module):
    def __init__(self, channel, num_classes):
        super(StudentNetwork, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(channel, 6, kernel_size=5, padding=2 if channel==1 else 0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_1 = nn.Linear(16 * 5 * 5, 256)
        self.fc_2 = nn.Linear(256, 512)
        self.fc_3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_1(x))
        features = F.relu(self.fc_2(x))
        logits = self.fc_3(features)
        return logits, features

In [27]:
s = StudentNetwork(3,10)
print(summary(s, input_size=(12, 3, 32, 32)))

Layer (type:depth-idx)                   Output Shape              Param #
StudentNetwork                           [12, 10]                  --
├─Sequential: 1-1                        [12, 16, 5, 5]            --
│    └─Conv2d: 2-1                       [12, 6, 28, 28]           456
│    └─ReLU: 2-2                         [12, 6, 28, 28]           --
│    └─MaxPool2d: 2-3                    [12, 6, 14, 14]           --
│    └─Conv2d: 2-4                       [12, 16, 10, 10]          2,416
│    └─ReLU: 2-5                         [12, 16, 10, 10]          --
│    └─MaxPool2d: 2-6                    [12, 16, 5, 5]            --
├─Linear: 1-2                            [12, 256]                 102,656
├─Linear: 1-3                            [12, 512]                 131,584
├─Linear: 1-4                            [12, 10]                  5,130
Total params: 242,242
Trainable params: 242,242
Non-trainable params: 0
Total mult-adds (M): 10.06
Input size (MB): 0.15
Forward/backward 

## Discriminator

In [28]:
class Discriminator(nn.Module):
    def __init__(self, input):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)

        self.dropout = nn.Dropout(p=0.5)
        self.output_layer = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, drop):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        if drop == True:
            x = self.dropout(x)
        x = self.sigmoid(self.output_layer(x))
        return x

In [29]:
model = Discriminator(100)
input_data = torch.randn(12,100)  # Example input with shape [batch_size, channels, height, width]
output = model(input_data, True)
print(output.shape)  # Should be [12, 10]

torch.Size([12, 1])


# Dataset

In [30]:
def get_dataset(dataset, data_path):

    if dataset == 'MNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        transform = transforms.Compose([transforms.ToTensor()])
        dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
        class_names = [str(c) for c in range(num_classes)]

    elif dataset == 'FashionMNIST':
        channel = 1
        im_size = (28, 28)
        num_classes = 10
        transform = transforms.Compose([transforms.ToTensor()])
        dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    elif dataset == 'CIFAR10':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        transform = transforms.Compose([transforms.ToTensor()])
        #transform = transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(0.5)])
        dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
        dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
        class_names = dst_train.classes

    else:
        exit('unknown dataset: %s'%dataset)

    testloader = torch.utils.data.DataLoader(dst_test, batch_size=32, shuffle=False, num_workers=0)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=32, shuffle=True, num_workers=0)
    return channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader

# Configs

In [31]:
save_path = f'/content/drive/MyDrive/Proyectos/Adversarial Knowledge Distillation/experiments/{get_time()}'

tb_path_teacher = save_path + '/tensorboard_teacher'
tb_path_student = save_path + '/tensorboard_student'
tb_path_student_distilled = save_path + '/tensorboard_student_distilled'

model_teacher = save_path + '/model_teacher'
model_student = save_path + '/model_student'
model_student_distilled = save_path + '/model_student_distilled'

os.makedirs(save_path, exist_ok=True)
os.makedirs(tb_path_teacher, exist_ok=True)
os.makedirs(tb_path_student, exist_ok=True)
os.makedirs(model_teacher, exist_ok=True)
os.makedirs(model_student, exist_ok=True)
os.makedirs(tb_path_student_distilled, exist_ok=True)
os.makedirs(model_student_distilled, exist_ok=True)



writer_teacher = SummaryWriter(tb_path_teacher)
writer_student = SummaryWriter(tb_path_student)
writer_student_distilled = SummaryWriter(tb_path_student_distilled)

In [32]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr_net = 0.001
teacher_epochs = 20
student_epochs = 20
dataset = 'CIFAR10'
data_path = 'data'
channel, im_size, num_classes, class_names, dst_train, dst_test, testloader, trainloader = get_dataset(dataset, data_path)
accuracy = Accuracy(task="multiclass", num_classes=num_classes, top_k=1).to(device)

Files already downloaded and verified
Files already downloaded and verified


# Baseline

In [33]:
teacher = VGG('VGG16').to(device)
optimizer = torch.optim.Adam(teacher.parameters(), lr=lr_net)
criterion = torch.nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=6, gamma=0.5)

for epoch in range(teacher_epochs):

    train_loss = AverageMeter()
    accuracy_train = AverageMeter()
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')

    teacher.train(True)
    for _, train_data in data_loop_train:
        train_img, train_label = train_data
        train_img = train_img.to(device)
        train_label = train_label.to(device)
        optimizer.zero_grad()

        train_pred, _ = teacher(train_img)
        loss = criterion(train_pred, train_label)
        acc = accuracy(train_pred, train_label)

        train_loss.update(loss.item(), train_img.size(0))
        accuracy_train.update(acc.item(), train_img.size(0))

        dict_metrics = dict(loss = train_loss.avg, acc = accuracy_train.avg)

        loss.backward()
        optimizer.step()

        data_loop_train.set_description(f'Train Epoch [{epoch + 1} / {teacher_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
              writer_teacher.add_scalar(f'train_{key}', value, epoch)

    scheduler.step()
    # Evaluation phase
    teacher.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            test_pred, _ = teacher(test_img)
            loss = criterion(test_pred, test_label)
            acc = accuracy(test_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {teacher_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_teacher.add_scalar(f'test_{key}', value, epoch)

Train Epoch [1 / 20]: 100%|[31m██████████[0m| 1563/1563 [01:41<00:00, 15.44it/s, acc=0.258, loss=1.88]
Test  Epoch [1 / 20]: 100%|[32m██████████[0m| 313/313 [00:06<00:00, 51.39it/s, acc=0.352, loss=1.55]
Train Epoch [2 / 20]: 100%|[31m██████████[0m| 1563/1563 [01:30<00:00, 17.23it/s, acc=0.489, loss=1.36]
Test  Epoch [2 / 20]: 100%|[32m██████████[0m| 313/313 [00:05<00:00, 56.97it/s, acc=0.582, loss=1.17]
Train Epoch [3 / 20]: 100%|[31m██████████[0m| 1563/1563 [01:37<00:00, 15.99it/s, acc=0.646, loss=1]
Test  Epoch [3 / 20]: 100%|[32m██████████[0m| 313/313 [00:05<00:00, 52.39it/s, acc=0.646, loss=1.09]
Train Epoch [4 / 20]: 100%|[31m██████████[0m| 1563/1563 [01:33<00:00, 16.75it/s, acc=0.726, loss=0.802]
Test  Epoch [4 / 20]: 100%|[32m██████████[0m| 313/313 [00:05<00:00, 57.45it/s, acc=0.747, loss=0.759]
Train Epoch [5 / 20]: 100%|[31m██████████[0m| 1563/1563 [01:33<00:00, 16.71it/s, acc=0.778, loss=0.664]
Test  Epoch [5 / 20]: 100%|[32m██████████[0m| 313/313 [00:05<

In [37]:
student = StudentNetwork(3,10).to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=lr_net)
criterion = torch.nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=6, gamma=0.5)

for epoch in range(student_epochs):

    train_loss = AverageMeter()
    accuracy_train = AverageMeter()
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')

    student.train(True)
    for _, train_data in data_loop_train:
        train_img, train_label = train_data
        train_img = train_img.to(device)
        train_label = train_label.to(device)
        optimizer.zero_grad()

        train_pred, _ = student(train_img)
        loss = criterion(train_pred, train_label)
        acc = accuracy(train_pred, train_label)

        train_loss.update(loss.item(), train_img.size(0))
        accuracy_train.update(acc.item(), train_img.size(0))

        dict_metrics = dict(loss = train_loss.avg, acc = accuracy_train.avg)

        loss.backward()
        optimizer.step()

        data_loop_train.set_description(f'Train Epoch [{epoch + 1} / {student_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
              writer_student.add_scalar(f'train_{key}', value, epoch)

    scheduler.step()
    # Evaluation phase
    student.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            test_pred, _ = student(test_img)
            loss = criterion(test_pred, test_label)
            acc = accuracy(test_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {student_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_student.add_scalar(f'test_{key}', value, epoch)

Train Epoch [1 / 20]: 100%|[31m██████████[0m| 1563/1563 [00:38<00:00, 40.08it/s, acc=0.366, loss=1.7]
Test  Epoch [1 / 20]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 80.57it/s, acc=0.462, loss=1.48]
Train Epoch [2 / 20]: 100%|[31m██████████[0m| 1563/1563 [00:37<00:00, 41.51it/s, acc=0.482, loss=1.42]
Test  Epoch [2 / 20]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 80.12it/s, acc=0.492, loss=1.41]
Train Epoch [3 / 20]: 100%|[31m██████████[0m| 1563/1563 [00:37<00:00, 41.63it/s, acc=0.528, loss=1.3]
Test  Epoch [3 / 20]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 77.31it/s, acc=0.543, loss=1.27]
Train Epoch [4 / 20]: 100%|[31m██████████[0m| 1563/1563 [00:36<00:00, 42.25it/s, acc=0.563, loss=1.22]
Test  Epoch [4 / 20]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 65.05it/s, acc=0.558, loss=1.24]
Train Epoch [5 / 20]: 100%|[31m██████████[0m| 1563/1563 [00:37<00:00, 41.61it/s, acc=0.592, loss=1.15]
Test  Epoch [5 / 20]: 100%|[32m██████████[0m| 313/313 [00:04<00

# Distillation

In [35]:
real_label = 1
fake_label = 0
student_distilled = StudentNetwork(3,10).to(device)
teacher.eval()
discriminator = Discriminator(512).to(device)
bce_loss = nn.BCELoss()
mse_loss = nn.MSELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr_net)
optimizer_student = torch.optim.Adam(student_distilled.parameters(), lr=lr_net)
criterion = torch.nn.CrossEntropyLoss()
torch.autograd.set_detect_anomaly(True)

scheduler_discriminator = StepLR(optimizer_discriminator, step_size=6, gamma=0.5)
scheduler_student = StepLR(optimizer_student, step_size=6, gamma=0.5)

In [36]:
for epoch in range(student_epochs):
    # For each batch in the dataloader
    data_loop_train = tqdm(enumerate(trainloader), total=len(trainloader), colour='red')
    student_distilled.train(True)
    train_disc_loss = AverageMeter()
    train_stu_loss = AverageMeter()
    for _, train_data in data_loop_train:

        # UPDATE DISCRIMINATOR ADVERSARIAL

        train_img, _ = train_data
        train_img = train_img.to(device)
        batch_size = train_img.size(0)
        optimizer_discriminator.zero_grad()

        disc_label_real = torch.full((batch_size,), real_label, dtype=torch.float).unsqueeze(-1).to(device)
        disc_label_fake = torch.full((batch_size,), fake_label, dtype=torch.float).unsqueeze(-1).to(device)

        _, features_teacher = teacher(train_img)
        output_real = discriminator(features_teacher, False)
        err_real = bce_loss(output_real, disc_label_real)

        _, features_student = student_distilled(train_img)
        output_fake = discriminator(features_student, False)
        err_real = bce_loss(output_fake, disc_label_fake)

        # REGULARIZER DISCRIMINATOR
        reg = bce_loss(output_fake, disc_label_real)
        loss = err_real + err_real + reg
        loss.backward()
        optimizer_discriminator.step()

        # UPDATE STUDENT ADVERSARIAL + MSE
        optimizer_student.zero_grad()
        stu_labels = torch.full((batch_size,), real_label, dtype=torch.float).unsqueeze(-1).to(device)
        logits_student, features_student = student_distilled(train_img)
        logits_teacher, features_teacher = teacher(train_img)
        output_fake = discriminator(features_student, True)
        s_loss = bce_loss(output_fake, stu_labels) + mse_loss(logits_student, logits_teacher)

        s_loss.backward()
        optimizer_student.step()

        # METRICS
        train_disc_loss.update(loss.item(), train_img.size(0))
        train_stu_loss.update(s_loss.item(), train_img.size(0))
        dict_metrics = dict(disc_loss = train_disc_loss.avg, stu_loss = train_stu_loss.avg)

        data_loop_train.set_description(f'Train  Epoch [{epoch + 1} / {student_epochs}]')
        data_loop_train.set_postfix(**dict_metrics)

        for key, value in dict_metrics.items():
            writer_student_distilled.add_scalar(f'ttrain_{key}', value, epoch)

    scheduler_discriminator.step()
    scheduler_student.step()
    # Eval model
    student_distilled.eval()
    data_loop_test = tqdm(enumerate(testloader), total=len(testloader),colour='green')
    with torch.no_grad():

        test_loss = AverageMeter()
        accuracy_test = AverageMeter()

        for _, test_data in data_loop_test:
            test_img, test_label = test_data
            test_img = test_img.to(device)
            test_label = test_label.to(device)

            test_pred, _ = student_distilled(test_img)
            loss = criterion(test_pred, test_label)
            acc = accuracy(test_pred, test_label)

            test_loss.update(loss.item(), test_img.size(0))
            accuracy_test.update(acc.item(), test_img.size(0))
            dict_metrics = dict(loss = test_loss.avg, acc = accuracy_test.avg)

            data_loop_test.set_description(f'Test  Epoch [{epoch + 1} / {student_epochs}]')
            data_loop_test.set_postfix(**dict_metrics)

            for key, value in dict_metrics.items():
                writer_student_distilled.add_scalar(f'test_{key}', value, epoch)

Train  Epoch [1 / 20]: 100%|[31m██████████[0m| 1563/1563 [03:15<00:00,  7.98it/s, disc_loss=1.91, stu_loss=60.4]
Test  Epoch [1 / 20]: 100%|[32m██████████[0m| 313/313 [00:03<00:00, 80.61it/s, acc=0.452, loss=2.91]
Train  Epoch [2 / 20]: 100%|[31m██████████[0m| 1563/1563 [03:19<00:00,  7.85it/s, disc_loss=1.91, stu_loss=46.1]
Test  Epoch [2 / 20]: 100%|[32m██████████[0m| 313/313 [00:05<00:00, 58.07it/s, acc=0.501, loss=2.69]
Train  Epoch [3 / 20]: 100%|[31m██████████[0m| 1563/1563 [03:16<00:00,  7.95it/s, disc_loss=1.91, stu_loss=41]
Test  Epoch [3 / 20]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 71.76it/s, acc=0.51, loss=3.26]
Train  Epoch [4 / 20]: 100%|[31m██████████[0m| 1563/1563 [03:14<00:00,  8.05it/s, disc_loss=1.91, stu_loss=37.9]
Test  Epoch [4 / 20]: 100%|[32m██████████[0m| 313/313 [00:04<00:00, 74.81it/s, acc=0.552, loss=2.76]
Train  Epoch [5 / 20]: 100%|[31m██████████[0m| 1563/1563 [03:14<00:00,  8.04it/s, disc_loss=1.91, stu_loss=35.1]
Test  Epoch [5 /