In [None]:
# import necessary dependencies
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor
import torchvision
import torchvision.transforms as transforms

import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

In [None]:
import torch.nn as nn
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F

In [None]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )

def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 10,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

In [None]:
# specify the device for computation
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device =='cuda':
    print("Run on GPU...")
else:
    print("Run on CPU...")

ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
ResNet18_model = ResNet18_model.to(device)

!nvidia-smi

Run on GPU...
Fri Nov 25 00:53:43 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    49W / 400W |   1060MiB / 40536MiB |      5%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------------------------------------------

# CIFAR-10

In [None]:
# Image preprocessing modules
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [None]:

#from tools.dataset import CIFAR10
from torch.utils.data import DataLoader

# a few arguments, do NOT change these
DATA_ROOT = "./data"
TRAIN_BATCH_SIZE = 256
VAL_BATCH_SIZE = 256


# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root=DATA_ROOT,train=True,transform=transform_train,download=True)

test_dataset = torchvision.datasets.CIFAR10(root=DATA_ROOT,train=False,transform=transform_val)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=TRAIN_BATCH_SIZE,shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=VAL_BATCH_SIZE,shuffle=False)

# CIFAR-100

In [None]:
# Image preprocessing modules
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

In [None]:
#from tools.dataset import CIFAR100
from torch.utils.data import DataLoader

# a few arguments, do NOT change these
DATA_ROOT = "./data"
TRAIN_BATCH_SIZE = 256
VAL_BATCH_SIZE = 256

# CIFAR-100 dataset
train_dataset = torchvision.datasets.CIFAR100(root=DATA_ROOT,train=True,transform=transform_train,download=True)

test_dataset = torchvision.datasets.CIFAR100(root=DATA_ROOT,train=False,transform=transform_val)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=TRAIN_BATCH_SIZE,shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=VAL_BATCH_SIZE,shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data


# ResNet18

