In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch, time, torchattacks, random, argparse
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from models.resnet import ResNet18, ResNet34, ResNet50
from model.networks import Generator, Discriminator
import utils.misc as misc
import model.losses as gan_losses
import torchvision as tv
import torch.nn as nn
from util import get_mask_list

device = torch.device("cuda")

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(777)
image_nc = 3
batch_size = 8
epochs = 300

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
                    default="configs/train.yaml", help="Path to yaml config file")
args = parser.parse_args(args=['--config', 'configs/train.yaml'])
config = misc.get_config(args.config)

checkpoint_dir = 'checkpoint/cifar_tmp'
config.checkpoint_dir = checkpoint_dir
config.batch_size = batch_size
tmp_n = 10

transform_train = transforms.Compose([
     transforms.Resize(256),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
])

transform_test = transforms.Compose([
     transforms.Resize(256),
     transforms.ToTensor(),
])

train_set = datasets.CIFAR10("./data", download=True, transform=transform_train)
test_set = datasets.CIFAR10("./data", download=True, transform=transform_test, train=False)


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)

discriminator = Discriminator(cnum_in=4, cnum=64)

discriminator = discriminator.to(device)

class Model_g(nn.Module):
    def __init__(self):
        super(Model_g, self).__init__()

        sd_path = 'pretrained/states_pt_places2.pth'
        self.generator = Generator(cnum_in=5, cnum=48, return_flow=False, checkpoint=sd_path)
        self.generator.train()

        self.res = ResNet18()
        self.res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar10.pth'))
        self.res.eval()

    def forward(self, x, adv_x, mask, test_flag):
        x1, x2 = self.generator(x, mask)
        if not test_flag:
            adv_x1, adv_x2 = self.generator(adv_x, mask)
            outputs_cls = self.res(adv_x2)
        else: outputs_cls = self.res(x2)

        batch_complete = x2
        return x1, x2, batch_complete, outputs_cls

model = Model_g().to(device)
for param in model.named_parameters():
    if 'res' in param[0]:
        param[1].requires_grad = False

model_res = ResNet18().to(device)
model_res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar10.pth'))
model_res.eval()

g_optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=config.g_lr, betas=(config.g_beta1, config.g_beta2))
d_optimizer = torch.optim.Adam(
    discriminator.parameters(), lr=config.d_lr, betas=(config.d_beta1, config.d_beta2))

gan_loss_d, gan_loss_g = gan_losses.hinge_loss_d, gan_losses.hinge_loss_g
loss_fun = torch.nn.CrossEntropyLoss()

if config.tb_logging:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(config.log_dir)

losses = {}

discriminator.train()

last_n_iter = -1
losses_log = {'d_loss':   [],
              'g_loss':   [],
              'ae_loss':  [],
              'ae_loss1': [],
              'ae_loss2': [],
              'cls_loss': [],
              }

# training loop
init_n_iter = last_n_iter + 1
n_iter = 0
time0 = time.time()

