In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import os
from torchvision.utils import save_image
from dataset import Adversarial_Dataset
from utils import adjust_lr, get_z_sets, get_z_star, Resize_Image
from model import CNN
from gan_model import Generator,Discriminator,Decoder
from torchsummary import summary
import copy

### set parameters

In [2]:
batch_size = 32
in_channel = 3
height = 32
width = 32

display_steps = 20

### Load classification model

In [3]:
# Send the model to GPU
model = CNN()

summary(model, input_size = (in_channel,height,width), device = 'cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
         LeakyReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          18,496
         LeakyReLU-5           [-1, 64, 32, 32]               0
         AvgPool2d-6           [-1, 64, 16, 16]               0
            Conv2d-7          [-1, 128, 16, 16]          73,856
       BatchNorm2d-8          [-1, 128, 16, 16]             256
         LeakyReLU-9          [-1, 128, 16, 16]               0
           Conv2d-10          [-1, 128, 16, 16]         147,584
        LeakyReLU-11          [-1, 128, 16, 16]               0
        AvgPool2d-12            [-1, 128, 8, 8]               0
           Conv2d-13            [-1, 256, 8, 8]         295,168
      BatchNorm2d-14            [-1, 25

In [4]:
device_model = torch.device(0)

In [5]:
model.load_state_dict(torch.load('./checkpoints/cifar10.pth'))

model = model.to(device_model)

### load defense-GAN model

In [6]:
learning_rate = 10.0
rec_iters = [800]
rec_rrs = [20]
decay_rate = 0.1
global_step = 3.0
generator_input_size = 32

INPUT_LATENT = 64 

In [7]:
device_generator = torch.device(0)
device_disc = torch.device(0)
device_dec = torch.device(0)

In [9]:
ModelG = Generator()
ModelD = Discriminator()
Dec = Decoder()

generator_path = './defensive_models/gen_cifar10_gp_61299.pth'
disc_path = './defensive_models/disc_cifar10_gp_61299.pth'
dec_path = './dec_checkpoints/cifar10_dec_100.pth'

ModelG.load_state_dict(torch.load(generator_path))
ModelD.load_state_dict(torch.load(disc_path))
Dec.load_state_dict(torch.load(dec_path))

summary(ModelG, input_size = (INPUT_LATENT,), device = 'cpu')
summary(ModelD, input_size = (3, 32, 32), device = 'cpu')
summary(Dec, input_size = (3, 32, 32), device = 'cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 4096]         266,240
       BatchNorm1d-2                 [-1, 4096]           8,192
              ReLU-3                 [-1, 4096]               0
   ConvTranspose2d-4            [-1, 128, 8, 8]         131,200
       BatchNorm2d-5            [-1, 128, 8, 8]             256
              ReLU-6            [-1, 128, 8, 8]               0
   ConvTranspose2d-7           [-1, 64, 16, 16]          32,832
       BatchNorm2d-8           [-1, 64, 16, 16]             128
              ReLU-9           [-1, 64, 16, 16]               0
  ConvTranspose2d-10            [-1, 3, 32, 32]             771
             Tanh-11            [-1, 3, 32, 32]               0
Total params: 439,619
Trainable params: 439,619
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/

In [10]:
ModelG = ModelG.to(device_generator)
ModelD = ModelD.to(device_disc)
Dec = Dec.to(device_dec)

In [11]:
loss = nn.MSELoss()

### load test dataset

In [12]:
# adversarial dataset path
root_dir = './adversarial/'

In [13]:
# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [14]:
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

Files already downloaded and verified


In [15]:
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size*2, shuffle=True, num_workers=4,pin_memory=True)

Files already downloaded and verified


In [15]:
img_rec_saves = []
img_inputs = []
save_losses = []
learning_rate = 1e-3
criterion_dec = nn.MSELoss()
optim_dec = optim.Adam(Dec.parameters(), lr=learning_rate, betas=(0.5, 0.9))

In [16]:
for epoch in range(300):
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(torch.device(0))
        labels = labels.to(torch.device(0))
#         img_inputs.append(inputs)
        
        optim_dec.zero_grad()
        
        img_rec_code = Dec(inputs)
        
        img_rec = ModelG(img_rec_code)
        # img_rec = img_rec.view(-1, inputs.size(1), inputs.size(2), inputs.size(3))
        # fake_labels = model(img_rec)
        loss_rec = criterion_dec(img_rec,inputs)
        loss_rec.backward()
        optim_dec.step()
        
        del img_rec_code
        
        running_loss += loss_rec.item()
        if(batch_idx % 100 == 99):
            # print(inputs.size())
            print('[%d, %4d] loss:%.3f'%(epoch + 1, batch_idx+1, running_loss / 100))
            save_losses.append(running_loss / 100)
            running_loss = 0.0
            
    if(epoch % 50 == 0):
        torch.save(Dec.state_dict(),'./dec_checkpoints/cifar10_dec_{}.pth'.format(epoch))
        save_image(img_rec.data, './rec_img/cifar10_rec_{}.png'.format(epoch), nrow=4, normalize=True)
        save_image(inputs.data, './rec_img/cifar10_orig_{}.png'.format(epoch), nrow=4, normalize=True)
        img_rec_saves.append(img_rec)
        img_inputs.append(inputs)
#             print(score)
    #         print(img_rec[0].size())

[1,  100] loss:0.133
[1,  200] loss:0.108
[1,  300] loss:0.102
[1,  400] loss:0.097
[1,  500] loss:0.095
[1,  600] loss:0.093
[1,  700] loss:0.092
[2,  100] loss:0.090
[2,  200] loss:0.089
[2,  300] loss:0.090
[2,  400] loss:0.088
[2,  500] loss:0.087
[2,  600] loss:0.087
[2,  700] loss:0.088
[3,  100] loss:0.086
[3,  200] loss:0.086
[3,  300] loss:0.086
[3,  400] loss:0.084
[3,  500] loss:0.084
[3,  600] loss:0.085
[3,  700] loss:0.084
[4,  100] loss:0.083
[4,  200] loss:0.083
[4,  300] loss:0.084
[4,  400] loss:0.083
[4,  500] loss:0.082
[4,  600] loss:0.084
[4,  700] loss:0.083
[5,  100] loss:0.082
[5,  200] loss:0.083
[5,  300] loss:0.082
[5,  400] loss:0.082
[5,  500] loss:0.081
[5,  600] loss:0.082
[5,  700] loss:0.081
[6,  100] loss:0.081
[6,  200] loss:0.080
[6,  300] loss:0.080
[6,  400] loss:0.081
[6,  500] loss:0.081
[6,  600] loss:0.081
[6,  700] loss:0.079
[7,  100] loss:0.080
[7,  200] loss:0.081
[7,  300] loss:0.079
[7,  400] loss:0.080
[7,  500] loss:0.078
[7,  600] los

KeyboardInterrupt: 

In [17]:
print(len(save_losses))

1062


In [18]:
save_losses[500]

0.0678948038071394

### 测试重建正确率

In [19]:
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 100 == 0:
            print(rec_img.data.shape)
            save_image(rec_img.data, './rec_img/test_rec_{}.png'.format(batch_idx), nrow=4, normalize=True)
            save_image(inputs.data, './rec_img/test_orig_{}.png'.format(batch_idx), nrow=4, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))


torch.Size([32, 3, 32, 32])
  1/313 average acc 0.4062
torch.Size([32, 3, 32, 32])
101/313 average acc 0.3373
torch.Size([32, 3, 32, 32])
201/313 average acc 0.3368
torch.Size([32, 3, 32, 32])
301/313 average acc 0.3393
Test Acc: 0.3400


In [20]:
del test_loader

### FGSM

In [21]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [22]:
sample = Adversarial_Dataset(root_dir,'FGSM_0.2',adversarial_transform)

In [23]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0
)

