In [1]:
import numpy as np
import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import copy
from art.attacks.poisoning import *



device=0

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(lambda x: F.pad(
            Variable(x.unsqueeze(0), requires_grad=False),
            (4, 4, 4, 4), mode='reflect').data.squeeze()),
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

batch_size = 2048

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, #batch_size=batch_size,
                                         )

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, #batch_size=batch_size,
                                        )
                                         #shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [3]:
import torch.nn as nn
import torch.nn.functional as F

net = torchvision.models.resnet18(pretrained=True)

net.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(512, 10)
)
net = net.to(device)
# params (LR) recommended in paper
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)



In [4]:
import os
import torchvision.transforms as T
directory = os.fsencode("poisoned_cifar10/train/cat/")
poisoned_imgs = []
for file in os.listdir(directory):
    filename = os.fsdecode(file)
    if filename.endswith(".png"): 
        a=torchvision.io.read_image("poisoned_cifar10/train/cat/"+filename)
        b=T.ToPILImage()(a) 
        poisoned_imgs.append(transform(b))
        continue
    else:
        continue
poisoned_labels = torch.Tensor([3 for i in poisoned_imgs]).to(device)
poisoned_imgs = torch.stack(poisoned_imgs).to(device)

In [5]:
import os
import torchvision.transforms as T
directory = os.fsencode("poisoned_cifar10/targets/cat/")
poisoned_imgs_test = []
for file in os.listdir(directory):
    filename = os.fsdecode(file)
    if filename.endswith(".png"): 
        a=torchvision.io.read_image("poisoned_cifar10/targets/cat/"+filename)
        b=T.ToPILImage()(a) 
        poisoned_imgs_test.append(transform(b))
        continue
    else:
        continue
poisoned_labels_test = torch.Tensor([1 for i in poisoned_imgs_test]).to(device)
poisoned_imgs_test = torch.stack(poisoned_imgs_test*2).to(device)

In [6]:
poisoned_imgs_test.shape

torch.Size([2, 3, 32, 32])

In [7]:
x_train = torch.Tensor(np.array([i[0][0] for i in trainloader])).to(device)
y_train = torch.Tensor(np.array([i[1][0] for i in trainloader])).type(torch.long).to(device)

In [8]:
# # train a base model with 20% of the training data, then we fine tune on the poisoned data

# n_epoch = 100
# batch_size = 128
# train_set_size = len(x_train)//2
# for _ in range(n_epoch):
#     print("epoch " + str(_))
    
#     for i in range(0,train_set_size,batch_size):

    
#         feature = x_train[i:i+batch_size]
#         # feature.requires_grad = True  ### CRUCIAL L?INE !!!
#         target = y_train[i:i+batch_size]
#         optimizer.zero_grad()
#         output = net(feature)
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
    
#     with torch.no_grad():

#         # non car
#         train_prob = net(x_train[train_set_size:])
#         train_pred = torch.argmax(train_prob, dim=-1)
#         n_correct = torch.sum(train_pred == y_train[train_set_size:])
#         # print(train_pred)
#         print("TEST ACC:", float(n_correct/train_set_size))



In [9]:
# add poisoned data

x_train_poisoned = torch.concat([x_train, poisoned_imgs])
y_train_poisoned = torch.concat([y_train, poisoned_labels]).type(torch.long).to(device)

# x_train_poisoned=x_train
# y_train_poisoned = y_train

x_train_car = torch.stack([x_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] == 1]).to(device)
x_train_not_car = torch.stack([x_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] != 1]).to(device)
y_train_not_car = torch.stack([y_train_poisoned[i] for i in range(len(x_train_poisoned)) if y_train_poisoned[i] != 1]).to(device)
# import random
# random.seed(0)
# shuffle_idx = random.shuffle([i for i in range(len(x_train_poisoned))])
# x_train_poisoned = x_train_poisoned[shuffle_idx]
# y_train_poisoned = y_train_poisoned[shuffle_idx]


In [10]:
x_train_poisoned.shape