In [None]:
def train_resnet_18(init_lr: float, momentum: float, regularization: float, decay_factor: float, total_epochs: int, change_lr_epochs: List[int]):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  ResNet18_model = ResNet18_model.to(device)
  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  optimizer = torch.optim.SGD(ResNet18_model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
  criterion = nn.CrossEntropyLoss()
  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0

  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"

  best_val_acc = 0
  best_val_epoch = 0

  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      
      # handle the learning rate scheduler.
      if i in change_lr_epochs == 0:
          current_learning_rate = current_learning_rate * DECAY
          for param_group in optimizer.param_groups:
              param_group['lr'] = current_learning_rate
          print("Current learning rate has decayed to %f" %current_learning_rate)
      
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet18_model.train()

      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
    
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet18_model(inputs)
          loss = criterion(outputs, targets)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet18_model.eval()

      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want

      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet18_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet18_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'only_train_ResNet18.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# ResNet50

In [None]:
def train_resnet_50(init_lr: float, momentum: float, regularization: float, decay_factor: float, total_epochs: int, change_lr_epochs: List[int]):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  ResNet50_model = ResNet(BasicBlock, [3, 4, 6, 3])
  ResNet50_model = ResNet50_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay_factor
  optimizer = optim.SGD(ResNet50_model.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()

  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"

  best_val_acc = 0
  best_val_epoch = 0

  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      
      # handle the learning rate scheduler.
      if i in change_lr_epochs == 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)

      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet50_model.train()
      #######################
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet50_model(inputs)
          loss = criterion(outputs, targets)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet50_model.eval()
      #######################

      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      
      val_loss = 0 # again, track the validation loss if you want

      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet50_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet50_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'only_train_ResNet50.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! ResNet50 Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# ResNet50 Teacher ResNet18 Student

In [None]:
def train_resnet18_student_resnet50_teacher(init_lr: float, momentum: float, regularization: float, decay: float, total_epochs: int, change_lr_epochs: List[int], temp: float, alpha: float):

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  Pretrained_ResNet50_model = ResNet(BasicBlock, [3, 4, 6, 3])

  state_dict = torch.load('./saved_model/only_train_ResNet50.pth') # change the path to your own checkpoint file
  Pretrained_ResNet50_model.load_state_dict(state_dict['state_dict'])
  Pretrained_ResNet50_model = Pretrained_ResNet50_model.to(device)

  ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  ResNet18_model = ResNet18_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay
  optimizer = optim.SGD(ResNet18_model.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()

  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"
  temperature = temp
  alpha = alpha

  best_val_acc = 0
  best_val_epoch = 0

  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      # handle the learning rate scheduler.
      if i in change_lr_epochs == 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
    
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet18_model.train()
      #######################
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet18_model(inputs)
          Pretrained_output = Pretrained_ResNet50_model(inputs)
          
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Pretrained_output = F.softmax(torch.div(Pretrained_output, temperature),dim =1)
          ###########################################################################################today remember to add alpha and 1-alpha on loss
          #print("criterion(outputs, targets):", criterion(outputs, targets))
          #print("DKL(soften_outputs, Pretrained_output):", DKL(soften_outputs, Pretrained_output))
          loss = (1-alpha) * criterion(outputs, targets) + alpha * DKL(soften_outputs, Pretrained_output)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet18_model.eval()
      #######################
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # your code here
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet18_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet18_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'KD_ResNet50_Teacher_ResNet18_Student.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! ResNet 50 Teacher for ResNet18 Student Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# ResNet18 Teacher ResNet50 Student

In [None]:
def train_resnet50_student_resnet18_teacher(init_lr: float, momentum: float, regularization: float, decay: float, total_epochs: int, change_lr_epochs: List[int], temp: float, alpha: float):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  Pretrained_ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])

  state_dict = torch.load('./saved_model/only_train_ResNet18.pth') 
  Pretrained_ResNet18_model.load_state_dict(state_dict['state_dict'])
  Pretrained_ResNet18_model = Pretrained_ResNet18_model.to(device)

  ResNet50_model = ResNet(BasicBlock, [3, 4, 6, 3])
  ResNet50_model = ResNet50_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay
  optimizer = optim.SGD(ResNet50_model.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()

  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"
  temperature = temp
  alpha = alpha
  best_val_acc = 0
  best_val_epoch = 0
  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      
      # handle the learning rate scheduler.
      if i in change_lr_epochs == 0:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
    
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet50_model.train()
      #######################
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet50_model(inputs)
          Pretrained_output = Pretrained_ResNet18_model(inputs)
          
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Pretrained_output = F.softmax(torch.div(Pretrained_output, temperature),dim =1)
          loss = (1-alpha) * criterion(outputs, targets) + alpha * DKL(soften_outputs, Pretrained_output)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet50_model.eval()
      #######################
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # your code here
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet50_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet18_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'KD_ResNet18_Teacher_ResNet50_Student.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! ResNet18 Teacher for ResNet50 Student Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# Teacher-Free Self Knowledge Distillation

Normal KD
ResNet_18 learns from Pretrained ResNet_18 Student: ResNet_18 Teacher: Self (Pretrained)

In [None]:
def train_resnet18_self_kd(init_lr: float, momentum: float, regularization: float, decay: float, total_epochs: int, change_lr_epochs: List[int], temp: float, alpha: float):
  # train_resnet_18(init_lr, momentum, regularization, decay, total_epochs, change_lr_epochs)

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  Pretrained_ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  state_dict = torch.load('./saved_model/only_train_ResNet18.pth') # change the path to your own checkpoint file
  Pretrained_ResNet18_model.load_state_dict(state_dict['state_dict'])
  Pretrained_ResNet18_model = Pretrained_ResNet18_model.to(device)

  ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  ResNet18_model = ResNet18_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay
  optimizer = optim.SGD(ResNet18_model.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()

  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"
  temperature = temp
  alpha = alpha
  best_val_acc = 0
  best_val_epoch = 0
  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      
      # handle the learning rate scheduler.
      if i in change_lr_epochs:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
    
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet18_model.train()
      #######################

      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet18_model(inputs)
          Pretrained_output = Pretrained_ResNet18_model(inputs)
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Pretrained_output = F.softmax(torch.div(Pretrained_output, temperature),dim =1)
          loss = (1-alpha) * criterion(outputs, targets) + alpha * DKL(soften_outputs, Pretrained_output)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet18_model.eval()
      #######################
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # your code here
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet18_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet18_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'Tf_KD_self_ResNet18.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
      print('')

  print("="*50)
  print(f"==> Optimization finished! ResNet18 Self Knowledge Distillation Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# Teacher-Free Knowledge Distillation Manually-Designated Regularization

We design a virtual teacher model based on targets, assigning weight a to the target and weight (1 - a) / (num_classes - 1) to all other classes

In [None]:
def train_resnet18_reg_kd(init_lr: float, momentum: float, regularization: float, decay: float, total_epochs: int, change_lr_epochs: List[int], temp: float, alpha: float, a_val: float, num_classes: int):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  ResNet18_model = ResNet18_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay
  optimizer = optim.SGD(ResNet18_model.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()

  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"

  temperature = temp
  alpha = alpha
  a = a_val
  n_classes = num_classes

  # start the training/validation process
  # the process should take about 5 minutes on a GTX 1070-Ti
  # if the code is written efficiently.
  best_val_acc = 0
  best_val_epoch = 0

  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      # handle the learning rate scheduler.
      if i in change_lr_epochs:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
      
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      ResNet18_model.train()
      #######################

      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      # Train the model for 1 epoch.
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          ####################################
          # your code here
          # copy inputs to device
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = ResNet18_model(inputs)
          detached_targets = targets.detach()
          Manual_output = torch.full((detached_targets.size()[0], n_classes), (1 - a) / (n_classes - 1))
          for i in range(detached_targets.size()[0]):
            Manual_output[i][detached_targets[i]] = a
          Manual_output = Manual_output.to(device)
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Manual_output = F.softmax(torch.div(Manual_output, temperature),dim =1)
          loss = (1-alpha) * criterion(outputs, targets) + alpha * DKL(soften_outputs, Manual_output)
          train_loss += loss
          # zero the gradient
          optimizer.zero_grad()
          # backpropagation
          loss.backward()
          # apply gradient and update the weights
          optimizer.step()
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      train_accuracy_history.append(avg_acc.cpu())
      # Validate on the validation dataset
      #######################
      # switch to eval mode
      ResNet18_model.eval()
      #######################

      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      # disable gradient during validation, which can save GPU memory
      with torch.no_grad():
          for batch_idx, (inputs, targets) in enumerate(test_loader):
              ####################################
              # copy inputs to device
              inputs = inputs.to(device)
              targets = targets.to(device)
              # compute the output and loss
              outputs = ResNet18_model(inputs)
              loss = criterion(outputs, targets)
              val_loss += loss
              # count the number of correctly predicted samples in the current batch
              _, pre_out = torch.max(outputs.data, 1)
              correct_examples = correct_examples + torch.sum(pre_out == targets)
              total_examples = total_examples + targets.shape[0]
              ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      validation_accuracy_history.append(avg_acc.cpu())
      # save the model checkpoint
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': ResNet18_model.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, 'Tf_KD_reg_ResNet18.pth'))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! Teacher Free Knowledge Distillation ResNet18 Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# Adverserial

In [None]:
import attacks
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import random

In [None]:
def test_model(mdl, loader, device):
    mdl.eval()
    running_correct = 0.
    running_loss = 0.
    running_total = 0.
    with torch.no_grad():
        for batch_idx,(data,labels) in enumerate(loader):
            data = data.to(device); labels = labels.to(device)
            clean_outputs = mdl(data)
            clean_loss = F.cross_entropy(clean_outputs, labels)
            _,clean_preds = clean_outputs.max(1)
            running_correct += clean_preds.eq(labels).sum().item()
            running_loss += clean_loss.item()
            running_total += labels.size(0)
    clean_acc = running_correct/running_total
    clean_loss = running_loss/len(loader)
    mdl.train()
    return clean_acc,clean_loss

In [None]:
max_val, min_val = float('-inf'), float('inf')
for batch_idx,(data,labels) in enumerate(test_loader):
  max_val = max(torch.max(data).item(), max_val)
  min_val = min(torch.min(data).item(), min_val)
print('Max: ', max_val)
print('Min: ', min_val)

Max:  2.7537312507629395
Min:  -2.429065704345703


# ResNet18 Adversarial

In [None]:
def train_resnet_18_adversarial(init_lr: float, momentum: float, regularization: float, decay_factor: float, total_epochs: int, change_lr_epochs: List[int], optim_type: str, eps: float, atk_iters: int, temp: float, alpha: float):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")
  net = ResNet(BasicBlock, [2, 2, 2, 2])
  net = net.to(device)
  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay_factor
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()
  Total_EPOCHS = total_epochs
  ATK_EPS = eps
  ATK_ITERS = atk_iters
  ATK_ALPHA = 1.85 * ATK_EPS / ATK_ITERS
  RND_START = True
  temperature = temp
  alpha = alpha
  if optim_type.lower() == 'sgd':
    optimizer = optim.SGD(net.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  else:
    optimizer = optim.Adam(net.parameters(), lr=INITIAL_LR, weight_decay=REG)
  criterion = nn.CrossEntropyLoss()
  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"
  already_find_best_val_acc_count = 0
  best_val_acc = 0
  best_val_epoch = 0
  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      # handle the learning rate scheduler.
      if i in change_lr_epochs:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
      
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))
      #######################
      # switch to train mode
      net.train()
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      for batch_idx,(data,labels) in enumerate(train_loader):
        data = data.to(device); labels = labels.to(device)
        adv_data = attacks.PGD_attack(net, device, data, labels, ATK_EPS, ATK_ALPHA, ATK_ITERS, RND_START, min_val, max_val)
        
        # Forward pass
        outputs = net(adv_data)
        net.zero_grad()
        optimizer.zero_grad()
        # Compute loss, gradients, and update params
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # Update stats
        _,preds = outputs.max(1)
        correct_examples += preds.eq(labels).sum().item()
        train_loss += loss.item()
        total_examples += labels.size(0)
        ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      train_accuracy_history.append(avg_acc)
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

      # switch to eval mode
      net.eval()
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      
      val_loss = 0 # again, track the validation loss if you want
      with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            ####################################
            inputs = inputs.to(device)
            targets = targets.to(device)
            # compute the output and loss
            outputs = net(inputs)
            # outputs = net(adv_data)
            loss = criterion(outputs, targets)
            val_loss += loss
            # count the number of correctly predicted samples in the current batch
            _, pre_out = torch.max(outputs.data, 1)
            correct_examples = correct_examples + torch.sum(pre_out == targets)
            total_examples = total_examples + targets.shape[0]
            ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      validation_accuracy_history.append(avg_acc)
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      # save the model checkpoint
      model_checkpoint = "Adv_only_train_ResNet18.pth"
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': net.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, model_checkpoint))
      else:
          already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))

# Teacher-Free Self Knowledge Distillation Adversarial

In [None]:
def train_resnet_18_self_kd_adversarial(init_lr: float, momentum: float, regularization: float, decay_factor: float, total_epochs: int, change_lr_epochs: List[int], optim_type: str, eps: float, atk_iters: int, temp: float, alpha: float):
  # train_resnet_18_adversarial(init_lr, momentum, regularization, decay_factor, total_epochs, change_lr_epochs, optim_type, eps, atk_iters, temp, alpha)
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  net = ResNet(BasicBlock, [2, 2, 2, 2])
  net = net.to(device)
  Pretrained_ResNet18_model = ResNet(BasicBlock, [2, 2, 2, 2])
  state_dict = torch.load('./saved_model/Adv_only_train_ResNet18.pth') 
  Pretrained_ResNet18_model.load_state_dict(state_dict['state_dict'])
  Pretrained_ResNet18_model = Pretrained_ResNet18_model.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay_factor
  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()

  if optim_type.lower() == 'sgd':
    optimizer = optim.SGD(net.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  else:
    optimizer = optim.Adam(net.parameters(), lr=INITIAL_LR, weight_decay=REG)
  
  ATK_EPS = eps
  ATK_ITERS = atk_iters
  ATK_ALPHA = 1.85 * ATK_EPS / ATK_ITERS
  RND_START = True
  temperature = temp
  alpha = alpha

  criterion = nn.CrossEntropyLoss()
  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0

  current_learning_rate = INITIAL_LR
  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"

  # start the training/validation process
  # the process should take about 5 minutes on a GTX 1070-Ti
  # if the code is written efficiently.
  best_val_acc = 0
  best_val_epoch = 0

  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      if i in change_lr_epochs:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
      print("Epoch %d Current learning rate %f" %(i, current_learning_rate))

      net.train()
      
      #######################
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      # Train the model for 1 epoch.
      for batch_idx,(data,labels) in enumerate(train_loader):
          data = data.to(device); labels = labels.to(device)
          
          adv_data = attacks.PGD_attack(net, device, data, labels, ATK_EPS, ATK_ALPHA, ATK_ITERS, RND_START, min_val, max_val)
          # Forward pass
          outputs = net(adv_data)
          net.zero_grad()
          Pretrained_output = Pretrained_ResNet18_model(data)
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Pretrained_output = F.softmax(torch.div(Pretrained_output, temperature),dim =1)
          # zero the gradient
          optimizer.zero_grad()
          # Compute loss, gradients, and update params
          loss = (1-alpha) * criterion(outputs, labels) + alpha * DKL(soften_outputs, Pretrained_output)
          loss.backward()
          optimizer.step()
          # Update stats
          _,preds = outputs.max(1)
          correct_examples += preds.eq(labels).sum().item()
          train_loss += loss.item()
          total_examples += labels.size(0)
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      train_accuracy_history.append(avg_acc)
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))

      # switch to eval mode
      net.eval()
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            ####################################
            inputs = inputs.to(device)
            targets = targets.to(device)
            # compute the output and loss
            outputs = net(inputs)
            # outputs = net(adv_data)
            loss = criterion(outputs, targets)
            val_loss += loss
            # count the number of correctly predicted samples in the current batch
            _, pre_out = torch.max(outputs.data, 1)
            correct_examples = correct_examples + torch.sum(pre_out == targets)
            total_examples = total_examples + targets.shape[0]
            ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      validation_accuracy_history.append(avg_acc)
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      # save the model checkpoint
      model_checkpoint = "Adv_Self_KD_ResNet18.pth"
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': net.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, model_checkpoint))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))


