In [55]:
import os
import copy
import torch
import numpy as np
from PIL import Image

from torch.autograd import Variable
from torch.optim import SGD
from src.models.cifar10.resnet import ResNet18

In [56]:
if torch.cuda.is_available():
    device = 'cuda'
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = 'cpu'
    torch.set_default_tensor_type('torch.FloatTensor')

In [57]:
model = ResNet18(alpha=1).to(device)

model.load_state_dict(torch.load('./pretrained/resnet18_cifar10_gvp_model_10.pth'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [58]:
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [59]:
def preprocess_image(img):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    
    im_as_arr = np.float32(img)
    im_as_arr = im_as_arr.transpose(2, 0, 1)
    
    # 채널 정규화
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
        
    # tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    im_as_ten.unsqueeze_(0)
    im_as_var = Variable(im_as_ten, requires_grad=True)
    
    return im_as_var

In [60]:
def recreate_image(im_as_var):
    reverse_mean = [-0.4914, -0.4822, -0.4465]
    reverse_std = [1/0.2023, 1/0.1994, 1/0.2010]
    recreated_im = copy.copy(im_as_var.cpu().data.numpy()[0])
    for c in range(3):
        recreated_im[c] /= reverse_std[c]
        recreated_im[c] -= reverse_mean[c]
    recreated_im[recreated_im > 1] = 1
    recreated_im[recreated_im < 0] = 0
    recreated_im = np.round(recreated_im * 255)

    recreated_im = np.uint8(recreated_im).transpose(1, 2, 0)
    
    return recreated_im

In [61]:
def save_image(im, path):
    if isinstance(im, (np.ndarray, np.generic)):
        if np.max(im) <= 1:
            im = (im*255).astype(np.uint8)
        im = Image.fromarray(im)
    im.save(path)

In [62]:
target_class = 5
created_image = np.uint8(np.random.uniform(0, 255, (224, 224, 3)))

if not os.path.exists('../generated'):
    os.makedirs('../generated')

initial_learning_rate = 20
for i in range(1, 150):
    processed_image = preprocess_image(created_image)
    optimizer = SGD([processed_image], lr=initial_learning_rate)
    output = model(processed_image.to(device))
    class_loss = -output[0, target_class]
    print('Iteration:', str(i), 'Loss', "{0:.2f}".format(class_loss.cpu().data.numpy()))
    
    model.zero_grad()
    
    class_loss.backward()
    # Update image
    optimizer.step()
    # Recreate image
    created_image = recreate_image(processed_image)
    if i % 10 == 0:
        # Save image
        im_path = '../generated/c_specific_iteration_'+str(i)+'.jpg'
        save_image(created_image, im_path)

Iteration: 1 Loss 5.85
Iteration: 2 Loss 4.80
Iteration: 3 Loss 4.14
Iteration: 4 Loss 3.69
Iteration: 5 Loss 3.34
Iteration: 6 Loss 3.07
Iteration: 7 Loss 2.84
Iteration: 8 Loss 2.65
Iteration: 9 Loss 2.49
Iteration: 10 Loss 2.34
Iteration: 11 Loss 2.20
Iteration: 12 Loss 2.08
Iteration: 13 Loss 1.96
Iteration: 14 Loss 1.85
Iteration: 15 Loss 1.74
Iteration: 16 Loss 1.63
Iteration: 17 Loss 1.51
Iteration: 18 Loss 1.40
Iteration: 19 Loss 1.28
Iteration: 20 Loss 1.15
Iteration: 21 Loss 1.02
Iteration: 22 Loss 0.87
Iteration: 23 Loss 0.72
Iteration: 24 Loss 0.55
Iteration: 25 Loss 0.36
Iteration: 26 Loss 0.15
Iteration: 27 Loss -0.07
Iteration: 28 Loss -0.32
Iteration: 29 Loss -0.58
Iteration: 30 Loss -0.88
Iteration: 31 Loss -1.21
Iteration: 32 Loss -1.56
Iteration: 33 Loss -1.96
Iteration: 34 Loss -2.41
Iteration: 35 Loss -2.94
Iteration: 36 Loss -3.53
Iteration: 37 Loss -4.20
Iteration: 38 Loss -4.86
Iteration: 39 Loss -5.37
Iteration: 40 Loss -6.27
Iteration: 41 Loss -7.03
Iteration: