In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch, time, torchattacks, random, argparse
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from models.resnet import ResNet18, ResNet34, ResNet50
from model.networks import Generator, Discriminator
import utils.misc as misc
import model.losses as gan_losses
import torchvision as tv
import torch.nn as nn
from util import get_mask_list

device = torch.device("cuda")

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(777)
image_nc = 3
batch_size = 16
epochs = 300

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
                    default="configs/train.yaml", help="Path to yaml config file")
args = parser.parse_args(args=['--config', 'configs/train.yaml'])
config = misc.get_config(args.config)

checkpoint_dir = 'checkpoint/cifar_tmp'
config.checkpoint_dir = checkpoint_dir
config.batch_size = batch_size

transform_train = transforms.Compose([
     transforms.Resize(256),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
])

transform_test = transforms.Compose([
     transforms.Resize(256),
     transforms.ToTensor(),
])

test_set = datasets.CIFAR10("./data", download=True, transform=transform_test, train=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)


model = ResNet18().to(device)
model.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar10.pth'))
model.eval()

block_n = 32
gau_n = 0.25
missing_rate = 0.25
cls_n = 0.1
atk1 = torchattacks.AutoAttack(model, eps=8/255.)

acc=0
acc_atk1=0
acc_atk2=0
acc_atk3=0
acc_atk4=0

for i, data in enumerate(test_loader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    adv_inputs1 = atk1(inputs, labels)

    _, preds = torch.max(model(inputs), 1)
    _, preds_atk1 = torch.max(model(adv_inputs1), 1)

    acc += torch.sum(preds == labels).item()
    acc_atk1 += torch.sum(preds_atk1 == labels).item()
print("test acc on clean examples (%): {:.2f}".format(
        acc / len(test_set) * 100.0))
print("test acc on adversairal examples (%):  {:.2f}".format(
        acc_atk1 / len(test_set) * 100.0))


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch, time, torchattacks, random, argparse
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from models.resnet import ResNet18, ResNet34, ResNet50
from model.networks import Generator, Discriminator
import utils.misc as misc
import model.losses as gan_losses
import torchvision as tv
import torch.nn as nn
from util import get_mask_list

device = torch.device("cuda")

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(777)
image_nc = 3
batch_size = 8
epochs = 300

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
                    default="configs/train.yaml", help="Path to yaml config file")
args = parser.parse_args(args=['--config', 'configs/train.yaml'])
config = misc.get_config(args.config)

checkpoint_dir = 'checkpoint/cifar_tmp'
config.checkpoint_dir = checkpoint_dir
config.batch_size = batch_size

transform_train = transforms.Compose([
     transforms.Resize(256),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
])

transform_test = transforms.Compose([
     transforms.Resize(256),
     transforms.ToTensor(),
])

test_set = datasets.CIFAR10("./data", download=True, transform=transform_test, train=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)

discriminator = Discriminator(cnum_in=4, cnum=64)
discriminator = discriminator.to(device)

class Model_g(nn.Module):
    def __init__(self):
        super(Model_g, self).__init__()

        sd_path = None
        self.generator = Generator(cnum_in=5, cnum=48, return_flow=False, checkpoint=sd_path)
        self.generator.train()

        self.res = ResNet18()
        self.res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar10.pth'))
        self.res.eval()

    def forward(self, inputs):

        block_n = 32
        num_of_mask = 4
        len_num = int( (256**2) / (block_n**2))
        nums = np.arange(len_num)
        np.random.shuffle(nums)

        res = []
        res_complete = 0
        interval = int(len_num/num_of_mask)

        mask_i = np.zeros(len_num)
        ones = nums[i*interval: i*interval+interval]
        mask_i[ones] = 1

        mask_i = mask_i.reshape((1,1,int(256/block_n),int(256/block_n)))

        mask_i = np.repeat(mask_i, block_n, axis=2)
        mask_i = np.repeat(mask_i, block_n, axis=3)
        mask_i = mask_i.astype(float)
        mask = torch.from_numpy(mask_i).type(torch.float).to(device)

        atk_batch_incomplete = inputs*(1-mask)
        atk_ones_x = torch.ones_like(atk_batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([atk_batch_incomplete, atk_ones_x, atk_ones_x*mask], axis=1)

        x1, x2 = self.generator(x, mask)

        x2 = (x2+1)*0.5
        outputs_cls = self.res(x2)
        return outputs_cls

# 冻结resnet的参数
model = Model_g().to(device)

g_dict = torch.load('checkpoint/cifar_tmp/states_9.pth')['G']
for key in list(g_dict.keys()):
    if key.startswith('generator'):
        g_dict[key[10:]] = g_dict.pop(key)
    if key.startswith('res'):
        del g_dict[key]
model.generator.load_state_dict(g_dict)
model.generator.eval()


model_res = ResNet18().to(device)
model_res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar10.pth'))
model_res.eval()

g_optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=config.g_lr, betas=(config.g_beta1, config.g_beta2))
d_optimizer = torch.optim.Adam(
    discriminator.parameters(), lr=config.d_lr, betas=(config.d_beta1, config.d_beta2))

gan_loss_d, gan_loss_g = gan_losses.hinge_loss_d, gan_losses.hinge_loss_g
loss_fun = torch.nn.CrossEntropyLoss()

if config.tb_logging:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(config.log_dir)

last_n_iter = -1

# training loop
init_n_iter = last_n_iter + 1
n_iter = 0
time0 = time.time()

block_n = 32
gau_n = 0.25
missing_rate = 0.25
cls_n = 0.1
atk = torchattacks.AutoAttack(model, eps=8/255.) # version = ['rand', 'standard'], norm = ['Linf', 'L2']

acc=0
acc_atk=0

for i, data in enumerate(test_loader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    org_inputs = inputs.clone().detach()
    adv_inputs = atk(inputs, labels)

    outputs_cls0 = model(org_inputs)
    outputs_cls1 = model(adv_inputs)

    _, preds = torch.max(outputs_cls0, 1)
    _, preds_atk1 = torch.max(outputs_cls1, 1)

    acc += torch.sum(preds == labels).item()
    acc_atk += torch.sum(preds_atk1 == labels).item()
    if i == 63: break

print("test acc on clean examples (%): {:.2f}".format(
        (acc / 512.) * 100.0))
print("test acc on FGSM examples (%):  {:.2f}".format(
        (acc_atk / 512.) * 100.0))
