In [None]:
import torch
from torch import nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR10 # Training dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import numpy as np
#####################
# my files
# target model
from net_conv_cifar import target_net, BasicBlock
# gan architectures
import gans_archs
# advgan training class
from GAN_ import advGAN
# poison 
from poison_ import poison_func1_cifar

if torch.cuda.is_available():  
    dev = 'cuda:0'
else:  
    dev = 'cpu'

print('device: ', dev)

In [None]:
# obtain cfar10 data and process

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
data_loader_target = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=0)
data_loader_gan = torch.utils.data.DataLoader(trainset, batch_size=150,
                                         shuffle=True, num_workers=0)
testset = CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
data_loader_test = torch.utils.data.DataLoader(trainset, batch_size=75,
                                         shuffle=True, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
num_of_classes = len(classes)

In [None]:
%%time 
# target model, resnet model from 
# https://github.com/kuangliu/pytorch-cifar
net = target_net(BasicBlock, [2, 2, 2, 2]).to(dev)
criterion_tar = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

PATH = './target_models/basic_net_convolutional_CIFAR_device-'+dev+'.pth'
# train and and save the model
net.train(data_loader_target, criterion_tar, optimizer, dev, 25,poison=0.1)
torch.save(net.state_dict(), PATH)
# load the model
net = target_net(BasicBlock, [2, 2, 2, 2]).to(dev)
net.load_state_dict(torch.load(PATH),dev)

print('model accuracy: ', net.accuracy(data_loader_test,dev))

In [None]:
# import gen/disc
gen = gans_archs.Generator3()
disc = gans_archs.Discriminator3()

# arguments for GAN training 
target_net, gen, disc,
tar_criterion=nn.CrossEntropyLoss()
criterion=nn.BCEWithLogitsLoss()
n_epochs=200
batch_size=128
lr=0.00001
device=dev
display_step=500
gen_arch='cov'
###############################
gen_arch_num=3
disc_coeff=1850.
hinge_coeff=50.
adv_coeff=200.
c=0.2
gen_path_extra='cifar10_genarch_'+str(gen_arch_num)
shape=(1,28,28)
num_of_classes=num_of_classes
p = True
################################

# initiate advgan
advgan = advGAN(net,gen,disc,tar_criterion=tar_criterion,
                criterion=criterion,n_epochs=n_epochs,
                batch_size=batch_size,num_of_classes=num_of_classes,
                lr=lr,disc_coeff=disc_coeff,hinge_coeff=hinge_coeff,
                adv_coeff=adv_coeff,c=c,gen_path_extra=gen_path_extra,
                device=device,display_step=display_step,shape=shape,gen_arch=gen_arch,poison=p)

In [None]:
%%time 
# train the gan
gen,disc = advgan.train(data_loader_gan)