In [45]:
import os
import sys
import copy
import numpy as np
import pickle
import torch
import torchvision.models as models

from PIL import Image

from torch.autograd import Variable
from torch.optim import SGD

In [46]:
if torch.cuda.is_available():
    device = 'cuda'
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = 'cpu'
    torch.set_default_tensor_type('torch.FloatTensor')

In [47]:
with open("imagenet_labels.pkl", "rb") as f:
    class_name = pickle.load(f)

In [48]:
model = models.resnet18(pretrained=True)

In [49]:
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [50]:
def preprocess_image(img):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    
    im_as_arr = np.float32(img)
    im_as_arr = im_as_arr.transpose(2, 0, 1)
    
    # 채널 정규화
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
        
    # tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    im_as_ten.unsqueeze_(0)
    im_as_var = Variable(im_as_ten, requires_grad=True)
    
    return im_as_var

In [51]:
def recreate_image(im_as_var):
    reverse_mean = [-0.4914, -0.4822, -0.4465]
    reverse_std = [1/0.2023, 1/0.1994, 1/0.2010]
    recreated_im = copy.copy(im_as_var.cpu().data.numpy()[0])
    for c in range(3):
        recreated_im[c] /= reverse_std[c]
        recreated_im[c] -= reverse_mean[c]
    recreated_im[recreated_im > 1] = 1
    recreated_im[recreated_im < 0] = 0
    recreated_im = np.round(recreated_im * 255)

    recreated_im = np.uint8(recreated_im).transpose(1, 2, 0)
    
    return recreated_im

In [52]:
def save_image(im, path):
    if isinstance(im, (np.ndarray, np.generic)):
        if np.max(im) <= 1:
            im = (im*255).astype(np.uint8)
        im = Image.fromarray(im)
    im.save(path)

In [53]:
target_class = 5
created_image = np.uint8(np.random.uniform(0, 255, (224, 224, 3)))

if not os.path.exists('./generated'):
    os.makedirs('./generated')

initial_learning_rate = 20
for i in range(1, 150):
    processed_image = preprocess_image(created_image)
    optimizer = SGD([processed_image], lr=initial_learning_rate)
    output = model(processed_image.to(device))
    class_loss = -output[0, target_class]
    print('Iteration:', str(i), 'Loss', "{0:.2f}".format(class_loss.cpu().data.numpy()))
    
    model.zero_grad()
    
    class_loss.backward()
    # Update image
    optimizer.step()
    # Recreate image
    created_image = recreate_image(processed_image)
    if i % 10 == 0:
        # Save image
        im_path = './generated/c_specific_iteration_'+str(i)+'.jpg'
        save_image(created_image, im_path)

Iteration: 1 Loss -3.60
Iteration: 2 Loss -11.63
Iteration: 3 Loss -21.11
Iteration: 4 Loss -26.70
Iteration: 5 Loss -38.74
Iteration: 6 Loss -40.35
Iteration: 7 Loss -40.43
Iteration: 8 Loss -41.93
Iteration: 9 Loss -51.07
Iteration: 10 Loss -48.10
Iteration: 11 Loss -54.98
Iteration: 12 Loss -58.98
Iteration: 13 Loss -69.85
Iteration: 14 Loss -71.10
Iteration: 15 Loss -71.78
Iteration: 16 Loss -72.12
Iteration: 17 Loss -77.57
Iteration: 18 Loss -77.44
Iteration: 19 Loss -78.17
Iteration: 20 Loss -79.12
Iteration: 21 Loss -86.56
Iteration: 22 Loss -85.33
Iteration: 23 Loss -82.33
Iteration: 24 Loss -81.62
Iteration: 25 Loss -85.79
Iteration: 26 Loss -90.49
Iteration: 27 Loss -84.15
Iteration: 28 Loss -88.30
Iteration: 29 Loss -97.00
Iteration: 30 Loss -100.08
Iteration: 31 Loss -101.35
Iteration: 32 Loss -96.13
Iteration: 33 Loss -104.04
Iteration: 34 Loss -104.55
Iteration: 35 Loss -102.60
Iteration: 36 Loss -108.58
Iteration: 37 Loss -111.86
Iteration: 38 Loss -113.13
Iteration: 39 