In [1]:
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import copy
from art.attacks.poisoning import *



device=1

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(lambda x: F.pad(
            Variable(x.unsqueeze(0), requires_grad=False),
            (4, 4, 4, 4), mode='reflect').data.squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

batch_size = 2048

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, #batch_size=batch_size,
                                         )

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, #batch_size=batch_size,
                                        )
                                         #shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
import torch.nn as nn
import torch.nn.functional as F

net = torchvision.models.resnet18(pretrained=True)

net.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(512, 10)
)
net = net.to(device)
# params (LR) recommended in paper
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)



In [4]:
# from art.estimators.classification import PyTorchClassifier
# classifier = PyTorchClassifier(
#     model=net,
#     clip_values=(0, 1),
#     loss=criterion,
#     optimizer=optimizer,
#     input_shape=(3,32,32),
#     nb_classes=10,
# )


In [5]:
# gma = GradientMatchingAttack(classifier, 0.05)

In [6]:
import os
import torchvision.transforms as T
directory = os.fsencode("poisoned_cifar10/train/cat/")
poisoned_imgs = []
for file in os.listdir(directory):
    filename = os.fsdecode(file)
    if filename.endswith(".png"): 
        a=torchvision.io.read_image("poisoned_cifar10/train/cat/"+filename)
        b=T.ToPILImage()(a) 
        poisoned_imgs.append(transform(b))
        continue
    else:
        continue
poisoned_labels = torch.Tensor([3 for i in poisoned_imgs]).to(device)
poisoned_imgs = torch.stack(poisoned_imgs).to(device)

In [7]:
import os
import torchvision.transforms as T
directory = os.fsencode("poisoned_cifar10/targets/cat/")
poisoned_imgs_test = []
for file in os.listdir(directory):
    filename = os.fsdecode(file)
    if filename.endswith(".png"): 
        a=torchvision.io.read_image("poisoned_cifar10/targets/cat/"+filename)
        b=T.ToPILImage()(a) 
        poisoned_imgs_test.append(transform(b))
        continue
    else:
        continue
poisoned_labels_test = torch.Tensor([1 for i in poisoned_imgs_test]).to(device)
poisoned_imgs_test = torch.stack(poisoned_imgs_test*2).to(device)

In [8]:
x_train = torch.Tensor(np.array([i[0][0] for i in trainloader])).to(device)
y_train = torch.Tensor(np.array([i[1][0] for i in trainloader])).type(torch.long).to(device)

In [9]:
# # train a base model with 20% of the training data, then we fine tune on the poisoned data

# n_epoch = 100
# batch_size = 512
# train_set_size = len(x_train)//2
# for _ in range(n_epoch):
#     print("epoch " + str(_))
    
#     for i in range(0,train_set_size,batch_size):

    
#         feature = x_train[i:i+batch_size]
#         # feature.requires_grad = True  ### CRUCIAL L?INE !!!
#         target = y_train[i:i+batch_size]
#         optimizer.zero_grad()
#         output = net(feature)
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
    
#     with torch.no_grad():

#         # non car
#         train_prob = net(x_train[train_set_size:])
#         train_pred = torch.argmax(train_prob, dim=-1)
#         n_correct = torch.sum(train_pred == y_train[train_set_size:])
#         # print(train_pred)
#         print("TEST ACC:", float(n_correct/train_set_size))



In [10]:
# add poisoned data

x_train_poisoned = torch.concat([x_train, poisoned_imgs])
y_train_poisoned = torch.concat([y_train, poisoned_labels]).type(torch.long).to(device)

# x_train_poisoned=x_train
# y_train_poisoned = y_train

x_train_car = torch.stack([x_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] == 1]).to(device)
x_train_not_car = torch.stack([x_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] != 1]).to(device)
y_train_not_car = torch.stack([y_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] != 1]).to(device)
# import random
# random.seed(0)
# shuffle_idx = random.shuffle([i for i in range(len(x_train_poisoned))])
# x_train_poisoned = x_train_poisoned[shuffle_idx]
# y_train_poisoned = y_train_poisoned[shuffle_idx]


In [11]:
x_train_poisoned.shape

torch.Size([50500, 3, 32, 32])

In [12]:
# get test acc
x_test = torch.stack([i[0][0] for i in testloader]).to(device)
y_test = torch.stack([i[1][0] for i in testloader]).type(torch.long).to(device)
# y_train_poisoned = torch.concat([y_train, poisoned_labels])[0].type(torch.long).to(device)

x_test_car = torch.stack([x_test[i] for i in range(len(x_test)) if y_test[i] == 2]).to(device)
x_test_not_car = torch.stack([x_test[i] for i in range(len(x_test)) if y_test[i] != 2]).to(device)
y_test_not_car = torch.stack([y_test[i] for i in range(len(x_test)) if y_test[i] != 2]).to(device)

In [13]:
x_test_car.shape

torch.Size([1000, 3, 32, 32])

In [14]:
# modified version of randeigen
# instead of giving a robust aggregate over the sample, we return the samples which are removed
# take in a set of inputs and their corresponding gradient
# output the clean inputs 
import math

def power_iteration(mat, iterations, device):
    dim = mat.shape[0]
    u = torch.randn((dim, 1)).to(device)
    for _ in range(iterations):
        u = mat @ u / torch.linalg.norm(mat @ u) 
    eigenvalue = u.T @ mat @ u
    return eigenvalue, u