block_n = 32
gau_n = 0.25
missing_rate = 0.25
cls_n = 0.05
atk1 = torchattacks.FGSM(model_res, eps=8/255.)
atk2 = torchattacks.BIM(model_res, eps=8/255.)
atk3 = torchattacks.EOTPGD(model_res, eps=8/255.)
atk4 = torchattacks.PGD(model_res, eps=8/255., steps=20)
for epoch in range(1, epochs+1):
    train_acc = 0

    for i, data in enumerate(train_loader, 0):
        batch_real, labels = data
        batch_real, labels = batch_real.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        batch_adv = atk1(batch_real, labels)
        batch_real = batch_adv.clone().detach()

        gau_noise = torch.normal(0, 1, batch_real.shape)*gau_n
        gau_noise = gau_noise.to(device)
        batch_gau = torch.clip((batch_real+gau_noise), 0, 1)
        adv_gau = torch.clip((batch_adv+gau_noise), 0, 1)
        batch_gau.mul_(2).sub_(1)

        adv_gau.mul_(2).sub_(1)
        batch_real.mul_(2).sub_(1)

        nums = np.zeros(int((256**2) / (block_n**2)))
        nums[:int(missing_rate*(len(nums)))] = 1
        np.random.shuffle(nums)
        mask = nums.reshape((1,1,int(256/block_n),int(256/block_n)))


        mask = np.repeat(mask, block_n, axis=2)
        mask = np.repeat(mask, block_n, axis=3)

        mask = torch.from_numpy(mask).type(torch.float).to(device)

        batch_incomplete = batch_gau*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)

        adv_incomplete = adv_gau*(1.-mask)
        adv_ones_x = torch.ones_like(adv_incomplete)[:, 0:1, :, :].to(device)
        adv_x = torch.cat([adv_incomplete, adv_ones_x, adv_ones_x*mask], axis=1)

        x1, x2, batch_complete, outputs_cls = model(x, adv_x, mask, False)
        batch_predicted = x2.clone().detach()
        batch_complete = x2.clone().detach()

        loss_cls = loss_fun(outputs_cls, labels)

        losses['cls_loss'] = loss_cls

        batch_real_mask = torch.cat(
            (batch_real, torch.tile(mask, [config.batch_size, 1, 1, 1])), dim=1)
        batch_filled_mask = torch.cat((batch_complete.detach(), torch.tile(
            mask, [config.batch_size, 1, 1, 1])), dim=1)

        batch_real_filled = torch.cat((batch_real_mask, batch_filled_mask))

        d_real_gen = discriminator(batch_real_filled)
        d_real, d_gen = torch.split(d_real_gen, config.batch_size)

        d_loss = gan_loss_d(d_real, d_gen)
        losses['d_loss'] = d_loss

        # update D parameters
        d_optimizer.zero_grad()
        losses['d_loss'].backward()
        d_optimizer.step()

        losses['ae_loss1'] = config.l1_loss_alpha * \
            torch.mean((torch.abs(batch_real - x1)))
        losses['ae_loss2'] = config.l1_loss_alpha * \
            torch.mean((torch.abs(batch_real - x2)))
        losses['ae_loss'] = losses['ae_loss1'] + losses['ae_loss2']

        batch_gen = batch_predicted
        batch_gen = torch.cat((batch_gen, torch.tile(
            mask, [config.batch_size, 1, 1, 1])), dim=1)

        d_gen = discriminator(batch_gen)

        g_loss = gan_loss_g(d_gen)
        losses['g_loss'] = g_loss
        losses['g_loss'] = config.gan_loss_alpha * losses['g_loss']
        if config.ae_loss:
            losses['g_loss'] += losses['ae_loss']
            if epoch > 2: losses['g_loss'] += losses['cls_loss']*cls_n

        # update G parameters
        g_optimizer.zero_grad()
        losses['g_loss'].backward()
        g_optimizer.step()

        # LOGGING
        for k in losses_log.keys():
            losses_log[k].append(losses[k].item())

        # (tensorboard) logging
        if n_iter % config.print_iter == 0:
            # measure iterations/second
            dt = time.time() - time0
            print(f"@iter: {n_iter}: {(config.print_iter/dt):.4f} it/s")
            time0 = time.time()

            # write loss terms to console
            # and tensorboard
            for k, loss_log in losses_log.items():
                loss_log_mean = sum(loss_log)/len(loss_log)
                print(f"{k}: {loss_log_mean:.4f}")
                if config.tb_logging:
                    writer.add_scalar(
                        f"losses/{k}", loss_log_mean, global_step=n_iter)
                losses_log[k].clear()

        # save example image grids to tensorboard
        if config.tb_logging \
            and config.save_imgs_to_tb_iter \
            and n_iter % config.save_imgs_to_tb_iter == 0:
            viz_images = [misc.pt_to_image(batch_complete),
                          misc.pt_to_image(x1), misc.pt_to_image(x2)]
            img_grids = [tv.utils.make_grid(images[:config.viz_max_out], nrow=2)
                        for images in viz_images]

            writer.add_image(
                "Inpainted", img_grids[0], global_step=n_iter, dataformats="CHW")
            writer.add_image(
                "Stage 1", img_grids[1], global_step=n_iter, dataformats="CHW")
            writer.add_image(
                "Stage 2", img_grids[2], global_step=n_iter, dataformats="CHW")

        # save example image grids to disk
        if config.save_imgs_to_disc_iter \
            and n_iter % config.save_imgs_to_disc_iter == 0:
            viz_images = [misc.pt_to_image(batch_real),
                          misc.pt_to_image(batch_complete)]
            img_grids = [tv.utils.make_grid(images[:config.viz_max_out], nrow=2)
                                            for images in viz_images]
            tv.utils.save_image(img_grids,
            f"{checkpoint_dir}/iter_{tmp_n}_{n_iter}.png",
            nrow=2)

        # save state dict snapshot
        if n_iter % config.save_checkpoint_iter == 0 \
            and n_iter > init_n_iter:
            misc.save_states(f"states_{tmp_n}.pth",
                        model, discriminator,
                        g_optimizer, d_optimizer,
                        n_iter, config)
        # save state dict snapshot backup
        if config.save_cp_backup_iter \
            and n_iter % config.save_cp_backup_iter == 0 \
            and n_iter > init_n_iter:
            misc.save_states(f"states_{tmp_n}_{n_iter}.pth",
                        model, discriminator,
                        g_optimizer, d_optimizer,
                        n_iter, config)
        n_iter += 1
        if i==624: break

    acc=0
    acc_atk1=0
    acc_atk2=0
    acc_atk3=0
    acc_atk4=0
    for i, data in enumerate(test_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs1 = atk1(inputs, labels)
        adv_inputs2 = atk2(inputs, labels)
        adv_inputs3 = atk3(inputs, labels)
        adv_inputs4 = atk4(inputs, labels)

        gau_noise = torch.normal(0, 1, inputs.shape)*gau_n
        gau_noise = gau_noise.to(device)

        inputs_gau = torch.clip((inputs+gau_noise), 0, 1)
        adv_gau1 = torch.clip((adv_inputs1+gau_noise), 0, 1)
        adv_gau2 = torch.clip((adv_inputs2+gau_noise), 0, 1)
        adv_gau3 = torch.clip((adv_inputs3+gau_noise), 0, 1)
        adv_gau4 = torch.clip((adv_inputs4+gau_noise), 0, 1)


        inputs_gau.mul_(2).sub_(1)
        adv_gau1.mul_(2).sub_(1)
        adv_gau2.mul_(2).sub_(1)
        adv_gau3.mul_(2).sub_(1)
        adv_gau4.mul_(2).sub_(1)

        nums = np.zeros(int((256**2) / (block_n**2)))
        nums[:int(missing_rate*(len(nums)))] = 1
        np.random.shuffle(nums)
        mask = nums.reshape((1,1,int(256/block_n),int(256/block_n)))

        mask = np.repeat(mask, block_n, axis=2)
        mask = np.repeat(mask, block_n, axis=3)

        mask = torch.from_numpy(mask).type(torch.float).to(device)

        batch_incomplete = inputs_gau*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        inputs_predicted = x2.clone().detach()

        batch_incomplete = adv_gau1*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted1 = x2.clone().detach()

        batch_incomplete = adv_gau2*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted2 = x2.clone().detach()

        batch_incomplete = adv_gau3*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted3 = x2.clone().detach()

        batch_incomplete = adv_gau4*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted4 = x2.clone().detach()


        _, preds = torch.max(model_res(inputs_predicted), 1)
        _, preds_atk1 = torch.max(model_res(adv_predicted1), 1)
        _, preds_atk2 = torch.max(model_res(adv_predicted2), 1)
        _, preds_atk3 = torch.max(model_res(adv_predicted3), 1)
        _, preds_atk4 = torch.max(model_res(adv_predicted4), 1)

        acc += torch.sum(preds == labels).item()
        acc_atk1 += torch.sum(preds_atk1 == labels).item()
        acc_atk2 += torch.sum(preds_atk2 == labels).item()
        acc_atk3 += torch.sum(preds_atk3 == labels).item()
        acc_atk4 += torch.sum(preds_atk4 == labels).item()
        if i == 63: break

    print("test acc on clean examples (%): {:.2f}".format(
            (acc / 512.) * 100.0))
    print("test acc on FGSM examples (%):  {:.2f}".format(
            (acc_atk1 / 512.) * 100.0))
    print("test acc on BIM examples (%):   {:.2f}".format(
            (acc_atk2 / 512.) * 100.0))
    print("test acc on EoT examples (%):   {:.2f}".format(
            (acc_atk3 / 512.) * 100.0))
    print("test acc on PGD examples (%):  {:.2f}".format(
            (acc_atk4 / 512.) * 100.0))
    print('________epoch '+str(epoch)+'________')
    misc.save_states(f"states_{tmp_n}_{epoch}.pth",
                    model, discriminator,
                    g_optimizer, d_optimizer,
                    epoch, config)
