In [None]:
import os
import numpy as np
import pickle
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms as T
from torch.utils.data import DataLoader

from utils_AR_GAN import adjust_lr, get_z_sets, get_z_star, Resize_Image
from Split_data import random_split

from WGAN_GP import Generator
from torchsummary import summary
import copy

## Set Parameters

In [None]:
batch_size = 128
in_channel = 3
height = 32
width = 32
num_classes = 2

display_steps = 20

## Load Data

In [None]:
# load dataset
data_file_path = os.path.join("./data", "stop_speed.pkl")

# Load the data from the file
with open(data_file_path, "rb") as data_file:
    reduced_data = pickle.load(data_file)

train_ds, val_ds, test_ds = random_split(reduced_data)

In [None]:
# Move data to GPU
from deviceSelector import DeviceDataLoader, to_device

torch.cuda.empty_cache()

n_cores = os.cpu_count()
test_loader = DataLoader(test_ds, 
                      batch_size, 
                      shuffle = False, 
                      num_workers = int(n_cores/2), 
                      pin_memory = True)
test_loader = DeviceDataLoader(test_loader)

In [None]:
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'

stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.cpu().detach()[:nmax]), nrow=8).permute(1, 2, 0))
    #ax.imshow(make_grid(images.detach()[:nmax], nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break

In [None]:
show_batch(test_loader)

## Load Classifier

In [None]:
from deviceSelector import DeviceDataLoader, to_device
from ResNet9 import ResNet9

device_model = 'cuda'
model = to_device(ResNet9(3,num_classes), device='cuda')
model.load_state_dict(torch.load('./trained_models/ResNet9/resnet9_m19_retrained.pth'))

## Load Defense-GAN

In [None]:
learning_rate = 10.0
rec_iters = [1000]
rec_rrs = [20]
decay_rate = 0.1
global_step = 3.0
generator_input_size = 32

INPUT_LATENT = 128
device_generator = torch.device('cuda')

In [None]:
ModelG = Generator()
generator_path = './trained_models/WGAN_GP/G_lisa_gp_4519.pth'
ModelG.load_state_dict(torch.load(generator_path))

summary(ModelG, input_size = (INPUT_LATENT,1,1), device = 'cpu')

In [None]:
ModelG = ModelG.to(device_generator)
loss = nn.MSELoss()

## Clean Image

In [None]:
model.eval()

running_corrects = 0
epoch_size = 0

is_input_size_diff = False

save_test_results = []

for rec_iter in rec_iters:
    for rec_rr in rec_rrs:
        
        for batch_idx, (inputs, labels) in enumerate(test_loader):

            # size change

            if inputs.size(2) != generator_input_size :

                target_shape = (inputs.size(0), inputs.size(1), generator_input_size, generator_input_size)

                data = Resize_Image(target_shape, inputs)
                data = data.to(device_generator)

                is_input_size_diff = True

            else :
                data = inputs.to(device_generator)

            # find z*

            _, z_sets = get_z_sets2(ModelG, data, learning_rate, \
                                        loss, device_generator, rec_iter = rec_iter, \
                                        rec_rr = rec_rr, input_latent = INPUT_LATENT, global_step = global_step)

            z_star = get_z_star(ModelG, data, z_sets, loss, device_generator)

            # generate data

            data_hat = ModelG(z_star.to(device_generator)).cpu().detach()

            # size back

            if is_input_size_diff:

                target_shape = (inputs.size(0), inputs.size(1), height, width)
                data_hat = Resize_Image(target_shape, data_hat)

            # classifier 
            data_hat = data_hat.to(device_model)

            labels = labels.to(device_model)

            # evaluate 

            outputs = model(data_hat)

            _, preds = torch.max(outputs, 1)

            # statistics
            running_corrects += torch.sum(preds == labels.data)
            epoch_size += inputs.size(0)

            if batch_idx % display_steps == 0:
                print('{:>3}/{:>3} average acc {:.4f}\r'\
                      .format(batch_idx+1, len(test_loader), running_corrects.double() / epoch_size))

            del labels, outputs, preds, data, data_hat,z_star

        test_acc = running_corrects.double() / epoch_size

        print('rec_iter : {}, rec_rr : {}, Test Acc: {:.4f}'.format(rec_iter, rec_rr, test_acc))
        
        save_test_results.append(test_acc)

In [None]:
del test_loader