In [24]:
print(len(sample))

9093


In [25]:
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 20 == 0:
            save_image(rec_img.data, './rec_img/test_FGSM_0.2_rec_{}.png'.format(batch_idx), nrow=4, normalize=True)
            save_image(inputs.data, './rec_img/test_FGSM_0.2_orig_{}.png'.format(batch_idx), nrow=4, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))


  1/285 average acc 0.2812
 21/285 average acc 0.2738
 41/285 average acc 0.2797
 61/285 average acc 0.2772
 81/285 average acc 0.2840
101/285 average acc 0.2853
121/285 average acc 0.2926
141/285 average acc 0.2932
161/285 average acc 0.2880
181/285 average acc 0.2882
201/285 average acc 0.2862
221/285 average acc 0.2872
241/285 average acc 0.2880
261/285 average acc 0.2909
281/285 average acc 0.2911
Test Acc: 0.2917


In [26]:
del test_loader

### CW

In [27]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [28]:
sample = Adversarial_Dataset(root_dir,'CW_0.2',adversarial_transform)

In [29]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [30]:
print(len(sample))

7019


In [31]:
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 20 == 0:
            save_image(rec_img.data, './rec_img/test_CW_0.2_rec_{}.png'.format(batch_idx), nrow=4, normalize=True)
            save_image(inputs.data, './rec_img/test_CW_0.2_orig_{}.png'.format(batch_idx), nrow=4, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))

  1/220 average acc 0.2188
 21/220 average acc 0.2560
 41/220 average acc 0.2325
 61/220 average acc 0.2387
 81/220 average acc 0.2342
