In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import torch, time, torchattacks, random, argparse
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from models.resnet import ResNet18, ResNet34, ResNet50
from model.networks import Generator, Discriminator
import utils.misc as misc
import model.losses as gan_losses
import torchvision as tv
import torch.nn as nn
from util import get_mask_list

device = torch.device("cuda")

# 设置seed
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(777)
image_nc = 3
batch_size = 8
epochs = 300

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
                    default="configs/train.yaml", help="Path to yaml config file")
args = parser.parse_args(args=['--config', 'configs/train.yaml'])
config = misc.get_config(args.config)

checkpoint_dir = 'checkpoint/cifar_tmp'
config.checkpoint_dir = checkpoint_dir
config.batch_size = batch_size
tmp_n = 10
# -----------------------------开始------------------------------

transform_train = transforms.Compose([
     transforms.Resize(256),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
])

transform_test = transforms.Compose([
     transforms.Resize(256),
     transforms.ToTensor(),
])

train_set = datasets.CIFAR10("../../data", download=True, transform=transform_train)
test_set = datasets.CIFAR10("../../data", download=True, transform=transform_test, train=False)


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8)

discriminator = Discriminator(cnum_in=4, cnum=64)
discriminator = discriminator.to(device)


class Model_g(nn.Module):
    def __init__(self):
        super(Model_g, self).__init__()
        
        sd_path = 'pretrained/states_pt_places2.pth'
        self.generator = Generator(cnum_in=5, cnum=48, return_flow=False, checkpoint=sd_path)
        self.generator.train()
        
        self.res = ResNet18()
        self.res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar_932.pth'))
        self.res.eval()
    
    def forward(self, x, adv_x, mask, test_flag):
        x1, x2 = self.generator(x, mask)
        if not test_flag:
            adv_x1, adv_x2 = self.generator(adv_x, mask)
            outputs_cls = self.res(adv_x2)
        else: outputs_cls = self.res(x2)
            
        batch_complete = x2
        return x1, x2, batch_complete, outputs_cls

# 冻结resnet的参数
model = Model_g().to(device)
for param in model.named_parameters():
    if 'res' in param[0]:
        param[1].requires_grad = False
        
model_res = ResNet18().to(device)
model_res.load_state_dict(torch.load('./models/state_dicts/resnet18_cifar_932.pth'))
model_res.eval()

g_optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=config.g_lr, betas=(config.g_beta1, config.g_beta2))
d_optimizer = torch.optim.Adam(
    discriminator.parameters(), lr=config.d_lr, betas=(config.d_beta1, config.d_beta2))


# 加载损失函数
gan_loss_d, gan_loss_g = gan_losses.hinge_loss_d, gan_losses.hinge_loss_g
loss_fun = torch.nn.CrossEntropyLoss()

# 使用tensorboard进行可视化统计
if config.tb_logging:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(config.log_dir)

    

# 正式开始训练模型    
losses = {}

# generator.train()
discriminator.train()

last_n_iter = -1
losses_log = {'d_loss':   [],
              'g_loss':   [],
              'ae_loss':  [],
              'ae_loss1': [],
              'ae_loss2': [],
              'cls_loss': [],
              }

# training loop
init_n_iter = last_n_iter + 1
n_iter = 0
time0 = time.time()

