In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy
import netdissect
from netdissect.easydict import EasyDict
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show
from experiment import dissect_experiment as experiment

args = EasyDict(model='progan', dataset='church', seg='netpqc', quantile=0.01, layer='layer4')
resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000))
def resfile(f):
    return os.path.join(resdir, f)

In [None]:
model = experiment.load_model(args)
layername = experiment.instrumented_layername(args)
model.retain_layer(layername)
dataset = experiment.load_dataset(args, model.model)
upfn = experiment.make_upfn(args, dataset, model, layername)
sample_size = len(dataset)
percent_level = 1 - args.quantile

In [None]:
from netdissect import renormalize

segmodel, seglabels, segcatlabels = experiment.setting.load_segmenter(args.seg)
renorm = renormalize.renormalizer(target='zc')

In [None]:
pbar.descnext('rq')
def compute_samples(batch, *args):
    z_batch = batch.cuda()
    _ = model(z_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
rq = tally.tally_quantile(compute_samples, dataset,
                          sample_size=sample_size,
                          r=8192,
                          num_workers=100,
                          pin_memory=True,
                          cachefile=resfile('rq.npz'))

In [None]:
def compute_conditional_indicator(batch, *args):
    data_batch = batch.cuda()
    out_batch = model(data_batch)
    image_batch = out_batch
    seg = segmodel.segment_batch(image_batch, downsample=4)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator
    return tally.conditional_samples(iacts, seg)

pbar.descnext('condi99')
condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
        dataset, sample_size=sample_size,
        num_workers=3, pin_memory=True,
        cachefile=resfile('condi99.npz'))

In [None]:
iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
unit_label_99 = [
        (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*iou_99.max(0))]
label_list = [labelcat for concept, label, labelcat, iou in unit_label_99 if iou > 0.04]
display(IPython.display.SVG(experiment.graph_conceptcatlist(label_list)))
len(label_list)

In [None]:
segindex = seglabels.index('tree')
tree_units = (-iou_99[segindex]).sort(0)[1][:20]
tree_units

In [None]:
from netdissect import renormalize

indices = [489, 200, 726, 803, 920, 926] #range(200,224)
batch = torch.cat([dataset[i][0][None,...] for i in indices])
outs = model(batch.cuda())
imgs = [renormalize.as_image(t) for t in outs]
show([[img] for img in imgs])

def zero_tree_units(x, *args):
    x[:, tree_units] = 0
    return x
model.edit_layer(layername, rule=zero_tree_units)
outs = model(batch.cuda())
imgs = [renormalize.as_image(t) for t in outs]
show([[img] for img in imgs])
model.remove_edits()

In [None]:
def test_segclass_with_zeroed_units(segclass, zeroed_units, sample_size=100):
    model.remove_edits()
    def zero_some_units(x, *args):
        x[:, zeroed_units] = 0
        return x
    model.edit_layer(layername, rule=zero_some_units)
    def compute_mean_seg_in_images(batch_z, *args):
        img = model(batch_z.cuda())
        seg = segmodel.segment_batch(img, downsample=4)
        segmatch = (seg == segclass).max(1)[0].float().view(seg.shape[0], -1).sum(1)
        # Express in units of fractions of an image
        return segmatch[:,None] / (seg.shape[2] * seg.shape[3])
    result = tally.tally_mean(compute_mean_seg_in_images, dataset,
                            batch_size=30, sample_size=sample_size, pin_memory=True)
    model.remove_edits()
    return result


In [None]:
def measure_segclasses_with_zeroed_units(zeroed_units, sample_size=100):
    model.remove_edits()
    def zero_some_units(x, *args):
        x[:, zeroed_units] = 0
        return x
    model.edit_layer(layername, rule=zero_some_units)
    num_seglabels = len(segmodel.get_label_and_category_names()[0])
    def compute_mean_seg_in_images(batch_z, *args):
        img = model(batch_z.cuda())
        seg = segmodel.segment_batch(img, downsample=4)
        seg_area = seg.shape[2] * seg.shape[3]
        seg_counts = torch.bincount((seg + (num_seglabels *
            torch.arange(seg.shape[0], dtype=seg.dtype, device=seg.device)[:,None,None,None])).view(-1),
            minlength=num_seglabels * seg.shape[0]).view(seg.shape[0], -1)
        seg_fracs = seg_counts.float() / seg_area
        return seg_fracs
    result = tally.tally_mean(compute_mean_seg_in_images, dataset,
                            batch_size=30, sample_size=sample_size, pin_memory=True)
    model.remove_edits()
    return result