101/220 average acc 0.2351
121/220 average acc 0.2358
141/220 average acc 0.2343
161/220 average acc 0.2349
181/220 average acc 0.2341
201/220 average acc 0.2329
Test Acc: 0.2308


In [32]:
del test_loader

### DF

In [16]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [17]:
sample = Adversarial_Dataset(root_dir,'DF_0.2',adversarial_transform)

In [18]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [19]:
print(len(sample))

6357


In [20]:
epoch_size = 0
running_corrects = 0.0
for batch_idx, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(torch.device(0))
    labels = labels.to(torch.device(0))
    with torch.no_grad():
        rec_code = Dec(inputs)
        rec_img = ModelG(rec_code)
        outputs = model(rec_img)
        _, preds = torch.max(outputs, 1)
         # statistics
        running_corrects += torch.sum(preds == labels.data)
        epoch_size += inputs.size(0)
        
        if batch_idx % 20 == 0:
            save_image(rec_img.data, './rec_img/test_DF_0.2_rec_{}.png'.format(batch_idx), nrow=4, normalize=True)
            save_image(inputs.data, './rec_img/test_DF_0.2_orig_{}.png'.format(batch_idx), nrow=4, normalize=True)
            print('{:>3}/{:>3} average acc {:.4f}\r'.format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))
            
        del labels, outputs, preds, rec_code, rec_img

test_acc = running_corrects.double() / epoch_size
print('Test Acc: {:.4f}'.format(test_acc))

  1/199 average acc 0.3438
 21/199 average acc 0.2113
 41/199 average acc 0.2012
 61/199 average acc 0.2085
 81/199 average acc 0.2068
101/199 average acc 0.2132
121/199 average acc 0.2157
141/199 average acc 0.2201
161/199 average acc 0.2180
181/199 average acc 0.2175
Test Acc: 0.2185


In [38]:
del test_loader

### clean Image

In [31]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []
save_test_rec_loss = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            z_orig_sets, z_rec_sets, rec_loss = get_z_sets(ModelG, data, learning_rate, \
                                        device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)
            
            save_test_rec_loss.append(rec_loss)
            
            z_star = get_z_star(ModelG, data, z_rec_sets, rec_loss, device_generator, rec_rr)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()
            # print(data_hat)
            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                save_image(data.data, './rec_img/new_lgd_L{}_R{}_orig_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
                save_image(data_hat.data, './rec_img/new_lgd_L{}_R{}_rec_{}.png'.format(rec_iter,rec_rr,batch_idx), nrow=5, normalize=True)
#                 print('rec_loss: ',rec_loss)
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

  1/313 average acc 0.1875
 21/313 average acc 0.1949
 41/313 average acc 0.1982
 61/313 average acc 0.1952
 81/313 average acc 0.1975
101/313 average acc 0.1934
121/313 average acc 0.1872
141/313 average acc 0.1871
161/313 average acc 0.1867
181/313 average acc 0.1856
201/313 average acc 0.1844
221/313 average acc 0.1818
241/313 average acc 0.1821
261/313 average acc 0.1815
281/313 average acc 0.1841
301/313 average acc 0.1835
rec_iter : 800, rec_rr : 20, Test Acc: 0.1825


In [18]:
print(save_test_results)

[tensor(0.1353, device='cuda:0', dtype=torch.float64)]


In [None]:
del test_loader

### FGSM

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'FGSM_0.002',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader

### CW

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'CW_0.2',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=0,
)

In [None]:
print(len(sample))

### Deep Fool

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'DF_0.002',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader

### Saliency Map

In [None]:
adversarial_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
])

In [None]:
sample = Adversarial_Dataset(root_dir,'SM',adversarial_transform)

In [None]:
test_loader = DataLoader(
    sample,
    batch_size=batch_size,
    num_workers=4
)

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader