In [6]:
import numpy as np
#import h5py
import time
import copy
from random import randint
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch import autograd
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

transform_test = transforms.Compose([
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
batch_size = 128
testset = torchvision.datasets.CIFAR10(root='./', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8)
testloader = enumerate(testloader)

cuda:0


In [7]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()               # input dim
        self.conv1 = self.conv(3  ,196, 3,1,1, (196,32,32)) # 32*32
        self.conv2 = self.conv(196,196, 3,2,1, (196,16,16)) # 32*32
        self.conv3 = self.conv(196,196, 3,1,1, (196,16,16)) # 16*16
        self.conv4 = self.conv(196,196, 3,2,1, (196,8,8))   # 16*16
        self.conv5 = self.conv(196,196, 3,1,1, (196,8,8))   # 8*8
        self.conv6 = self.conv(196,196, 3,1,1, (196,8,8))   # 8*8
        self.conv7 = self.conv(196,196, 3,1,1, (196,8,8))   # 8*8
        self.conv8 = self.conv(196,196, 3,2,1, (196,4,4))   # 8*8
        self.pool = nn.MaxPool2d(4,4)           # 4*4
        
        self.fc1 = nn.Linear(196, 1)            # 1     critic output
        self.fc10 = nn.Linear(196, 10)          # 10    auxiliary classifier output
        
    def conv(self,i,o,k,s,p,ln_dim):
        return nn.Sequential(
            nn.Conv2d(i,o,k,s,p),
            nn.LayerNorm(ln_dim),
            nn.LeakyReLU()
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.pool(x)
        x1 = x.view(-1, 196)
        critic = self.fc1(x1)
        x2 = x.view(-1, 196)
        classifier = self.fc10(x2)
        return critic,classifier




In [8]:

model = torch.load('discriminator.model')
model.to(device)
model.eval()

  "type " + container_type.__name__ + ". It won't be checked "


discriminator(
  (conv1): Sequential(
    (0): Conv2d(3, 196, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LayerNorm(torch.Size([196, 32, 32]), eps=1e-05, elementwise_affine=True)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (conv2): Sequential(
    (0): Conv2d(196, 196, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LayerNorm(torch.Size([196, 16, 16]), eps=1e-05, elementwise_affine=True)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (conv3): Sequential(
    (0): Conv2d(196, 196, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LayerNorm(torch.Size([196, 16, 16]), eps=1e-05, elementwise_affine=True)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (conv4): Sequential(
    (0): Conv2d(196, 196, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LayerNorm(torch.Size([196, 8, 8]), eps=1e-05, elementwise_affine=True)
    (2): LeakyReLU(negative_slope=0.01)
  )
  (conv5): Sequential(
    (0): Conv2d(196, 196, kernel_size=(3, 3), stride=(1, 1), p

In [9]:
batch_idx, (X_batch, Y_batch) = testloader.__next__()
X_batch = Variable(X_batch,requires_grad=True).to(device)
Y_batch = Variable(Y_batch).to(device)

X = X_batch.mean(dim=0)
X = X.repeat(10,1,1,1)

Y = torch.arange(10).type(torch.int64)
Y = Variable(Y).to(device)
_, output = model(X_batch)
prediction = output.data.max(1)[1] # first column has actual prob.
accuracy = ( float( prediction.eq(Y_batch.data).sum() ) /float(batch_size))*100.0
print(accuracy)


In [10]:
def plot(samples):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(10, 10)
    gs.update(wspace=0.02, hspace=0.02)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample)
    return fig

lr = 0.1
weight_decay = 0.001
for i in range(200):
    _, output = model(X)

    loss = -output[torch.arange(10).type(torch.int64),torch.arange(10).type(torch.int64)]
    gradients = torch.autograd.grad(outputs=loss, inputs=X,
                              grad_outputs=torch.ones(loss.size()).to(device),
                              create_graph=True, retain_graph=False, only_inputs=True)[0]

    prediction = output.data.max(1)[1] # first column has actual prob.
    accuracy = ( float( prediction.eq(Y.data).sum() ) /float(10.0))*100.0
    print(i,accuracy,-loss)

    X = X - lr*gradients.data - weight_decay*X.data*torch.abs(X.data)
    X[X>1.0] = 1.0
    X[X<-1.0] = -1.0

## save new images
samples = X.data.cpu().numpy()
samples += 1.0
samples /= 2.0
samples = samples.transpose(0,2,3,1)

fig = plot(samples)
plt.savefig('visualization/max_class_dis.png', bbox_inches='tight')
plt.close(fig)

0 10.0 tensor([ 7.4868, -0.1894,  6.5896,  4.5246,  8.4412,  3.0818,  5.2118, -0.4515,
         7.7461,  1.0045], device='cuda:0', grad_fn=<NegBackward>)
1 40.0 tensor([ 9.4723,  0.9185, 10.6275,  5.5657, 12.6575,  4.2017,  6.6812,  0.7278,
        10.4475,  2.0195], device='cuda:0', grad_fn=<NegBackward>)
2 50.0 tensor([10.8040,  1.7884, 14.7934,  6.4649, 17.1496,  5.8783,  8.0086,  2.2690,
        12.8965,  3.0311], device='cuda:0', grad_fn=<NegBackward>)
3 70.0 tensor([12.1182,  2.5546, 18.2371,  7.2204, 20.4293,  7.9703,  9.2522,  5.2247,
        15.4674,  4.2395], device='cuda:0', grad_fn=<NegBackward>)
4 80.0 tensor([13.5515,  3.2589, 21.3748,  7.8397, 22.6068, 10.0078, 11.1615, 10.2289,
        18.0450,  5.5425], device='cuda:0', grad_fn=<NegBackward>)
5 80.0 tensor([15.1480,  3.9923, 23.7729,  8.6710, 24.2019, 12.0454, 13.1999, 15.3409,
        20.4784,  6.8311], device='cuda:0', grad_fn=<NegBackward>)
6 90.0 tensor([17.0437,  4.8927, 25.6707,  9.6314, 25.2608, 13.9871, 15.7112

59 100.0 tensor([37.5582, 40.6715, 41.0063, 29.6375, 37.3729, 35.3806, 39.7638, 42.2153,
        42.1454, 32.9835], device='cuda:0', grad_fn=<NegBackward>)
60 100.0 tensor([37.6335, 40.7970, 41.0713, 29.7343, 37.4703, 35.4738, 39.8769, 42.3425,
        42.2378, 33.1177], device='cuda:0', grad_fn=<NegBackward>)
61 100.0 tensor([37.7246, 40.9374, 41.1632, 29.8251, 37.5383, 35.5633, 39.9874, 42.4647,
        42.3157, 33.2181], device='cuda:0', grad_fn=<NegBackward>)
62 100.0 tensor([37.7867, 41.0662, 41.2149, 29.9297, 37.6658, 35.6435, 40.0945, 42.5576,
        42.3492, 33.3251], device='cuda:0', grad_fn=<NegBackward>)
63 100.0 tensor([37.8849, 41.1908, 41.2898, 30.0041, 37.7420, 35.7322, 40.2143, 42.6908,
        42.4366, 33.4271], device='cuda:0', grad_fn=<NegBackward>)
64 100.0 tensor([37.9592, 41.3079, 41.3561, 30.1021, 37.8465, 35.8162, 40.2887, 42.7912,
        42.4966, 33.5223], device='cuda:0', grad_fn=<NegBackward>)
65 100.0 tensor([38.0434, 41.4233, 41.4429, 30.1876, 37.8983, 35

122 100.0 tensor([41.0069, 45.8941, 43.7769, 33.6055, 42.0583, 39.1367, 43.9694, 47.0068,
        44.9600, 37.3696], device='cuda:0', grad_fn=<NegBackward>)
123 100.0 tensor([41.0845, 45.9568, 43.7960, 33.6485, 42.0906, 39.1781, 44.0252, 47.0536,
        45.0142, 37.4082], device='cuda:0', grad_fn=<NegBackward>)
124 100.0 tensor([41.1129, 46.0193, 43.8334, 33.6936, 42.1516, 39.2151, 44.0448, 47.0903,
        45.0109, 37.4318], device='cuda:0', grad_fn=<NegBackward>)
125 100.0 tensor([41.1564, 46.0781, 43.8320, 33.7365, 42.1991, 39.2527, 44.0713, 47.1524,
        45.0794, 37.5120], device='cuda:0', grad_fn=<NegBackward>)
126 100.0 tensor([41.1915, 46.1391, 43.8654, 33.7755, 42.2286, 39.2849, 44.1268, 47.2023,
        45.1276, 37.5774], device='cuda:0', grad_fn=<NegBackward>)
127 100.0 tensor([41.2337, 46.1970, 43.8984, 33.8194, 42.2797, 39.3233, 44.1785, 47.2529,
        45.1401, 37.5873], device='cuda:0', grad_fn=<NegBackward>)
128 100.0 tensor([41.2744, 46.2578, 43.9001, 33.8648, 42.3

186 100.0 tensor([43.3771, 49.1860, 45.3693, 36.0440, 44.8804, 41.1202, 45.8326, 49.6601,
        46.4877, 39.7187], device='cuda:0', grad_fn=<NegBackward>)
187 100.0 tensor([43.4154, 49.2275, 45.4076, 36.0747, 44.9142, 41.1469, 45.8837, 49.6948,
        46.5261, 39.7437], device='cuda:0', grad_fn=<NegBackward>)
188 100.0 tensor([43.4306, 49.2611, 45.4372, 36.1037, 44.9399, 41.1770, 45.9021, 49.7144,
        46.5425, 39.7487], device='cuda:0', grad_fn=<NegBackward>)
189 100.0 tensor([43.4929, 49.2962, 45.4458, 36.1422, 44.9798, 41.2081, 45.9280, 49.7540,
        46.5704, 39.7929], device='cuda:0', grad_fn=<NegBackward>)
190 100.0 tensor([43.5032, 49.3363, 45.4479, 36.1673, 45.0371, 41.2346, 45.9272, 49.7843,
        46.5871, 39.7949], device='cuda:0', grad_fn=<NegBackward>)
191 100.0 tensor([43.5158, 49.3515, 45.4923, 36.2008, 45.0501, 41.2614, 45.9726, 49.8150,
        46.5776, 39.8556], device='cuda:0', grad_fn=<NegBackward>)
192 100.0 tensor([43.5572, 49.3993, 45.5014, 36.2370, 45.1