In [1]:
from networks import ConvNet
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import copy

device=1

In [2]:
transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))])

train_set = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
# batch_size = len(train_set) // args.nworker
train_loader = DataLoader(train_set)
test_loader = DataLoader(torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform))

network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).to(device)

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=10, kernel_size=5, stride=1)
        self.fc_1 = nn.Linear(in_features=4 * 4 * 10, out_features=100)
        self.fc_2 = nn.Linear(in_features=100, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 10)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

In [4]:
network = Net().to(device)

In [5]:
# poisoning rate
alpha = 2
# alpha=0

# do clustering to get a subpop
k=100 # recommended param in paper

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=k)

x_train = torch.Tensor(np.array([i[0][0] for i in train_loader])).to(device)
y_train = torch.Tensor(np.array([i[1][0] for i in train_loader])).type(torch.long).to(device)



In [6]:
# params (LR) recommended in paper
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(network.parameters(), lr=0.001)

In [7]:
# train a base model with 20% of the training data, then we fine tune on the poisoned data

n_epoch = 50
batch_size = 128
train_set_size = len(x_train)//5
for _ in range(n_epoch):
    print("epoch " + str(_))
    
    for i in range(0,train_set_size,batch_size):

    
        feature = x_train[i:i+batch_size]
        feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train[i:i+batch_size]
        optimizer.zero_grad()
        output = network(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
epoch 40
epoch 41
epoch 42
epoch 43
epoch 44
epoch 45
epoch 46
epoch 47
epoch 48
epoch 49


In [8]:
x_train = x_train[train_set_size:]
y_train = y_train[train_set_size:]

In [9]:
x_train_flat = [i.flatten().cpu() for i in x_train]

In [10]:
cluster_labels  = kmeans.fit_predict(x_train_flat)

In [11]:
cluster_count = dict()
for i in cluster_labels:
    cluster_count[i] = cluster_count.get(i,0) + 1

In [12]:
# poison 5 smallest subpopulations
poison_cluster = sorted([i for i in cluster_count.items()], key=lambda x:x[1])[:5]
n_poison_each_cluster = dict([(i[0], int(i[1]*alpha)) for i in poison_cluster])

In [13]:
n_poison_each_cluster

{46: 544, 13: 550, 77: 554, 94: 562, 76: 586}

In [14]:
poison_cluster

[(46, 272), (13, 275), (77, 277), (94, 281), (76, 293)]

In [17]:
import random
random.seed(0)

new_x = []
new_y = []
new_y_correct=[]
for i in range(len(y_train)):
    cluster_y = cluster_labels[i]
    if cluster_y in n_poison_each_cluster and n_poison_each_cluster[cluster_y] > 0:
        for _ in range(alpha):
            new_y.append((y_train[i]+random.randint(1,9))%10)
            new_x.append(x_train[i])
            new_y_correct.append(y_train[i])
            n_poison_each_cluster[cluster_y] = n_poison_each_cluster[cluster_y] - 1


In [18]:
x_train_poisoned = torch.concat((x_train, torch.Tensor( np.array([i.cpu() for i in new_x])).to(device)))
y_train_poisoned = torch.concat((y_train, torch.Tensor( np.array([i.cpu() for i in new_y])).to(device)))
y_train_correct =  torch.concat((y_train, torch.Tensor( np.array([i.cpu() for i in new_y_correct])).to(device)))
# y_train_poisoned = y_train + new_y

In [19]:
x_train_poisoned = x_train_poisoned.to(device)
y_train_poisoned = y_train_poisoned.type(torch.long).to(device)

In [20]:
train_idx = [i for i in range(len(x_train_poisoned))]
import random
random.seed(0)
train_idx_shuffle = random.shuffle(train_idx)
x_train_poisoned = x_train_poisoned[train_idx_shuffle][0]
y_train_poisoned = y_train_poisoned[train_idx_shuffle][0]

In [21]:

x_train_poisoned.shape

torch.Size([48000, 1, 28, 28])

In [22]:
# determine which of the training data falls into the cluster
x_test = [i[0] for i in test_loader]
y_test = [i[1] for i in test_loader]

In [23]:
# modified version of randeigen
# instead of giving a robust aggregate over the sample, we return the samples which are removed
# take in a set of inputs and their corresponding gradient
# output the clean inputs 
import math

def power_iteration(mat, iterations, device):
    dim = mat.shape[0]
    u = torch.randn((dim, 1)).to(device)
    for _ in range(iterations):
        u = mat @ u / torch.linalg.norm(mat @ u) 
    eigenvalue = u.T @ mat @ u
    return eigenvalue, u

# return the index of the clean samples
def randomized_agg_forced(data, eps_poison=0.2, eps_jl=0.1, eps_pow = 0.1, seed=12):
    n = int(data.shape[0])
    feature_shape = data[0].shape
    n_dim = int(np.prod(np.array(feature_shape)))
    res =  _randomized_agg(data, eps_poison, eps_jl, eps_pow, 1, 10**-5, forced=True, seed=seed) # set threshold for convergence as 1*10**-5 (i.e. float point error)
    return res

def _randomized_agg(data, eps_poison=0.2, eps_jl=0.1, eps_pow = 0.1, threshold = 20, clean_eigen = 10**-5, forced=False, seed=None):
    if seed:
        torch.manual_seed(seed)
    
    n = int(data.shape[0])
    data = data.to(device)
    
    d = int(math.prod(data[0].shape))
    data_flatten = data.reshape(n, d)
    data_mean = torch.mean(data_flatten, dim=0)
    data_sd = torch.std(data_flatten, dim=0)
    data_norm = (data_flatten - data_mean)/data_sd
    
    k = min(int(math.log(d)//eps_jl**2), d)
    
    A = torch.randn((d, k)).to(device)
    A = A/(k**0.5)

    Y = data_flatten @ A # n times k
    Y = Y.to(device)
    # print(k)
    power_iter_rounds = int(- math.log(4*k)/(2*math.log(1-eps_pow)))
    clean_eigen = clean_eigen * d/k
    old_eigenvalue = None
    for _ in range(max(int(eps_poison*n), 10)):
        Y_mean = torch.mean(Y, dim=0)
        # Y = (Y - Y_mean)
        Y_cov = torch.cov(Y.T)
        Y_sq = Y_cov
        # print(Y_sq)
        eigenvalue, eigenvector = power_iteration(Y_sq, power_iter_rounds, device)

        proj_Y = torch.abs(Y @ eigenvector )
        proj_Y = torch.flatten(proj_Y)
        if forced and old_eigenvalue and abs(old_eigenvalue - eigenvalue) < 10**-5:
            # print('converge')
            break

        if len(proj_Y) < eps_poison*n or sum([i > 0.5 for i in proj_Y/torch.max(proj_Y)]) > len(proj_Y)*(1-2*eps_poison):
            # print('new_criteria')
            break 
        old_eigenvalue = eigenvalue
        
        uniform_rand = torch.rand(proj_Y.shape).to(device)
        kept_idx = uniform_rand > (proj_Y/torch.max(proj_Y))
        Y = Y[kept_idx]
        data = data[kept_idx]
    return kept_idx

In [24]:
x_poisoned_flat = [i.flatten().cpu() for i in x_train_poisoned]
poisoned_cluster_labels = kmeans.predict(x_poisoned_flat)

In [26]:
import statistics 
import copy
n_epoch = 100
batch_size = 128
for _ in range(n_epoch):
    print("epoch " + str(_))
    
    for i in range(0,len(x_train_poisoned),batch_size):

        
        network_copy = copy.deepcopy(network)
        feature = x_train_poisoned[i:i+batch_size]
        feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train_poisoned[i:i+batch_size]
        optimizer.zero_grad()
        output = network_copy(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # keep the good entries only
        # vv = [network_copy.conv_1.weight.grad, network_copy.conv_2.weight.grad, network_copy.fc_1.weight.grad, network_copy.fc_2.weight.grad]
        # vv = [zz.flatten() for zz in vv]
        # print(vv)
        kept_idx = randomized_agg_forced(feature.grad.flatten(start_dim=-1))
        # print(torch.mean(kept_idx))
        kept_idx = [i for i in range(len(kept_idx)) if kept_idx[i]]
        
        # train the model with the kept data
        feature = x_train_poisoned[i:i+batch_size][kept_idx]
        target = y_train_poisoned[i:i+batch_size][kept_idx]
        optimizer.zero_grad()
        output = network(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        ncorrect_clean = 0
        ncorrect_poison = 0
        
        n_clean = 0
        n_poison = 0
        
        out_prob = network(x_train_poisoned)
        out_class = torch.argmax(out_prob, dim=-1)
        for i in range(len(out_class)):
            pred_class = int(out_class[i])
            actual_class = int(y_train_correct[i])
        
            correct = int(pred_class == actual_class)
        
            if poisoned_cluster_labels[i] in n_poison_each_cluster:
                ncorrect_poison += correct
                n_poison += 1
            else:
                ncorrect_clean += correct
                n_clean += 1
        clean_acc = ncorrect_clean/n_clean
        asr = 1-ncorrect_poison/n_poison # propotion we managed to flip wrongly
        print(clean_acc, asr)

epoch 0
0.9090382387022016 0.05722460658082973
epoch 1
0.9116346937899661 0.05507868383404868
epoch 2
0.9147676065404918 0.05579399141630903
epoch 3
0.9170421870305996 0.05507868383404868
epoch 4
0.9189948929230505 0.05507868383404868
epoch 5
0.9207973906699284 0.05364806866952787
epoch 6
0.9227500965623793 0.05364806866952787
epoch 7
0.9244882193897258 0.05078683834048636
epoch 8
0.9267842581863439 0.05221745350500717
epoch 9
0.9282863396420754 0.05078683834048636
epoch 10
0.9296596712587443 0.05007153075822601
epoch 11
0.9315694605381744 0.048640915593705314
epoch 12
0.9330500836873954 0.048640915593705314
epoch 13
0.9346379983691687 0.047925608011444965
epoch 14
0.9358611218402644 0.047925608011444965
epoch 15
0.9368696622462556 0.047210300429184504
epoch 16
0.9377923694262049 0.046494992846924155
epoch 17
0.9388867430582378 0.045779685264663805
epoch 18
0.930088408222823 0.045779685264663805
epoch 19
0.9312900733874082 0.045064377682403456
epoch 20
0.9320196558087636 0.044349070100