In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

mpl.rcParams['font.sans-serif'] = "Arial"
mpl.rcParams['font.family'] = "Arial"

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy
import netdissect
from netdissect.easydict import EasyDict
from experiment import dissect_experiment as experiment
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model='progan', dataset='kitchen', seg='netpqc', layer='layer5', quantile=0.01)
resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000))
def resfile(f):
    return os.path.join(resdir, f)

In [None]:
model = experiment.load_model(args)
layername = experiment.instrumented_layername(args)
model.retain_layer(layername)
dataset = experiment.load_dataset(args, model.model)
upfn = experiment.make_upfn(args, dataset, model, layername)
sample_size = len(dataset)
percent_level = 0.995

In [None]:
from netdissect import renormalize

segmodel, seglabels, segcatlabels = experiment.setting.load_segmenter(args.seg)
renorm = renormalize.renormalizer(dataset, target='zc')

In [None]:
from netdissect import renormalize

indices = range(200,212)
batch = torch.cat([dataset[i][0][None,...] for i in indices])
outs = model(batch.cuda())
imgs = [renormalize.as_image(t) for t in outs]
show([[img] for img in imgs])

In [None]:
from netdissect import imgviz

iv = imgviz.ImageVisualizer(120)
seg = segmodel.segment_batch(renorm(outs).cuda(), downsample=4)[:,0:1]

show([(iv.image(outs[i]), iv.segmentation(seg[i,-1]),
            iv.segment_key(seg[i,-1], segmodel))
            for i in range(len(seg))])

In [None]:
window_segclasses = [i for i, n in enumerate(seglabels) if n == 'window']
window_segclasses

Generate 1000 images and separate them into those that have windows and those that do not.

In [None]:
def window_present(data_batch, *args):
    image_batch = model(data_batch.cuda())
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)[:,0:1]
    present_count = (seg == window_segclasses[0]).view(seg.shape[0], -1).sum(1)
    return present_count
    
window_presence = tally.tally_cat(
    window_present, dataset, sample_size=20000, pin_memory=True,
    cachefile=resfile('window_presence.npz'))

In [None]:
in_examples = (window_presence > 204).nonzero()[:,0]
out_examples = (window_presence <= 204).nonzero()[:,0]
len(in_examples), len(out_examples)

In [None]:
def make_image(data_batch, *args):
    return model(data_batch.cuda())
    
generated_image = tally.tally_cat(make_image, dataset, sample_size=len(window_presence), pin_memory=True)

In [None]:
unit_number = 314

def max_unit(data_batch, *args):
    _ = model(data_batch.cuda())
    acts = model.retained_layer(layername)[:, unit_number]
    return acts.view(acts.shape[0], -1).max(1)[0]
    
maxvals_all = tally.tally_cat(max_unit, dataset, sample_size=len(window_presence), pin_memory=True)

In [None]:
result = ([[maxvals_all[i].item(), renormalize.as_image(generated_image[i])]
          for indices in [in_examples[:10], out_examples[:10]]
          for i in indices])
show(result)

In [None]:
import seaborn as sns
import random

in_vals = maxvals_all[in_examples].numpy()[:1000]
out_vals = maxvals_all[out_examples].numpy()
nbvals = out_vals[:len(in_vals)]


fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(8, 3), dpi=300,
                              # facecolor='white',
                              sharex='all')

#ax1.set_title('VGG-16 places unit 19 maxact, 150 random non-bus images (blue) and 150 bus images (red)')
# a3.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
ax2.scatter(nbvals, range(len(nbvals)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(in_vals, range(len(in_vals)), alpha=0.5, s=10, c='#4B4CBF')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Jitterplot')

ax1.axvline(numpy.array(in_vals).mean().item(), color='#B6B6F2', linewidth=1.5, linestyle='--')
ax1.axvline(numpy.array(out_vals).mean().item(), color='#F2CFB6', linewidth=1.5, linestyle='--')

#ax2.set_title('Kernel density plot')
sns.distplot(out_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="kitchens with < 5%% window pixels, mean=%.1f" % numpy.array(out_vals).mean().item(),
            ax=ax1)
sns.distplot(in_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="kitchens with > 5%% window pixels, mean=%.1f" % numpy.array(in_vals).mean().item(),
            ax=ax1)
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_xlim(-3, 32)
ax2.set_xlabel('Unit 314 peak activation in image')

plt.savefig("sgup.pdf", bbox_inches='tight')

In [None]:
len(in_vals)

In [None]:
for i in (in_vals > 20).nonzero()[0]:
    print(i)
    show(renormalize.as_image(generated_image[i]))

In [None]:
len(generated_image)