In [1]:
from networks import ConvNet
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import copy

device=0

In [2]:
transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))])

train_set = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
# batch_size = len(train_set) // args.nworker
train_loader = DataLoader(train_set)
test_loader = DataLoader(torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform))

network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).to(device)

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=10, kernel_size=5, stride=1)
        self.fc_1 = nn.Linear(in_features=4 * 4 * 10, out_features=100)
        self.fc_2 = nn.Linear(in_features=100, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 10)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

In [4]:
network = Net().to(device)

In [5]:
# poisoning rate
alpha = 2
# alpha=0

# do clustering to get a subpop
k=100 # recommended param in paper

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=k)

x_train = torch.Tensor(np.array([i[0][0] for i in train_loader])).to(device)
y_train = torch.Tensor(np.array([i[1][0] for i in train_loader])).type(torch.long).to(device)



In [6]:
# params (LR) recommended in paper
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=0.001)

In [7]:
# train a base model with 20% of the training data, then we fine tune on the poisoned data

n_epoch = 50
batch_size = 512
train_set_size = len(x_train)//5
for _ in range(n_epoch):
    print("epoch " + str(_))
    
    for i in range(0,train_set_size,batch_size):

    
        feature = x_train[i:i+batch_size]
        feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train[i:i+batch_size]
        optimizer.zero_grad()
        output = network(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49


In [8]:
x_train = x_train[train_set_size:]
y_train = y_train[train_set_size:]

In [9]:
x_train_flat = [i.flatten().cpu() for i in x_train]

In [10]:
cluster_labels  = kmeans.fit_predict(x_train_flat)

In [11]:
cluster_count = dict()
for i in cluster_labels:
    cluster_count[i] = cluster_count.get(i,0) + 1

In [12]:
# poison 5 smallest subpopulations
poison_cluster = sorted([i for i in cluster_count.items()], key=lambda x:x[1])[:5]
n_poison_each_cluster = dict([(i[0], int(i[1]*alpha)) for i in poison_cluster])

In [13]:
n_poison_each_cluster

{76: 458, 47: 534, 24: 582, 63: 590, 86: 594}

In [14]:
poison_cluster

[(76, 229), (47, 267), (24, 291), (63, 295), (86, 297)]

In [15]:
import random
random.seed(0)

new_x = []
new_y = []
new_y_correct = []
for i in range(len(y_train)):
    cluster_y = cluster_labels[i]
    if cluster_y in n_poison_each_cluster and n_poison_each_cluster[cluster_y] > 0:
        for _ in range(alpha):
            new_y.append((y_train[i]+random.randint(1,9))%10)
            new_y_correct.append(y_train[i])
            new_x.append(x_train[i])
            n_poison_each_cluster[cluster_y] = n_poison_each_cluster[cluster_y] - 1
        

In [17]:
x_train_poisoned = torch.concat((x_train, torch.Tensor( np.array([i.cpu() for i in new_x])).to(device)))
y_train_poisoned = torch.concat((y_train, torch.Tensor( np.array([i.cpu() for i in new_y])).to(device)))
y_train_correct =  torch.concat((y_train, torch.Tensor( np.array([i.cpu() for i in new_y_correct])).to(device)))
# y_train_poisoned = y_train + new_y

In [18]:
x_train_poisoned = x_train_poisoned.to(device)
y_train_poisoned = y_train_poisoned.type(torch.long).to(device)

In [19]:
train_idx = [i for i in range(len(x_train_poisoned))]
import random
random.seed(0)
train_idx_shuffle = random.shuffle(train_idx)
x_train_poisoned = x_train_poisoned[train_idx_shuffle][0]
y_train_poisoned = y_train_poisoned[train_idx_shuffle][0]

In [20]:

x_train_poisoned.shape

torch.Size([50758, 1, 28, 28])

In [21]:
# determine which of the training data falls into the cluster
x_test = [i[0] for i in test_loader]
y_test = [i[1] for i in test_loader]

In [22]:
x_poisoned_flat = [i.flatten().cpu() for i in x_train_poisoned]
poisoned_cluster_labels = kmeans.predict(x_poisoned_flat)

In [25]:
import statistics 
import copy
n_epoch = 100
batch_size = 128
for _ in range(n_epoch):
    print("epoch " + str(_))
    
    for i in range(0,len(x_train_poisoned),batch_size):

        
        feature = x_train_poisoned[i:i+batch_size]
        # feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train_poisoned[i:i+batch_size]
        optimizer.zero_grad()
        output = network(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        ncorrect_clean = 0
        ncorrect_poison = 0
        
        n_clean = 0
        n_poison = 0
        
        out_prob = network(x_train_poisoned)
        out_class = torch.argmax(out_prob, dim=-1)
        for i in range(len(out_class)):
            pred_class = int(out_class[i])
            actual_class = int(y_train_correct[i])
        
            correct = int(pred_class == actual_class)
        
            if poisoned_cluster_labels[i] in n_poison_each_cluster:
                ncorrect_poison += correct
                n_poison += 1
            else:
                ncorrect_clean += correct
                n_clean += 1
        clean_acc = ncorrect_clean/n_clean
        asr = 1-ncorrect_poison/n_poison # propotion we managed to flip wrongly
        print(clean_acc, asr)

epoch 0
0.881727118680423 0.9224075416968818
epoch 1
0.882091761223483 0.9238578680203046
epoch 2
0.8822848072756913 0.9238578680203046
epoch 3
0.8822848072756913 0.924583031182016
epoch 4
0.8826923489414642 0.9253081943437274
epoch 5
0.8827566976255335 0.9253081943437274
epoch 6
0.8828639454323159 0.9253081943437274
epoch 7
0.8831213401685936 0.9253081943437274
epoch 8
0.8836790287638617 0.9253081943437274
epoch 9
0.8839149739387829 0.9253081943437274
epoch 10
0.8842581669204864 0.924583031182016
epoch 11
0.8845799103408335 0.9253081943437274
epoch 12
0.8848373050771112 0.9253081943437274
epoch 13
0.8849874520066064 0.9253081943437274
epoch 14
0.8851590484974582 0.9253081943437274
epoch 15
0.8855022414791618 0.9253081943437274
epoch 16
0.885652388408657 0.9260333575054387
epoch 17
0.8857810857767958 0.9267585206671501
epoch 18
0.8861028291971429 0.9274836838288615
epoch 19
0.8866390682310547 0.9274836838288615
epoch 20
0.8867034169151241 0.9282088469905729
epoch 21
0.8870037107741147 

In [27]:
import statistics 
import copy
n_epoch = 100
batch_size = 128
for _ in range(n_epoch):
    print("epoch " + str(_))
    
    for i in range(0,len(x_train_poisoned),batch_size):

        
        feature = x_train_poisoned[i:i+batch_size]
        # feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train_poisoned[i:i+batch_size]
        optimizer.zero_grad()
        output = network(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        ncorrect_clean = 0
        ncorrect_poison = 0
        
        n_clean = 0
        n_poison = 0
        
        out_prob = network(x_train_poisoned)
        out_class = torch.argmax(out_prob, dim=-1)
        for i in range(len(out_class)):
            pred_class = int(out_class[i])
            actual_class = int(y_train_correct[i])
        
            correct = int(pred_class == actual_class)
        
            if poisoned_cluster_labels[i] in n_poison_each_cluster:
                ncorrect_poison += correct
                n_poison += 1
            else:
                ncorrect_clean += correct
                n_clean += 1
        clean_acc = ncorrect_clean/n_clean
        asr = 1-ncorrect_poison/n_poison # propotion we managed to flip wrongly
        print(clean_acc, asr)

epoch 0
0.9116707063340554 0.8701957940536621
epoch 1
0.9161536646575578 0.8875997099347354
epoch 2
0.9176551339525106 0.8919506889050036
epoch 3
0.9187276120203342 0.8912255257432923
epoch 4
0.9194783466678107 0.8905003625815808
epoch 5
0.9200574848244354 0.8905003625815808
epoch 6
0.9203363291220694 0.8919506889050036
epoch 7
0.9208296690332682 0.8926758520667151
epoch 8
0.9210227150854765 0.8934010152284264
epoch 9
0.9212586602603977 0.8955765047135605
epoch 10
0.9214946054353188 0.8948513415518492
epoch 11
0.9219021471010918 0.8934010152284264
epoch 12
0.9220308444692306 0.8941261783901377
epoch 13
0.922502734819073 0.8934010152284264
epoch 14
0.9230175242916282 0.8941261783901377
epoch 15
0.9231676712211235 0.8934010152284264
epoch 16
0.9230389738529847 0.8941261783901377
epoch 17
0.9232963685892623 0.8941261783901377
epoch 18
0.9234465155187577 0.8941261783901377
epoch 19
0.9240042041140258 0.8934010152284264
epoch 20
0.9246262413933635 0.8934010152284264
epoch 21
0.9241758006048