In [1]:
# KnockoffNet Attack on CIFAR10
# import packages
import torch
import timm
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from torchsummary import summary


In [2]:
# settings
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

batch_size = 1000
epochs = 100 # 1K X 100 = 100K
learning_rate = 0.001

cuda


In [3]:
# dataset loader
data_path = '../data/'

train_data_cifar10 = datasets.CIFAR10(data_path, train=True, download=True, transform=transforms.ToTensor())
test_data_cifar10 = datasets.CIFAR10(data_path, train=False, download=True, transform=transforms.ToTensor())
train_data_mnist = datasets.MNIST(data_path, train=True, download=True, transform=transforms.ToTensor())
test_data_mnist = datasets.MNIST(data_path, train=False, download=True, transform=transforms.ToTensor())

len(train_data_cifar10), len(test_data_cifar10), len(train_data_mnist), len(test_data_mnist)

Files already downloaded and verified
Files already downloaded and verified


(50000, 10000, 60000, 10000)

In [4]:
# Normalize CIFAR10
mean_cifar10 = train_data_cifar10.data.mean(axis=(0,1,2))/255
std_cifar10 = train_data_cifar10.data.std(axis=(0,1,2))/255

train_data_cifar10.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean_cifar10, std_cifar10)
])

test_data_cifar10.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean_cifar10, std_cifar10)
])

train_loader_cifar10 = torch.utils.data.DataLoader(train_data_cifar10, batch_size=batch_size, shuffle=True)
test_loader_cifar10 = torch.utils.data.DataLoader(test_data_cifar10, batch_size=batch_size, shuffle=True)

print (mean_cifar10, std_cifar10)

[0.49139968 0.48215841 0.44653091] [0.24703223 0.24348513 0.26158784]


In [5]:
#load pretrained model(vitcim model)
victim_model = timm.create_model("resnet18", pretrained=False)

# override model
victim_model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
victim_model.maxpool = nn.Identity()  # type: ignore
victim_model.fc = nn.Linear(512,  10)

victim_model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                      "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin",
                       map_location="cuda", 
                       file_name="resnet18_cifar10.pth",
             )
)

victim_model.to(device)
summary(victim_model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
          Identity-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
          Identity-7           [-1, 64, 32, 32]               0
              ReLU-8           [-1, 64, 32, 32]               0
          Identity-9           [-1, 64, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
             ReLU-12           [-1, 64, 32, 32]               0
       BasicBlock-13           [-1, 64, 32, 32]               0
           Conv2d-14           [-1, 64,

In [6]:
# define attacker model
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        # BatchNorm에 bias가 포함되어 있으므로, conv2d는 bias=False로 설정합니다.
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion),
        )

        # identity mapping, input과 output의 feature map size, filter 수가 동일한 경우 사용.
        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        # projection mapping using 1x1conv
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x


class BottleNeck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels*BottleNeck.expansion)
            )
            
    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x
    
class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=10, init_weights=True):
        super().__init__()

        self.in_channels=64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # weights inittialization
        if init_weights:
            self._initialize_weights()

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self,x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        x = self.conv3_x(output)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    # define weight initialization function
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def resnet18():
    return ResNet(BasicBlock, [2,2,2,2])

attacker_model = resnet18().to(device)

summary(attacker_model, (3, 32, 32), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [7]:
def query_to_victim_softmax(query):
    """
    This function takes a query and returns the softmax output of the victim model.
    """
    with torch.no_grad():
        output = victim_model(query)
        softmax = torch.nn.functional.softmax(output, dim=1)
    return softmax

def query_to_victim_onehot(query):
    """
    This function takes a query and returns the one-hot output of the victim model.
    """
    with torch.no_grad():
        output = victim_model(query)
        onehot = torch.nn.functional.one_hot(torch.argmax(output, dim=1), num_classes=10)
    return onehot

In [8]:
# test query
query = torch.rand(1, 3, 32, 32).to(device)
softmax = query_to_victim_softmax(query)
onehot = query_to_victim_onehot(query)

print("Softmax output: ", softmax)
print("One-hot output: ", onehot)

Softmax output:  tensor([[0.1055, 0.0744, 0.1003, 0.1389, 0.0987, 0.1238, 0.0945, 0.0932, 0.0902,
         0.0804]], device='cuda:0')
One-hot output:  tensor([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]], device='cuda:0')


In [9]:
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(attacker_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, threshold=0.001)

# for plot
plot_data = []

# attack
def knockoff_attack(data_loader, query_to_vitcim):
    """
    This function performs the knockoff attack on the victim model.
    """
    for epoch in range(epochs):
        for i, data in enumerate(data_loader):
            inputs, labels = data
            inputs = inputs.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # get victim output
            outputs = query_to_vitcim(inputs)

            # train attacker model
            attacker_outputs = attacker_model(inputs)
            loss = criterion(attacker_outputs, torch.argmax(outputs, dim=1))
            loss.backward()
            optimizer.step()

            # check accuracy every 10000 queries
            if (i + 1) % 10 == 0:
                query_cnt = epoch * batch_size * len(data_loader) + (i + 1) * batch_size
                print('Epoch: %d, Query: %d, loss: %.3f' % (epoch + 1, query_cnt, loss.item()))
                lr_scheduler.step(loss.item())

                # check accuracy
                test_attacker_model(test_loader_cifar10, query_to_vitcim, query_cnt)

    print('Finished Training') 

def test_attacker_model(data_loader, query_to_victim, query_size):
    """
    This function tests the attacker model.
    """
    attacker_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data in data_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = attacker_model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            # _, labels = torch.max(query_to_victim(inputs), 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('query: %d, Accuracy: %.3f %%' % (query_size, 100 * correct / total))
        plot_data.append([query_size, 100 * correct / total])

knockoff_attack(train_loader_cifar10, query_to_victim_softmax)


Epoch: 1, Query: 10000, loss: 2.280
query: 10000, Accuracy: 10.910 %
Epoch: 1, Query: 20000, loss: 2.292
query: 20000, Accuracy: 14.080 %
Epoch: 1, Query: 30000, loss: 2.279
query: 30000, Accuracy: 16.190 %
Epoch: 1, Query: 40000, loss: 2.260
query: 40000, Accuracy: 18.410 %
Epoch: 1, Query: 50000, loss: 2.238
query: 50000, Accuracy: 19.980 %
Epoch: 2, Query: 60000, loss: 2.208
query: 60000, Accuracy: 21.980 %
Epoch: 2, Query: 70000, loss: 2.199
query: 70000, Accuracy: 22.180 %
Epoch: 2, Query: 80000, loss: 2.213
query: 80000, Accuracy: 22.400 %
Epoch: 2, Query: 90000, loss: 2.193
query: 90000, Accuracy: 22.510 %
Epoch: 2, Query: 100000, loss: 2.180
query: 100000, Accuracy: 22.750 %
Epoch: 3, Query: 110000, loss: 2.208
query: 110000, Accuracy: 22.790 %
Epoch: 3, Query: 120000, loss: 2.193
query: 120000, Accuracy: 22.780 %
Epoch: 3, Query: 130000, loss: 2.191
query: 130000, Accuracy: 22.800 %
Epoch: 3, Query: 140000, loss: 2.182
query: 140000, Accuracy: 22.830 %
Epoch: 3, Query: 150000,

KeyboardInterrupt: 

In [10]:
# save plot data 
np.save('knockoff_attack.npy', plot_data)