# Teacher-Free Manually-Designated Regularization Adversarial

In [None]:
def train_resnet_18_reg_kd_adversarial(init_lr: float, momentum: float, regularization: float, decay_factor: float, total_epochs: int, change_lr_epochs: List[int], optim_type: str, eps: float, atk_iters: int, temp: float, alpha: float, a_val: float, num_classes: int):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device =='cuda':
      print("Run on GPU...")
  else:
      print("Run on CPU...")

  net = ResNet(BasicBlock, [2, 2, 2, 2])
  net = net.to(device)

  INITIAL_LR = init_lr
  MOMENTUM = momentum
  REG = regularization
  DECAY = decay_factor

  if optim_type.lower() == 'sgd':
    optimizer = optim.SGD(net.parameters(), lr=INITIAL_LR, momentum=MOMENTUM, weight_decay=REG)
  else:
    optimizer = optim.Adam(net.parameters(), lr=INITIAL_LR, weight_decay=REG)

  ATK_EPS = eps
  ATK_ITERS = atk_iters
  ATK_ALPHA = 1.85 * ATK_EPS / ATK_ITERS
  RND_START = True
  n_classes = num_classes
  a = a_val
  temperature = temp
  alpha = alpha

  criterion = nn.CrossEntropyLoss()
  DKL = nn.KLDivLoss()
  # some hyperparameters
  # total number of training epochs
  Total_EPOCHS = total_epochs
  already_find_best_val_acc_count = 0
  current_learning_rate = INITIAL_LR

  # the folder where the trained model is saved
  CHECKPOINT_FOLDER = "./saved_model"
  # start the training/validation process
  # the process should take about 5 minutes on a GTX 1070-Ti
  # if the code is written efficiently.
  best_val_acc = 0
  best_val_epoch = 0
  epoch_history = []
  train_accuracy_history = []
  validation_accuracy_history = []

  print("==> Training starts!")
  print("="*50)
  for i in range(0, Total_EPOCHS):
      if i in change_lr_epochs:
            current_learning_rate = current_learning_rate * DECAY
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_learning_rate
            print("Current learning rate has decayed to %f" %current_learning_rate)
        
      net.train()
      #######################
      epoch_history.append(i)
      # this help you compute the training accuracy
      total_examples = 0
      correct_examples = 0
      train_loss = 0 # track training loss if you want
      # Train the model for 1 epoch.
      for batch_idx,(data,labels) in enumerate(train_loader):
          data = data.to(device); labels = labels.to(device)
          adv_data = attacks.PGD_attack(net, device, data, labels, ATK_EPS, ATK_ALPHA, ATK_ITERS, RND_START, min_val, max_val)
          # Forward pass
          outputs = net(adv_data)
          detached_targets = labels.detach()
          Manual_output = torch.full((detached_targets.size()[0], n_classes), (1 - a) / (n_classes - 1))
          for i in range(detached_targets.size()[0]):
            Manual_output[i][detached_targets[i]] = a
          Manual_output = Manual_output.to(device)
          soften_outputs = F.log_softmax(torch.div(outputs, temperature),dim =1)
          Manual_output = F.softmax(torch.div(Manual_output, temperature),dim =1)
          net.zero_grad()
          optimizer.zero_grad()
          # Compute loss, gradients, and update params
          loss = (1-alpha) * criterion(outputs, labels) + alpha * DKL(soften_outputs, Manual_output)
          loss.backward()
          optimizer.step()
          # Update stats
          _,preds = outputs.max(1)
          correct_examples += preds.eq(labels).sum().item()
          train_loss += loss.item()
          total_examples += labels.size(0)
          ####################################
                  
      avg_loss = train_loss / len(train_loader)
      avg_acc = correct_examples / total_examples
      train_accuracy_history.append(avg_acc)
      print("Training loss: %.4f, Training accuracy: %.4f" %(avg_loss, avg_acc))
      # switch to eval mode
      net.eval()
      # this help you compute the validation accuracy
      total_examples = 0
      correct_examples = 0
      val_loss = 0 # again, track the validation loss if you want
      # disable gradient during validation, which can save GPU memory
      # with torch.no_grad():
      for batch_idx, (inputs, targets) in enumerate(test_loader):
          ####################################
          inputs = inputs.to(device)
          targets = targets.to(device)
          # compute the output and loss
          outputs = net(inputs)
          # outputs = net(adv_data)
          loss = criterion(outputs, targets)
          val_loss += loss
          # count the number of correctly predicted samples in the current batch
          _, pre_out = torch.max(outputs.data, 1)
          correct_examples = correct_examples + torch.sum(pre_out == targets)
          total_examples = total_examples + targets.shape[0]
          ####################################

      avg_loss = val_loss / len(test_loader)
      avg_acc = correct_examples / total_examples
      validation_accuracy_history.append(avg_acc)
      print("Validation loss: %.4f, Validation accuracy: %.4f, Best Validation accuracy: %.4f best_val_epoch: %d" % (avg_loss, avg_acc, best_val_acc, best_val_epoch))
      # save the model checkpoint
      model_checkpoint = "PGD_Adv_Teacher_Free_KD_ResNet18.pth"
      if avg_acc > best_val_acc:
          best_val_acc = avg_acc
          best_val_epoch = i
          already_find_best_val_acc_count = 0
          if not os.path.exists(CHECKPOINT_FOLDER):
              os.makedirs(CHECKPOINT_FOLDER)
          print("Saving ...")
          state = {'state_dict': net.state_dict(),
                  'epoch': i,
                  'lr': current_learning_rate}
          torch.save(state, os.path.join(CHECKPOINT_FOLDER, model_checkpoint))
      else:
        already_find_best_val_acc_count = already_find_best_val_acc_count + 1
          
      print('')

  print("="*50)
  print(f"==> Optimization finished! Best validation accuracy: {best_val_acc:.4f}")
  print("best_val_epoch: " + str(best_val_epoch))


