In [1]:
from __future__ import absolute_import, division, print_function
import argparse
import os
import re
from collections import defaultdict
import glob
import time
import pathlib
import imageio
import sys
import numpy as np
import fid
import inception as iscore
import imageio
import tensorflow as tf
from torchvision.datasets import CIFAR10, STL10
import torch
import torchvision.utils as vutils
import torch.utils.data as utils
import visdom
from torchvision import transforms
from GAN_training.models import resnet, resnet_extra, resnet_48
from tqdm import tqdm

import data

from sklearn.metrics import pairwise_distances

Instructions for updating:
Use tf.gfile.GFile.


In [2]:
%matplotlib ipympl

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import ipympl

In [3]:
class optclass:
    workaround = True
    
    
opt = optclass()
optdict = {
    'outf': '/scratch0/ilya/locDoc/ACGAN/experiments/yogesh_acgan_0p2',
    'marygan': False,
    'data_root': '/scratch0/ilya/locDoc/data/cifar10',
    'dataset': 'cifar',
    'dev_batch_size': 100,
    'size_labeled_data': 4000,
    'train_batch_size': 128,
    'train_batch_size_2': 100,
    'cifar_fname': '/scratch0/ilya/locDoc/data/cifar10/fid_is_scores.npz',
    'nz': 128,
    'GAN_nz': 128,
    'netG': '',
    'imageSize': 32,
    'ngpu': 1,
    'nc':3
}
for k, v in optdict.items():
    setattr(opt, k, v)

## ACGAN

In [14]:
netGfiles = glob.glob(os.path.join(opt.outf, 'netG_iter_*.pth'))
netGfiles.sort(key = lambda s: int(s.split('_')[-1].split('.')[0]))
opt.netG = netGfiles[-1]
print(opt.netG)

/scratch0/ilya/locDoc/ACGAN/experiments/yogesh_acgan_0p2/netG_iter_274200.pth


In [15]:
if opt.imageSize == 32:
    netG = resnet.Generator(opt)
elif opt.imageSize == 64:
    netG = resnet_extra.Generator(opt)
elif opt.imageSize == 48:
    netG = resnet_48.Generator(opt)
netG.load_state_dict(torch.load(opt.netG))
netG = netG.cuda()

In [496]:
batch_size = opt.train_batch_size
nz = opt.nz
noise = torch.FloatTensor(opt.train_batch_size, nz)
noise = noise.cuda()
num_classes = 10
klass_picked = 9