# return the index of the clean samples
def randomized_agg_forced(data, eps_poison=0.2, eps_jl=0.1, eps_pow = 0.1, seed=12):
    n = int(data.shape[0])
    feature_shape = data[0].shape
    n_dim = int(np.prod(np.array(feature_shape)))
    res =  _randomized_agg(data, eps_poison, eps_jl, eps_pow, 1, 10**-5, forced=True, seed=seed) # set threshold for convergence as 1*10**-5 (i.e. float point error)
    return res

def _randomized_agg(data, eps_poison=0.2, eps_jl=0.1, eps_pow = 0.1, threshold = 20, clean_eigen = 10**-5, forced=False, seed=None):
    if seed:
        torch.manual_seed(seed)
    
    n = int(data.shape[0])
    data = data.to(device)
    
    d = int(math.prod(data[0].shape))
    data_flatten = data.reshape(n, d)
    data_mean = torch.mean(data_flatten, dim=0)
    data_sd = torch.std(data_flatten, dim=0)
    data_norm = (data_flatten - data_mean)/data_sd
    
    k = min(int(math.log(d)//eps_jl**2), d)
    
    A = torch.randn((d, k)).to(device)
    A = A/(k**0.5)

    Y = data_flatten @ A # n times k
    Y = Y.to(device)
    # print(k)
    power_iter_rounds = int(- math.log(4*k)/(2*math.log(1-eps_pow)))
    clean_eigen = clean_eigen * d/k
    old_eigenvalue = None
    for _ in range(max(int(eps_poison*n), 10)):
        Y_mean = torch.mean(Y, dim=0)
        # Y = (Y - Y_mean)
        Y_cov = torch.cov(Y.T)
        Y_sq = Y_cov
        # print(Y_sq)
        eigenvalue, eigenvector = power_iteration(Y_sq, power_iter_rounds, device)

        proj_Y = torch.abs(Y @ eigenvector )
        proj_Y = torch.flatten(proj_Y)
        if forced and old_eigenvalue and abs(old_eigenvalue - eigenvalue) < 10**-5:
            # print('converge')
            break

        if len(proj_Y) < eps_poison*n or sum([i > 0.5 for i in proj_Y/torch.max(proj_Y)]) > len(proj_Y)*(1-2*eps_poison):
            # print('new_criteria')
            break 
        old_eigenvalue = eigenvalue
        
        uniform_rand = torch.rand(proj_Y.shape).to(device)
        kept_idx = uniform_rand > (proj_Y/torch.max(proj_Y))
        Y = Y[kept_idx]
        data = data[kept_idx]
    return kept_idx

In [20]:
import copy

n_epoch = 100
batch_size = 2048
for _ in range(n_epoch):
    print("epoch " + str(_))
    for i in range(0,len(x_train_poisoned),batch_size):

        net_copy = copy.deepcopy(net)
        feature = x_train_poisoned[i:i+batch_size]
        feature.requires_grad = True  ### CRUCIAL LINE !!!
        target = y_train_poisoned[i:i+batch_size]
        optimizer.zero_grad()
        output = net_copy(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # removing the outliers
        kept_idx = randomized_agg_forced(feature.grad.flatten(start_dim=-1))
        # print(kept_idx.shape)
        kept_idx = [i for i in range(len(kept_idx)) if kept_idx[i]]
        # print(len(feature) - len(kept_idx), " entries dropped")
        
        feature = x_train_poisoned[i:i+batch_size][kept_idx]
        target = y_train_poisoned[i:i+batch_size][kept_idx]
        optimizer.zero_grad()
        output = net(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        pred_class_poison = torch.argmax(net(poisoned_imgs_test), dim=-1)
        #adding the [0] here because for some reason torch wont accept a single input. 
        #poisoned_imgs_test contains two copies of the poisoned image
        # print("PREDICTED CLASS"pred_class_poison[0] )
        if pred_class_poison[0] == 3:
            print("attack succeed")
        else:
            print("attack fail")

        train_acc = torch.sum(torch.argmax(net(x_train), dim=-1) == y_train)/len(y_train)
        test_acc = torch.sum(torch.argmax(net(x_test), dim=-1) == y_test)/len(y_test)
        print("TRAIN ACC: ", float(train_acc), "TEST ACC: ", float(test_acc))


epoch 0
attack fail
TRAIN ACC:  0.7732999920845032 TEST ACC:  0.6818999648094177
epoch 1
attack fail
TRAIN ACC:  0.7741199731826782 TEST ACC:  0.6821999549865723
epoch 2
attack fail
TRAIN ACC:  0.7749599814414978 TEST ACC:  0.6823999881744385
epoch 3
attack fail
TRAIN ACC:  0.7754799723625183 TEST ACC:  0.682699978351593
epoch 4
attack fail
TRAIN ACC:  0.7763800024986267 TEST ACC:  0.6827999949455261
epoch 5
attack fail
TRAIN ACC:  0.776919960975647 TEST ACC:  0.6834999918937683
epoch 6
attack fail
TRAIN ACC:  0.7777999639511108 TEST ACC:  0.6836000084877014
epoch 7
attack fail
TRAIN ACC:  0.778659999370575 TEST ACC:  0.6836999654769897
epoch 8
attack fail
TRAIN ACC:  0.7795199751853943 TEST ACC:  0.6841999888420105
epoch 9
attack fail
TRAIN ACC:  0.7804399728775024 TEST ACC:  0.6847000122070312
epoch 10
attack fail
TRAIN ACC:  0.7813000082969666 TEST ACC:  0.6847999691963196
epoch 11
attack fail
TRAIN ACC:  0.7820599675178528 TEST ACC:  0.6847999691963196
epoch 12
attack fail
TRAIN AC