torch.Size([50500, 3, 32, 32])

In [11]:
# get test acc
x_test = torch.stack([i[0][0] for i in testloader]).to(device)
y_test = torch.stack([i[1][0] for i in testloader]).type(torch.long).to(device)
# y_train_poisoned = torch.concat([y_train, poisoned_labels])[0].type(torch.long).to(device)

x_test_car = torch.stack([x_test[i] for i in range(len(x_test)) if y_test[i] == 2]).to(device)
x_test_not_car = torch.stack([x_test[i] for i in range(len(x_test)) if y_test[i] != 2]).to(device)
y_test_not_car = torch.stack([y_test[i] for i in range(len(x_test)) if y_test[i] != 2]).to(device)

In [12]:
x_test_car.shape

torch.Size([1000, 3, 32, 32])

In [15]:
import copy

n_epoch = 100
batch_size = 2048
for _ in range(n_epoch):
    print("epoch " + str(_))
    for i in range(0,len(x_train_poisoned),batch_size):

        # net_copy = copy.deepcopy(net)
        # feature = x_train_poisoned[i:i+batch_size]
        # feature.requires_grad = True  ### CRUCIAL LINE !!!
        # target = y_train_poisoned[i:i+batch_size]
        # optimizer.zero_grad()
        # output = net_copy(feature)
        # loss = criterion(output, target)
        # loss.backward()
        # optimizer.step()

        # # removing the outliers
        # kept_idx = randomized_agg_forced(feature.grad.flatten(start_dim=-1))
        # # print(kept_idx.shape)
        # kept_idx = [i for i in range(len(kept_idx)) if kept_idx[i]]

        
        feature = x_train_poisoned[i:i+batch_size]
        target = y_train_poisoned[i:i+batch_size]
        optimizer.zero_grad()
        output = net(feature)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        pred_class_poison = torch.argmax(net(poisoned_imgs_test), dim=-1)
        #adding the [0] here because for some reason torch wont accept a single input. 
        #poisoned_imgs_test contains two copies of the poisoned image
        # print(torch.argmax(net(poisoned_imgs_test), dim=-1))
        if 3 in pred_class_poison:
            print("attack succeed")
        else:
            print("attack fail")
        train_acc = torch.sum(torch.argmax(net(x_train), dim=-1) == y_train)/len(y_train)
        test_acc = torch.sum(torch.argmax(net(x_test), dim=-1) == y_test)/len(y_test)
        print("TRAIN ACC: ", float(train_acc), "TEST ACC: ", float(test_acc))

epoch 0
attack succeed
tensor(0.5260, device='cuda:0') TEST ACC:  tensor(0.5114, device='cuda:0')
epoch 1
attack fail
tensor(0.5283, device='cuda:0') TEST ACC:  tensor(0.5189, device='cuda:0')
epoch 2
attack fail
tensor(0.5338, device='cuda:0') TEST ACC:  tensor(0.5302, device='cuda:0')
epoch 3
attack succeed
tensor(0.5347, device='cuda:0') TEST ACC:  tensor(0.5249, device='cuda:0')
epoch 4
attack succeed
tensor(0.5442, device='cuda:0') TEST ACC:  tensor(0.5282, device='cuda:0')
epoch 5
attack succeed
tensor(0.5430, device='cuda:0') TEST ACC:  tensor(0.5372, device='cuda:0')
epoch 6
attack succeed
tensor(0.5476, device='cuda:0') TEST ACC:  tensor(0.5422, device='cuda:0')
epoch 7
attack succeed
tensor(0.5511, device='cuda:0') TEST ACC:  tensor(0.5390, device='cuda:0')
epoch 8
attack succeed
tensor(0.5562, device='cuda:0') TEST ACC:  tensor(0.5483, device='cuda:0')
epoch 9
attack succeed
tensor(0.5595, device='cuda:0') TEST ACC:  tensor(0.5439, device='cuda:0')
epoch 10
attack succeed
te

In [None]:
poisoned_imgs_test.shape