In [None]:
segs_baseline = measure_segclasses_with_zeroed_units([])
segs_without_treeunits = measure_segclasses_with_zeroed_units(tree_units)

In [None]:
print(segs_baseline.mean()[:10], segs_baseline.stdev())
print(segs_without_treeunits.mean()[:10], segs_without_treeunits.stdev())

In [None]:
num_units = len(unit_label_99)
baseline_segmean = experiment.test_generator_segclass_stats(model, dataset, segmodel,
            layername=layername,
            cachefile=resfile('segstats/baseline.npz')).mean()
unit_ablation_segmean = torch.zeros(num_units, len(baseline_segmean))
for unit in range(num_units):
    unit_ablation_segmean[unit] = experiment.test_generator_segclass_stats(model, dataset, segmodel,
                layername=layername, zeroed_units=[unit],
                cachefile=resfile('segstats/ablated_unit_%d.npz' % unit)).mean()

In [None]:
a = torch.zeros([])
a[None] = 1.0
a

In [None]:
ablate_segclass_name = 'tree'
ablate_segclass = seglabels.index(ablate_segclass_name)

In [None]:
best_ss_units = unit_ablation_segmean[:,ablate_segclass].sort(0)[1]
best_ss_units

In [None]:
unit_ablation_segmean.shape

In [None]:
best_iou_units = iou_99[ablate_segclass,:].sort(0)[1].flip(0)
best_iou_units

In [None]:
import math
byiou_unit_ablation_seg = torch.zeros(1)
byiou_unit_ablation_seg_stdev = torch.zeros(1)
for unitcount in range(0,1):
    zero_units = best_iou_units[:unitcount].tolist()
    stats = experiment.test_generator_segclass_stats(model, dataset, segmodel,
                layername=layername, zeroed_units=zero_units,
                cachefile=resfile('segstats/ablated_best_%d_iou_%s.npz' %
                    (unitcount, ablate_segclass_name)))
    byiou_unit_ablation_seg[unitcount] = stats.mean()[ablate_segclass]  
    byiou_unit_ablation_seg_stdev[unitcount] = stats.stdev()[ablate_segclass]
byiou_unit_ablation_delta_seg = torch.zeros(31)
byiou_unit_ablation_delta_seg_stdev = torch.zeros(31)
byiou_unit_ablation_delta_seg_stderr = torch.zeros(31)
for unitcount in range(0,31):
    zero_units = best_iou_units[:unitcount].tolist()
    stats = experiment.test_generator_segclass_delta_stats(model, dataset, segmodel,
                layername=layername, zeroed_units=zero_units,
                cachefile=resfile('deltasegstats/ablated_best_%d_iou_%s.npz' %
                    (unitcount, ablate_segclass_name)))
    byiou_unit_ablation_delta_seg[unitcount] = stats.mean()[ablate_segclass]  
    byiou_unit_ablation_delta_seg_stdev[unitcount] = stats.stdev()[ablate_segclass]
    byiou_unit_ablation_delta_seg_stderr[unitcount] = stats.stdev()[ablate_segclass] / math.sqrt(stats.size())
    

In [None]:
stats.size()

In [None]:
fig, ax = plt.subplots(figsize=(7,3), dpi=300)
# y = 1 - byiou_unit_ablation_seg.numpy()/byiou_unit_ablation_seg.numpy()[0]
y = -byiou_unit_ablation_delta_seg.numpy()/byiou_unit_ablation_seg.numpy()[0]
yerr = byiou_unit_ablation_delta_seg_stderr.numpy()/byiou_unit_ablation_seg.numpy()[0]
ax.plot(y, linewidth=2, color="#4B4CBF")
ax.fill_between(range(len(y)), y-yerr*2.58, y+yerr*2.56,
                edgecolor='#55B05B', facecolor='#55B05B',
    antialiased=True)