# create images
n_used_imgs = 10000
n_gen_imgs = ((n_used_imgs // opt.train_batch_size) + 1) * opt.train_batch_size
x = np.empty((n_gen_imgs,3,opt.imageSize,opt.imageSize), dtype=np.uint8)
# create a bunch of GAN images
for l in  tqdm(range((n_used_imgs // opt.train_batch_size) + 1),desc='Generating'):
    start = l * opt.train_batch_size
    end = start + opt.train_batch_size
    noise.data.resize_(batch_size, nz).normal_(0, 1)
    #label = np.random.randint(0, num_classes, batch_size)
    if klass_picked is None:
        label = np.random.randint(0, num_classes, batch_size)
    else:
        label = np.ones((batch_size,),dtype=int)*klass_picked
    noise_ = np.random.normal(0, 1, (batch_size, nz))
    if not opt.marygan:
        class_onehot = np.zeros((batch_size, num_classes))
        class_onehot[np.arange(batch_size), label] = 1
        noise_[np.arange(batch_size), :num_classes] = class_onehot[np.arange(batch_size)]
    noise_ = (torch.from_numpy(noise_))
    noise.data.copy_(noise_.view(batch_size, nz))
    fake = netG(noise).data.cpu().numpy()
    fake = np.floor((fake + 1) * 255/2.0).astype(np.uint8)
    x[start:end] = fake

Generating: 100%|██████████| 79/79 [00:07<00:00, 11.29it/s]


In [448]:
plt.figure()
fake_grid = vutils.make_grid(torch.Tensor(x), nrow=int(np.sqrt(n_used_imgs)), padding=2, normalize=True)
plt.imshow(np.moveaxis(fake_grid.data.cpu().numpy(),0,-1))



FigureCanvasNbAgg()

<matplotlib.image.AxesImage at 0x7fe45e89af28>

### Load feature extracting network

In [342]:
device = torch.device("cuda:0")

In [510]:
torch.cuda.empty_cache()

In [472]:
from classification.models.vgg5 import VGG
compnet = VGG('VGG19')
compnet = torch.nn.DataParallel(compnet)
checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/vgg16','ckpt_200.t7'))
compnet.load_state_dict(checkpoint['net'])
compnet = compnet.to(device)
compnet.eval();
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [480]:
from classification.models.densenet import DenseNet121
compnet = DenseNet121()
compnet = torch.nn.DataParallel(compnet)
checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
compnet.load_state_dict(checkpoint['net'])
compnet = compnet.to(device)
compnet.eval();
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [458]:
from classification.models.vgg_official2 import vgg16
compnet = vgg16(pretrained=True)
compnet = compnet.to(device)
compnet.eval()
transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

### Pass imgs through net

In [497]:
net_in = np.empty(x.shape)
#net_in = np.empty((x.shape[0],) + (3,224,224))
for i in tqdm(range(x.shape[0]),desc='Preprocess'):
    net_in[i] = transform_test(np.moveaxis(x[i],0,-1))

my_dataset = utils.TensorDataset(torch.FloatTensor(net_in))
my_dataloader = utils.DataLoader(my_dataset, batch_size=opt.train_batch_size, shuffle=False)

Preprocess: 100%|██████████| 10112/10112 [00:01<00:00, 7429.75it/s]


In [471]:
x.shape

(128, 3, 32, 32)

In [498]:
#net_out = np.empty((x.shape[0], 602112))
#net_out = np.empty((x.shape[0], 12288))
net_out = np.empty((x.shape[0], 1024))
for i, batch in enumerate(tqdm(my_dataloader,desc='Extract Feat')):
    start = i * opt.train_batch_size
    end = start + opt.train_batch_size
    batch_in = batch[0].to(device)
    batch_out = compnet(batch_in, True).detach().data.cpu()
    net_out[start:end] = batch_out

Extract Feat: 100%|██████████| 79/79 [00:09<00:00,  8.89it/s]


In [504]:
### MSE style
net_out = x.reshape(x.shape[0], np.prod(x.shape[1:]))

In [505]:
net_out.shape

(10112, 3072)

### Plot closest pairs

In [506]:
D = pairwise_distances(net_out)

In [507]:
# remove the diagonal and lower triangle
to_del = np.tril(np.ones((D.shape[0], D.shape[0]), dtype=int))
D[to_del == 1] = D.max()


dists = D.flatten()
closest_N = 20
idxs = np.argpartition(dists,closest_N)
min_idxs = sorted(idxs[:closest_N], key=lambda i: dists[i])
closest_idxs = [(idx // D.shape[0], idx % D.shape[0]) for idx in min_idxs]

In [508]:
closest_imgs = np.empty((closest_N * 2,)+x.shape[1:])
for l, (i,j) in enumerate(closest_idxs):
    closest_imgs[2*l] = x[min(i,j)]
    closest_imgs[2*l + 1] = x[max(i,j)]

In [509]:
# plot closest pairs
plt.figure()
fake_grid = vutils.make_grid(torch.Tensor(closest_imgs[:20]), nrow=2, padding=2, normalize=True)
plt.imshow(np.moveaxis(fake_grid.data.cpu().numpy(),0,-1))



FigureCanvasNbAgg()

<matplotlib.image.AxesImage at 0x7fe45e353f60>

In [468]:
closest_idxs[:10]

[(3711, 9061),
 (4745, 9061),
 (3711, 4745),
 (8305, 9061),
 (3711, 8305),
 (2460, 3711),
 (2460, 9061),
 (4745, 8305),
 (2460, 4745),
 (2460, 8305)]

In [503]:
plt.figure()
fake_grid = vutils.make_grid(torch.Tensor(closest_imgs[20:]), nrow=2, padding=2, normalize=True)
plt.imshow(np.moveaxis(fake_grid.data.cpu().numpy(),0,-1))



FigureCanvasNbAgg()

<matplotlib.image.AxesImage at 0x7fe45e313518>

In [229]:
np.sqrt(((net_out[closest_idxs[0][0]] - net_out[closest_idxs[0][1]])**2).sum())

0.5097062616022188

In [204]:
np.expand_dims(x[closest_idxs[0][0]],0).shape

(1, 3, 32, 32)

In [243]:
xa = transform_test(np.moveaxis(x[closest_idxs[0][0]],0,-1))
xb = transform_test(np.moveaxis(x[closest_idxs[0][1]],0,-1))
a = compnet(torch.FloatTensor(np.expand_dims(xa,0)).to(device)).detach().data.cpu()
b = compnet(torch.FloatTensor(np.expand_dims(xb,0)).to(device)).detach().data.cpu()

In [245]:
b

tensor([[-1.0944, -1.3794,  2.8354,  0.2152, -1.2497,  2.4409,  3.7628, -0.8741,
         -1.6860, -2.9780]])