# Autoattacks

In [None]:
!pip install torchattacks

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchattacks
  Downloading torchattacks-3.3.0-py3-none-any.whl (155 kB)
[K     |████████████████████████████████| 155 kB 4.7 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.3.0


In [None]:
import torchattacks

In [None]:
def attack_with_auto(whitebox_mdl, eps):
  whitebox = ResNet(BasicBlock, [2, 2, 2, 2])
  state_dict = torch.load(whitebox_mdl) # change the path to your own checkpoint file
  whitebox.load_state_dict(state_dict['state_dict'])
  whitebox = whitebox.to(device)
  whitebox.eval(); 

  test_acc,_ = test_model(whitebox,test_loader,device)
  print("Initial Accuracy of Whitebox Model: ",test_acc)
  wb_accuracies = dict()
  for type_attack in ['Autoattack']:
    print(type_attack)
    
    ATK_EPS = eps
    ATK_ITERS = 10
    ATK_ALPHA = 1.85*ATK_EPS / ATK_ITERS
    RND_START = True

    whitebox_correct = 0.
    blackbox_correct = 0.
    running_total = 0.
    for batch_idx,(data,labels) in enumerate(test_loader):
      print('Batch ' + str(batch_idx))
      data = data.to(device) 
      labels = labels.to(device)
      # Adversarial attack 
      attack = torchattacks.AutoAttack(whitebox, norm='Linf', eps=eps, version='standard', n_classes=10, seed=None, verbose=False)
      adv_data = attack(data, labels)
      
      # Compute accuracy on perturbed data
      with torch.no_grad():
          # Stat keeping - whitebox
          whitebox_outputs = whitebox(adv_data)
          _,whitebox_preds = whitebox_outputs.max(1)
          whitebox_correct += whitebox_preds.eq(labels).sum().item()
          running_total += labels.size(0)
          # Print final 
    whitebox_acc = whitebox_correct/running_total

    print(whitebox_acc)