ax.set_ylim([0,0.7])
#ax.set_xlabel('Number of units removed (units ranked by IoU with trees)')
ax.set_ylabel('Portion of tree pixels removed')
ax.set_xticks([0, 2, 4, 8, 20, 30])
ax.set_yticklabels(['{:,.0%}'.format(x) for x in ax.get_yticks()])
ax.grid(linewidth=0.5)

In [None]:
yerr

In [None]:
y

In [None]:
byiou_unit_ablation_delta_seg_stdev

In [None]:
byiou_unit_ablation_seg_stdev

In [None]:
byiou_unit_ablation_delta_seg

In [None]:
plt.figure(figsize=(15,5))
plt.plot((unit_ablation_segmean[:,4] / baseline_segmean[4]).sort(0)[0].numpy()[:30], linewidth=3)

In [None]:
from netdissect import renormalize
indices = [489, 726, 920, 926] #range(200,224)
batch = torch.cat([dataset[i][0][None,...] for i in indices])

for unit_count in [0, 2, 4, 8, 20]:
    tree_units = best_iou_units[:unit_count]
    def zero_tree_units(x, *args):
        x[:, tree_units] = 0
        return x
    model.remove_edits()
    model.edit_layer(layername, rule=zero_tree_units)
    outs = model(batch.cuda())
    imgs = [renormalize.as_image(t) for t in outs]
    show([[img] for img in imgs])
    model.remove_edits()

In [None]:
door_segclass = seglabels.index('door')
door_segclass

In [None]:
door_units = iou_99[door_segclass].sort(0)[1].flip(0)[:20]
door_high_values = rq.quantiles(0.995)[door_units].cuda()

In [None]:
door_high_values

In [None]:
from netdissect import segviz

def add_yellow_box(timg, y1, y2, x1, x2, thickness):
    yellow = torch.tensor([1.0, 1.0, 0.0], dtype=timg.dtype, device=timg.device)[None, :, None, None]
    def yclip(c):
        return max(0, min(timg.shape[2], c))
    def xclip(c):
        return max(0, min(timg.shape[3], c))
    
    timg[:, :, yclip(y1):yclip(y2+thickness), xclip(x1):xclip(x1+thickness)] = yellow
    timg[:, :, yclip(y1):yclip(y2+thickness), xclip(x2):xclip(x2+thickness)] = yellow
    timg[:, :, yclip(y1):yclip(y1+thickness), xclip(x1):xclip(x2+thickness)] = yellow
    timg[:, :, yclip(y2):yclip(y2+thickness), xclip(x1):xclip(x2+thickness)] = yellow
    return timg
indices = [726] #[489, 726, 920, 926] #range(200,224)