block_n = 32
gau_n = 0.25
missing_rate = 0.25
cls_n = 0.05
atk1 = torchattacks.FGSM(model_res, eps=8/255.)
atk2 = torchattacks.BIM(model_res, eps=8/255.)
atk3 = torchattacks.EOTPGD(model_res, eps=8/255.)
atk4 = torchattacks.PGD(model_res, eps=8/255., steps=20)
for epoch in range(1, epochs+1):
    train_acc = 0
    
    for i, data in enumerate(train_loader, 0):
        batch_real, labels = data
        batch_real, labels = batch_real.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        # 加入高斯噪声
        batch_adv = atk1(batch_real, labels)
        gau_noise = torch.normal(0, 1, batch_real.shape)*gau_n
        gau_noise = gau_noise.to(device)
        batch_gau = torch.clip((batch_real+gau_noise), 0, 1)
        adv_gau = torch.clip((batch_adv+gau_noise), 0, 1)
        batch_gau.mul_(2).sub_(1)
            
        adv_gau.mul_(2).sub_(1)
        batch_real.mul_(2).sub_(1)
        
        # 生成mask
        nums = np.zeros(int((256**2) / (block_n**2)))
        nums[:int(missing_rate*(len(nums)))] = 1
        np.random.shuffle(nums)
        mask = nums.reshape((1,1,int(256/block_n),int(256/block_n)))
        
        
        mask = np.repeat(mask, block_n, axis=2)
        mask = np.repeat(mask, block_n, axis=3)
         
        mask = torch.from_numpy(mask).type(torch.float).to(device)
        
        batch_incomplete = batch_gau*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        
        adv_incomplete = adv_gau*(1.-mask)
        adv_ones_x = torch.ones_like(adv_incomplete)[:, 0:1, :, :].to(device)
        adv_x = torch.cat([adv_incomplete, adv_ones_x, adv_ones_x*mask], axis=1)
        
        # x1为coarse result，x2为refined result
        x1, x2, batch_complete, outputs_cls = model(x, adv_x, mask, False)
        batch_predicted = x2.clone().detach()
        batch_complete = x2.clone().detach()
        
        
        loss_cls = loss_fun(outputs_cls, labels)
        losses['cls_loss'] = loss_cls
        
        # 开始关于D的训练：
        batch_real_mask = torch.cat(
            (batch_real, torch.tile(mask, [config.batch_size, 1, 1, 1])), dim=1)
        batch_filled_mask = torch.cat((batch_complete.detach(), torch.tile(
            mask, [config.batch_size, 1, 1, 1])), dim=1)

        batch_real_filled = torch.cat((batch_real_mask, batch_filled_mask))

        d_real_gen = discriminator(batch_real_filled)
        d_real, d_gen = torch.split(d_real_gen, config.batch_size)

        d_loss = gan_loss_d(d_real, d_gen)
        losses['d_loss'] = d_loss

        # update D parameters
        d_optimizer.zero_grad()
        losses['d_loss'].backward()
        d_optimizer.step()

        
        # 开始关于G的训练
        losses['ae_loss1'] = config.l1_loss_alpha * \
            torch.mean((torch.abs(batch_real - x1)))
        losses['ae_loss2'] = config.l1_loss_alpha * \
            torch.mean((torch.abs(batch_real - x2)))
        losses['ae_loss'] = losses['ae_loss1'] + losses['ae_loss2']

        batch_gen = batch_predicted
        batch_gen = torch.cat((batch_gen, torch.tile(
            mask, [config.batch_size, 1, 1, 1])), dim=1)

        d_gen = discriminator(batch_gen)

        g_loss = gan_loss_g(d_gen)
        losses['g_loss'] = g_loss
        losses['g_loss'] = config.gan_loss_alpha * losses['g_loss']
        if config.ae_loss:
            losses['g_loss'] += losses['ae_loss']
            if epoch > 2: losses['g_loss'] += losses['cls_loss']*cls_n

        # update G parameters
        g_optimizer.zero_grad()
        losses['g_loss'].backward()
        g_optimizer.step()
        
        # LOGGING
        for k in losses_log.keys():
            losses_log[k].append(losses[k].item())

        # (tensorboard) logging
        if n_iter % config.print_iter == 0:
            # measure iterations/second
            dt = time.time() - time0
            print(f"@iter: {n_iter}: {(config.print_iter/dt):.4f} it/s")
            time0 = time.time()

            # write loss terms to console
            # and tensorboard
            for k, loss_log in losses_log.items():
                loss_log_mean = sum(loss_log)/len(loss_log)
                print(f"{k}: {loss_log_mean:.4f}")
                if config.tb_logging:
                    writer.add_scalar(
                        f"losses/{k}", loss_log_mean, global_step=n_iter)                
                losses_log[k].clear()

        # save example image grids to tensorboard
        if config.tb_logging \
            and config.save_imgs_to_tb_iter \
            and n_iter % config.save_imgs_to_tb_iter == 0:
            viz_images = [misc.pt_to_image(batch_complete),
                          misc.pt_to_image(x1), misc.pt_to_image(x2)]
            img_grids = [tv.utils.make_grid(images[:config.viz_max_out], nrow=2)
                        for images in viz_images]

            writer.add_image(
                "Inpainted", img_grids[0], global_step=n_iter, dataformats="CHW")
            writer.add_image(
                "Stage 1", img_grids[1], global_step=n_iter, dataformats="CHW")
            writer.add_image(
                "Stage 2", img_grids[2], global_step=n_iter, dataformats="CHW")

        # save example image grids to disk
        if config.save_imgs_to_disc_iter \
            and n_iter % config.save_imgs_to_disc_iter == 0:
            viz_images = [misc.pt_to_image(batch_real), 
                          misc.pt_to_image(batch_complete)]
            img_grids = [tv.utils.make_grid(images[:config.viz_max_out], nrow=2)
                                            for images in viz_images]
            tv.utils.save_image(img_grids, 
            f"{checkpoint_dir}/iter_{tmp_n}_{n_iter}.png", 
            nrow=2)

        # save state dict snapshot
        if n_iter % config.save_checkpoint_iter == 0 \
            and n_iter > init_n_iter:
            misc.save_states(f"states_{tmp_n}.pth",
                        model, discriminator,
                        g_optimizer, d_optimizer,
                        n_iter, config)
        # save state dict snapshot backup
        if config.save_cp_backup_iter \
            and n_iter % config.save_cp_backup_iter == 0 \
            and n_iter > init_n_iter:
            misc.save_states(f"states_{tmp_n}_{n_iter}.pth",
                        model, discriminator,
                        g_optimizer, d_optimizer,
                        n_iter, config)
        n_iter += 1
        if i==624: break
        
    acc=0
    acc_atk1=0
    acc_atk2=0
    acc_atk3=0
    acc_atk4=0
    for i, data in enumerate(test_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        adv_inputs1 = atk1(inputs, labels)
        adv_inputs2 = atk2(inputs, labels)
        adv_inputs3 = atk3(inputs, labels)
        adv_inputs4 = atk4(inputs, labels)

        gau_noise = torch.normal(0, 1, inputs.shape)*gau_n
        gau_noise = gau_noise.to(device)
        
        inputs_gau = torch.clip((inputs+gau_noise), 0, 1)
        adv_gau1 = torch.clip((adv_inputs1+gau_noise), 0, 1)
        adv_gau2 = torch.clip((adv_inputs2+gau_noise), 0, 1)
        adv_gau3 = torch.clip((adv_inputs3+gau_noise), 0, 1)
        adv_gau4 = torch.clip((adv_inputs4+gau_noise), 0, 1)
        
        
        inputs_gau.mul_(2).sub_(1)
        adv_gau1.mul_(2).sub_(1)
        adv_gau2.mul_(2).sub_(1)
        adv_gau3.mul_(2).sub_(1)
        adv_gau4.mul_(2).sub_(1)
        
        # 生成mask
        nums = np.zeros(int((256**2) / (block_n**2)))
        nums[:int(missing_rate*(len(nums)))] = 1
        np.random.shuffle(nums)
        mask = nums.reshape((1,1,int(256/block_n),int(256/block_n)))
        
        mask = np.repeat(mask, block_n, axis=2)
        mask = np.repeat(mask, block_n, axis=3)
         
        mask = torch.from_numpy(mask).type(torch.float).to(device)
        
        batch_incomplete = inputs_gau*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        inputs_predicted = x2.clone().detach()
        
        batch_incomplete = adv_gau1*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted1 = x2.clone().detach()
        
        batch_incomplete = adv_gau2*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted2 = x2.clone().detach()
        
        batch_incomplete = adv_gau3*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted3 = x2.clone().detach()
        
        batch_incomplete = adv_gau4*(1.-mask)
        ones_x = torch.ones_like(batch_incomplete)[:, 0:1, :, :].to(device)
        x = torch.cat([batch_incomplete, ones_x, ones_x*mask], axis=1)
        x1, x2, batch_complete, outputs_cls = model(x, x, mask, True)
        adv_predicted4 = x2.clone().detach()

        
        _, preds = torch.max(model_res(inputs_predicted), 1)
        _, preds_atk1 = torch.max(model_res(adv_predicted1), 1)
        _, preds_atk2 = torch.max(model_res(adv_predicted2), 1)
        _, preds_atk3 = torch.max(model_res(adv_predicted3), 1)
        _, preds_atk4 = torch.max(model_res(adv_predicted4), 1)

        acc += torch.sum(preds == labels).item()
        acc_atk1 += torch.sum(preds_atk1 == labels).item()
        acc_atk2 += torch.sum(preds_atk2 == labels).item()
        acc_atk3 += torch.sum(preds_atk3 == labels).item()
        acc_atk4 += torch.sum(preds_atk4 == labels).item()
        if i == 63: break

    print("test acc on clean examples (%): {:.2f}".format(
            (acc / 512.) * 100.0))
    print("test acc on FGSM examples (%):  {:.2f}".format(
            (acc_atk1 / 512.) * 100.0))
    print("test acc on BIM examples (%):   {:.2f}".format(
            (acc_atk2 / 512.) * 100.0))
    print("test acc on EoT examples (%):   {:.2f}".format(
            (acc_atk3 / 512.) * 100.0))
    print("test acc on BPDA examples (%):  {:.2f}".format(
            (acc_atk4 / 512.) * 100.0))
    print('________epoch '+str(epoch)+'________')
    misc.save_states(f"states_{tmp_n}_{epoch}.pth",
                    model, discriminator,
                    g_optimizer, d_optimizer,
                    epoch, config)


  warn(f"Failed to load image Python extension: {e}")


Files already downloaded and verified
Files already downloaded and verified
@iter: 0: 32.0716 it/s
d_loss: 1.0004
g_loss: 0.4020
ae_loss: 0.4109
ae_loss1: 0.1552
ae_loss2: 0.2558
cls_loss: 2.8541
@iter: 100: 1.3900 it/s
d_loss: 0.9992
g_loss: 0.0541
ae_loss: 0.1281
ae_loss1: 0.0582
ae_loss2: 0.0699
cls_loss: 1.6255
Saved state dicts!
@iter: 200: 1.4268 it/s
d_loss: 0.9903
g_loss: 0.0994
ae_loss: 0.0950
ae_loss1: 0.0461
ae_loss2: 0.0489
cls_loss: 1.0380
Saved state dicts!
@iter: 300: 1.4271 it/s
d_loss: 0.9658
g_loss: 0.2570
ae_loss: 0.0918
ae_loss1: 0.0447
ae_loss2: 0.0471
cls_loss: 0.7819
Saved state dicts!
@iter: 400: 1.4124 it/s
d_loss: 0.9270
g_loss: -0.0245
ae_loss: 0.0888
ae_loss1: 0.0433
ae_loss2: 0.0456
cls_loss: 0.8451
Saved state dicts!
@iter: 500: 1.3872 it/s
d_loss: 0.9264
g_loss: 0.4548
ae_loss: 0.0871
ae_loss1: 0.0425
ae_loss2: 0.0446
cls_loss: 0.7750
Saved state dicts!
@iter: 600: 1.2979 it/s
d_loss: 0.9246
g_loss: 0.4399
ae_loss: 0.0874
ae_loss1: 0.0428
ae_loss2: 0.0445

Saved state dicts!
@iter: 4700: 1.0912 it/s
d_loss: 0.7392
g_loss: 0.0356
ae_loss: 0.0865
ae_loss1: 0.0415
ae_loss2: 0.0450
cls_loss: 0.1689
Saved state dicts!
@iter: 4800: 1.1709 it/s
d_loss: 0.7476
g_loss: 0.0315
ae_loss: 0.0853
ae_loss1: 0.0408
ae_loss2: 0.0445
cls_loss: 0.1588
Saved state dicts!
@iter: 4900: 1.1379 it/s
d_loss: 0.7425
g_loss: 0.0290
ae_loss: 0.0860
ae_loss1: 0.0412
ae_loss2: 0.0448
cls_loss: 0.1869
Saved state dicts!
test acc on clean examples (%): 86.72
test acc on FGSM examples (%):  85.94
test acc on BIM examples (%):   83.40
test acc on EoT examples (%):   83.01
test acc on BPDA examples (%):  82.62
________epoch 8________
Saved state dicts!
@iter: 5000: 0.4283 it/s
d_loss: 0.7289
g_loss: 0.0291
ae_loss: 0.0854
ae_loss1: 0.0408
ae_loss2: 0.0446
cls_loss: 0.1655
Saved state dicts!
Saved state dicts!
@iter: 5100: 1.1716 it/s
d_loss: 0.7312
g_loss: 0.0285
ae_loss: 0.0863
ae_loss1: 0.0414
ae_loss2: 0.0449
cls_loss: 0.1551
Saved state dicts!
@iter: 5200: 1.1167 it/s

@iter: 9300: 1.1051 it/s
d_loss: 0.7356
g_loss: -0.0967
ae_loss: 0.0830
ae_loss1: 0.0400
ae_loss2: 0.0430
cls_loss: 0.1406
Saved state dicts!
test acc on clean examples (%): 87.70
test acc on FGSM examples (%):  89.06
test acc on BIM examples (%):   85.35
test acc on EoT examples (%):   85.16
test acc on BPDA examples (%):  84.57
________epoch 15________
Saved state dicts!
@iter: 9400: 0.4185 it/s
d_loss: 0.7354
g_loss: -0.0871
ae_loss: 0.0851
ae_loss1: 0.0407
ae_loss2: 0.0444
cls_loss: 0.1945
Saved state dicts!
@iter: 9500: 1.1152 it/s
d_loss: 0.7437
g_loss: -0.1095
ae_loss: 0.0851
ae_loss1: 0.0406
ae_loss2: 0.0444
cls_loss: 0.1691
Saved state dicts!
@iter: 9600: 1.0838 it/s
d_loss: 0.7323
g_loss: -0.0935
ae_loss: 0.0844
ae_loss1: 0.0404
ae_loss2: 0.0440
cls_loss: 0.1656
Saved state dicts!
@iter: 9700: 1.1109 it/s
d_loss: 0.7250
g_loss: -0.0958
ae_loss: 0.0843
ae_loss1: 0.0402
ae_loss2: 0.0441
cls_loss: 0.1736
Saved state dicts!
@iter: 9800: 1.1064 it/s
d_loss: 0.7348
g_loss: -0.1017


Saved state dicts!
@iter: 13800: 0.4139 it/s
d_loss: 0.7203
g_loss: -0.1326
ae_loss: 0.0828
ae_loss1: 0.0395
ae_loss2: 0.0434
cls_loss: 0.1331
Saved state dicts!
@iter: 13900: 1.1134 it/s
d_loss: 0.7109
g_loss: -0.1235
ae_loss: 0.0821
ae_loss1: 0.0392
ae_loss2: 0.0428
cls_loss: 0.1161
Saved state dicts!
@iter: 14000: 1.1136 it/s
d_loss: 0.7183
g_loss: -0.1298
ae_loss: 0.0820
ae_loss1: 0.0392
ae_loss2: 0.0428
cls_loss: 0.1493
Saved state dicts!
@iter: 14100: 1.0820 it/s
d_loss: 0.7098
g_loss: -0.1276
ae_loss: 0.0839
ae_loss1: 0.0400
ae_loss2: 0.0438
cls_loss: 0.1704
Saved state dicts!
@iter: 14200: 1.1057 it/s
d_loss: 0.7117
g_loss: -0.1305
ae_loss: 0.0806
ae_loss1: 0.0386
ae_loss2: 0.0420
cls_loss: 0.1269
Saved state dicts!
@iter: 14300: 1.1121 it/s
d_loss: 0.7108
g_loss: -0.1243
ae_loss: 0.0835
ae_loss1: 0.0396
ae_loss2: 0.0439
cls_loss: 0.1368
Saved state dicts!
test acc on clean examples (%): 88.28
test acc on FGSM examples (%):  88.48
test acc on BIM examples (%):   86.72
test acc 

Saved state dicts!
@iter: 18400: 1.1087 it/s
d_loss: 0.7290
g_loss: -0.1617
ae_loss: 0.0820
ae_loss1: 0.0393
ae_loss2: 0.0428
cls_loss: 0.1369
Saved state dicts!
@iter: 18500: 1.0989 it/s
d_loss: 0.7280
g_loss: -0.1535
ae_loss: 0.0809
ae_loss1: 0.0387
ae_loss2: 0.0422
cls_loss: 0.1229
Saved state dicts!
@iter: 18600: 1.0881 it/s
d_loss: 0.7248
g_loss: -0.1630
ae_loss: 0.0808
ae_loss1: 0.0388
ae_loss2: 0.0420
cls_loss: 0.1077
Saved state dicts!
@iter: 18700: 1.1105 it/s
d_loss: 0.7295
g_loss: -0.1670
ae_loss: 0.0819
ae_loss1: 0.0393
ae_loss2: 0.0427
cls_loss: 0.1162
Saved state dicts!
test acc on clean examples (%): 89.45
test acc on FGSM examples (%):  89.84
test acc on BIM examples (%):   86.91
test acc on EoT examples (%):   86.33
test acc on BPDA examples (%):  85.74
________epoch 30________
Saved state dicts!
@iter: 18800: 0.4170 it/s
d_loss: 0.7346
g_loss: -0.1662
ae_loss: 0.0807
ae_loss1: 0.0386
ae_loss2: 0.0421
cls_loss: 0.1348
Saved state dicts!
@iter: 18900: 1.1134 it/s
d_loss

Saved state dicts!
@iter: 23000: 1.1131 it/s
d_loss: 0.7449
g_loss: -0.1818
ae_loss: 0.0803
ae_loss1: 0.0387
ae_loss2: 0.0416
cls_loss: 0.0835
Saved state dicts!
@iter: 23100: 1.0860 it/s
d_loss: 0.7390
g_loss: -0.1761
ae_loss: 0.0823
ae_loss1: 0.0396
ae_loss2: 0.0427
cls_loss: 0.1021
Saved state dicts!
test acc on clean examples (%): 89.45
test acc on FGSM examples (%):  90.04
test acc on BIM examples (%):   87.50
test acc on EoT examples (%):   87.50
test acc on BPDA examples (%):  86.13
________epoch 37________
Saved state dicts!
@iter: 23200: 0.4162 it/s
d_loss: 0.7566
g_loss: -0.1899
ae_loss: 0.0808
ae_loss1: 0.0386
ae_loss2: 0.0422
cls_loss: 0.1218
Saved state dicts!
@iter: 23300: 1.1066 it/s
d_loss: 0.7535
g_loss: -0.1923
ae_loss: 0.0812
ae_loss1: 0.0391
ae_loss2: 0.0422
cls_loss: 0.1252
Saved state dicts!
@iter: 23400: 1.1073 it/s
d_loss: 0.7404
g_loss: -0.1788
ae_loss: 0.0806
ae_loss1: 0.0386
ae_loss2: 0.0420
cls_loss: 0.1252
Saved state dicts!
@iter: 23500: 1.1071 it/s
d_loss

Saved state dicts!
@iter: 27500: 0.4103 it/s
d_loss: 0.7561
g_loss: -0.2001
ae_loss: 0.0802
ae_loss1: 0.0384
ae_loss2: 0.0418
cls_loss: 0.1102
Saved state dicts!
@iter: 27600: 1.0815 it/s
d_loss: 0.7472
g_loss: -0.1844
ae_loss: 0.0812
ae_loss1: 0.0391
ae_loss2: 0.0421
cls_loss: 0.1184
Saved state dicts!
@iter: 27700: 1.0998 it/s
d_loss: 0.7454
g_loss: -0.1853
ae_loss: 0.0794
ae_loss1: 0.0382
ae_loss2: 0.0412
cls_loss: 0.1085
Saved state dicts!
@iter: 27800: 1.0885 it/s
d_loss: 0.7507
g_loss: -0.1951
ae_loss: 0.0789
ae_loss1: 0.0380
ae_loss2: 0.0409
cls_loss: 0.1142
Saved state dicts!
@iter: 27900: 1.0846 it/s
d_loss: 0.7457
g_loss: -0.1898
ae_loss: 0.0806
ae_loss1: 0.0388
ae_loss2: 0.0418
cls_loss: 0.1231
Saved state dicts!
@iter: 28000: 1.0844 it/s
d_loss: 0.7499
g_loss: -0.1960
ae_loss: 0.0800
ae_loss1: 0.0384
ae_loss2: 0.0416
cls_loss: 0.0944
Saved state dicts!
@iter: 28100: 1.0520 it/s
d_loss: 0.7558
g_loss: -0.2048
ae_loss: 0.0783
ae_loss1: 0.0377
ae_loss2: 0.0406
cls_loss: 0.0748

Saved state dicts!
@iter: 32100: 1.0717 it/s
d_loss: 0.7600
g_loss: -0.2077
ae_loss: 0.0795
ae_loss1: 0.0382
ae_loss2: 0.0412
cls_loss: 0.0995
Saved state dicts!
@iter: 32200: 1.0954 it/s
d_loss: 0.7670
g_loss: -0.2112
ae_loss: 0.0801
ae_loss1: 0.0386
ae_loss2: 0.0414
cls_loss: 0.1222
Saved state dicts!
@iter: 32400: 1.0948 it/s
d_loss: 0.7599
g_loss: -0.1935
ae_loss: 0.0807
ae_loss1: 0.0388
ae_loss2: 0.0419
cls_loss: 0.1109
Saved state dicts!
test acc on clean examples (%): 87.70
test acc on FGSM examples (%):  89.45
test acc on BIM examples (%):   85.94
test acc on EoT examples (%):   84.38
test acc on BPDA examples (%):  85.35
________epoch 52________
Saved state dicts!
@iter: 32500: 0.4126 it/s
d_loss: 0.7568
g_loss: -0.1965
ae_loss: 0.0791
ae_loss1: 0.0382
ae_loss2: 0.0409
cls_loss: 0.1239
Saved state dicts!
@iter: 32600: 1.0703 it/s
d_loss: 0.7550
g_loss: -0.1975
ae_loss: 0.0791
ae_loss1: 0.0382
ae_loss2: 0.0410
cls_loss: 0.0830
Saved state dicts!
@iter: 32700: 1.0950 it/s
d_loss

Saved state dicts!
test acc on clean examples (%): 87.89
test acc on FGSM examples (%):  86.91
test acc on BIM examples (%):   86.91
test acc on EoT examples (%):   85.74
test acc on BPDA examples (%):  86.33
________epoch 59________
Saved state dicts!
@iter: 36900: 0.5012 it/s
d_loss: 0.7627
g_loss: -0.2035
ae_loss: 0.0799
ae_loss1: 0.0384
ae_loss2: 0.0414
cls_loss: 0.1006
Saved state dicts!
@iter: 37000: 1.2263 it/s
d_loss: 0.7592
g_loss: -0.2013
ae_loss: 0.0787
ae_loss1: 0.0379
ae_loss2: 0.0408
cls_loss: 0.1380
Saved state dicts!
@iter: 37100: 1.2009 it/s
d_loss: 0.7639
g_loss: -0.2128
ae_loss: 0.0778
ae_loss1: 0.0377
ae_loss2: 0.0401
cls_loss: 0.0788
Saved state dicts!
@iter: 37200: 1.2353 it/s
d_loss: 0.7657
g_loss: -0.2093
ae_loss: 0.0791
ae_loss1: 0.0382
ae_loss2: 0.0409
cls_loss: 0.0984
Saved state dicts!
@iter: 37300: 1.2241 it/s
d_loss: 0.7609
g_loss: -0.2070
ae_loss: 0.0783
ae_loss1: 0.0377
ae_loss2: 0.0405
cls_loss: 0.1178
Saved state dicts!
@iter: 37400: 1.2243 it/s
d_loss