for index in pbar([591, 589, 561, 422, 499, 315, 361, 396, 19, 25, 71, 151, 159, 167, 188, 279, ]):
    indices = [index]
    batch = torch.cat([dataset[i][0][None,...] for i in indices])
    batchc = batch.cuda()[:1]
    model.remove_edits()
    orig_img = model(batchc)
    orig_seg = segmodel.segment_batch(orig_img, downsample=4)
    orig_door = (orig_seg == door_segclass).view(len(batchc), -1).sum(1)
    rep = model.retained_layer(layername).clone()
    ysize = orig_seg.shape[2] // rep.shape[2]
    xsize = orig_seg.shape[3] // rep.shape[3]
    for y in range(rep.shape[2]):
        for x in range(rep.shape[3]):
            changed_rep = rep.clone()
            changed_rep[:,door_units,y,x] = door_high_values[None,:]
            # def subst(x, *args):
            #    return changed_rep
            # model.edit_layer(layername, rule=subst) # lambda x, ctx: changed_rep) # ablation=0.5, replacement=changed_rep)
            model.edit_layer(layername, ablation=1.0, replacement=changed_rep)
            changed_img = model(batchc)
            changed_seg = segmodel.segment_batch(changed_img, downsample=4)
            changed_door = (changed_seg == door_segclass).view(len(batchc), -1).sum(1)
            if (changed_door - orig_door).max().item() > 2:
                selsegs = orig_seg[:,:,y*ysize+ysize//2,x*xsize+xsize//2].view(-1)
                orig_img_copy = orig_img.clone()
                add_yellow_box(orig_img_copy, y*32-1, (y+1)*32-1, x*32-1, (x+1)*32-1, 2)
                existing = ' '.join([seglabels[sc] for sc in selsegs if sc != 0])
                show([['#%d %d %d repd %.2f rgbd %.2f doord %.1f %s' %
                       (index, y, x, (changed_rep - rep).max().item(),
                        (changed_img - orig_img).max().item(),
                        (changed_door - orig_door).max().item(),
                        existing
                       ),
                       [renormalize.as_image(orig_img_copy[0])],
                       [renormalize.as_image(img)],
                       # [segviz.seg_as_image(orig_seg[i, 2:3], size=256)],
                       [segviz.seg_as_image(changed_seg[i, 2:3], size=256)],
                       [segviz.segment_key(changed_seg[i, 2:3], segmodel, 10)]]
                      for i, img in enumerate(changed_img)])


In [None]:
for index, coordlist in pbar([
    (151, [(5, 3), (5, 5)]),
    (279, [(5, 3), (5, 5)]),
]):
    indices = [index]
    batch = torch.cat([dataset[i][0][None,...] for i in indices])
    batchc = batch.cuda()[:1]
    model.remove_edits()
    orig_img = model(batchc)
    orig_seg = segmodel.segment_batch(orig_img, downsample=4)
    orig_door = (orig_seg == door_segclass).view(len(batchc), -1).sum(1)
    rep = model.retained_layer(layername).clone()
    ysize = orig_seg.shape[2] // rep.shape[2]
    xsize = orig_seg.shape[3] // rep.shape[3]
    for y, x in coordlist:
        changed_rep = rep.clone()
        changed_rep[:,door_units,y,x] = door_high_values[None,:]
        # def subst(x, *args):
        #    return changed_rep
        # model.edit_layer(layername, rule=subst) # lambda x, ctx: changed_rep) # ablation=0.5, replacement=changed_rep)
        model.edit_layer(layername, ablation=1.0, replacement=changed_rep)
        changed_img = model(batchc)
        changed_seg = segmodel.segment_batch(changed_img, downsample=4)
        changed_door = (changed_seg == door_segclass).view(len(batchc), -1).sum(1)
        selsegs = orig_seg[:,:,y*ysize+ysize//2,x*xsize+xsize//2].view(-1)
        # orig_img_copy = orig_img.clone()
        add_yellow_box(orig_img, y*32-1, (y+1)*32-1, x*32-1, (x+1)*32-1, 2)
        existing = ' '.join([seglabels[sc] for sc in selsegs if sc != 0])
        show([['#%d %d %d repd %.2f rgbd %.2f doord %.1f %s' %
               (index, y, x, (changed_rep - rep).max().item(),
                (changed_img - orig_img).max().item(),
                (changed_door - orig_door).max().item(),
                existing
               ),
               [renormalize.as_image(orig_img[0])],
               [renormalize.as_image(img)],
               # [segviz.seg_as_image(orig_seg[i, 2:3], size=256)],
               [segviz.seg_as_image(changed_seg[i, 2:3], size=256)],
               [segviz.segment_key(changed_seg[i, 2:3], segmodel, 10)]]
              for i, img in enumerate(changed_img)])

In [None]:
num_segclass = len(seglabels)
num_segclass

def batch_bincount(data, num_labels):
    data = data.view(len(data), -1)
    strided = data + torch.arange(len(data), dtype=data.dtype, device=data.device)[:,None] * num_labels
    counts = torch.bincount(strided.view(-1), minlength=num_labels * len(data))
    return counts.view(len(data), num_labels)

def compute_seg_impact(zbatch, *args):
    zbatch = zbatch.cuda()
    model.remove_edits()
    orig_img = model(zbatch)
    orig_seg = segmodel.segment_batch(orig_img, downsample=4)
    orig_segcount = batch_bincount(orig_seg, num_segclass)
    rep = model.retained_layer(layername).clone()
    ysize = orig_seg.shape[2] // rep.shape[2]
    xsize = orig_seg.shape[3] // rep.shape[3]
    def gen_conditions():
        for y in range(rep.shape[2]):
            for x in range(rep.shape[3]):
                # Take as the context location the segmentation labels at the center of the square.
                selsegs = orig_seg[:,:,y*ysize+ysize//2,x*xsize+xsize//2]
                changed_rep = rep.clone()
                changed_rep[:,door_units,y,x] = door_high_values[None,:]
                model.edit_layer(layername, ablation=1.0, replacement=changed_rep)
                changed_img = model(zbatch)
                changed_seg = segmodel.segment_batch(changed_img, downsample=4)
                changed_segcount = batch_bincount(changed_seg, num_segclass)
                delta_segcount = (changed_segcount - orig_segcount).float()
                for sel, delta in zip(selsegs, delta_segcount):
                    for cond in torch.bincount(sel).nonzero()[:,0]:
                        if cond == 0:
                            continue
                        yield (cond.item(), delta)
    return gen_conditions()

cond_changes = tally.tally_conditional_mean(compute_seg_impact, dataset, sample_size=10000, batch_size=20,
                                           cachefile=resfile('big_door_cond_changes.npz'))
cond_changes
            
            



In [None]:
def compute_seg_counts(zbatch, *args):
    zbatch = zbatch.cuda()
    model.remove_edits()
    orig_img = model(zbatch)
    orig_seg = segmodel.segment_batch(orig_img, downsample=4)
    orig_segcount = batch_bincount(orig_seg, num_segclass)
    return orig_segcount.float()

baseline_segcounts = tally.tally_mean(compute_seg_counts, dataset, sample_size=10000, batch_size=100,
                                     cachefile=resfile('baseline_segcounts.npz'))
baseline_segcounts
 

In [None]:
baseline_door = baseline_segcounts.mean()[seglabels.index('door')].item()
baseline_door

In [None]:
sorted([(k, seglabels[k], cond_changes.conditional(k).size(),
         cond_changes.conditional(k).mean()[seglabels.index('door')].item() / baseline_door)
 for k in cond_changes.keys()
 if cond_changes.conditional(k).size() >= 1000], key=lambda x: -x[-1])




In [None]:
fig, ax = plt.subplots(figsize=(6,3), dpi=300)
glabels = ['window', 'stairway', 'building', 'grass', 'tree', 'sky']
ax.bar(range(len(glabels)), [
    cond_changes.conditional(seglabels.index(gl)).mean()[seglabels.index('door')].item()
    / baseline_door
    for gl in glabels],
    yerr=[ math.sqrt(
        cond_changes.conditional(seglabels.index(gl)).variance()[seglabels.index('door')].item()
          # / cond_changes.conditional(seglabels.index(gl)).size())
        )
        / baseline_door
        for gl in glabels],
        error_kw=dict(lw=1, capsize=5, capthick=1),
        color="#4B4CBF"
      )
ax.set_xticklabels(['']  + glabels)
ax.set_yticklabels(['{:,.0%}'.format(x) for x in ax.get_yticks()])
ax.set_ylabel('Added door area')
ax.set_ylim([0,0.19])

In [None]:
import math
(cond_changes.conditional(seglabels.index('window')).variance()[seglabels.index('door')].item()
  / math.sqrt(cond_changes.conditional(seglabels.index('window')).size()))

In [None]:
from collections import defaultdict
catcount = defaultdict(int)
for _, cat in segcatlabels[1:]:
    catcount[cat] += 1
print(catcount)