In [None]:
from netdissect import parallelfolder, show, tally, nethook, renormalize
from experiment import readdissect, setting
import copy, PIL.Image
from netdissect import upsample, imgsave, imgviz
import re, torchvision, torch, os
from IPython.display import SVG
from matplotlib import pyplot as plt

def normalize_filename(n):
    return re.match(r'^(.*Places365_\w+_\d+)', n).group(1)

ds = parallelfolder.ParallelImageFolders(
    ['datasets/places/val', 'datasets/stylized-places/val'],
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        # transforms.CenterCrop(224),
        torchvision.transforms.CenterCrop(256),
        torchvision.transforms.ToTensor(),
        renormalize.NORMALIZER['imagenet'],
    ]),
    normalize_filename=normalize_filename,
    shuffle=True)


layers = [
    'conv5_3',
    'conv5_2',
    'conv5_1',
    'conv4_3',
    'conv4_2',
    'conv4_1',
    'conv3_3',
    'conv3_2',
    'conv3_1',
    'conv2_2',
    'conv2_1',
    'conv1_2',
    'conv1_1',
]
qd = readdissect.DissectVis(layers=layers)
net = setting.load_classifier('vgg16')

sds = parallelfolder.ParallelImageFolders(
    ['datasets/stylized-places/val'],
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        # transforms.CenterCrop(224),
        torchvision.transforms.CenterCrop(256),
        torchvision.transforms.ToTensor(),
        renormalize.NORMALIZER['imagenet'],
    ]),
    normalize_filename=normalize_filename,
    shuffle=True)

def s_image(layername, unit):
    result = PIL.Image.open(os.path.join(qd.dir(layername), 's_imgs/unit%d.jpg' % unit))
    result.load()
    return result


def su_image(layername, unit):
    result = PIL.Image.open(os.path.join(qd.dir(layername), 'su_imgs/unit%d.jpg' % unit))
    result.load()
    return result


In [None]:
for layername in layers:
    inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda()
    inst_net.retain_layer('features.' + layername)
    inst_net(ds[0][0][None].cuda())
    sample_act = inst_net.retained_layer('features.' + layername).cpu()
    upfn = upsample.upsampler((64, 64), sample_act.shape[2:])

    def flat_acts(batch):
        inst_net(batch.cuda())
        acts = upfn(inst_net.retained_layer('features.' + layername))
        return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
    s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join(qd.dir(layername), 's_rq.npz'))
    u_rq = qd.rq(layername)

    def intersect_99_fn(uimg, simg):
        s_99 = s_rq.quantiles(0.99)[None,:,None,None].cuda()
        u_99 = u_rq.quantiles(0.99)[None,:,None,None].cuda()
        with torch.no_grad():
            ux, sx = uimg.cuda(), simg.cuda()
            inst_net(ux)
            ur = inst_net.retained_layer('features.' + layername)
            inst_net(sx)
            sr = inst_net.retained_layer('features.' + layername)
            return ((sr > s_99).float() * (ur > u_99).float()).permute(0, 2, 3, 1).reshape(-1, ur.size(1))
    
    intersect_99 = tally.tally_mean(intersect_99_fn, ds,
        cachefile=os.path.join(qd.dir(layername), 'intersect_99.npz'))
    print(layername)
    numerator = intersect_99.mean()
    denominator = (0.02 - intersect_99.mean())
    score = (numerator / denominator).clamp(0, 1)
    plt.plot(score)
    plt.show()
    fig, ax = plt.subplots(1, 1, figsize=(3,1.2), dpi=300)
    ax.hist(score)
    ax.set_ylabel('%s units' % (layername.replace('features.', '')))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # ax.set_xlabel('unit IoU (stylized vs original)')
    plt.show()
    labelcat_list_h = []
    labelcat_list_l = []
    for i, rec in enumerate(qd.labels[layername]):
        if rec['iou'] and float(rec['iou']) >= 0.04:
            if score[i] > 0.1:
                labelcat_list_h.append((rec['label'], rec['cat']))
            else:
                labelcat_list_l.append((rec['label'], rec['cat']))
    display(SVG(qd.bargraph_from_conceptcatlist(labelcat_list_l)))
    display(SVG(qd.bargraph_from_conceptcatlist(labelcat_list_h)))
    
    ordering = score.sort()[1]

    for i in torch.cat([ordering[:5], ordering[-10:]]):
        #if qd.iou(layername, i) > 0.04:
            print(i.item(), score[i].item(), qd.label(layername, i), qd.iou(layername, i))
            display(qd.image(layername, i))
            display(s_image(layername, i))

    #result = [qd.iou(layername, i) for i in ordering]
    #plt.plot(result)

In [None]:
fig, axes = plt.subplots(5, 1, figsize=(5,6), dpi=300, sharex=True)
plotlayers = [
    'features.conv1_2',
    'features.conv2_2',
    'features.conv3_3',
    'features.conv4_3',
    'features.conv5_3',
]
for i, layername in enumerate(plotlayers):
    inst_net = nethook.InstrumentedModel(copy.deepcopy(net)).cuda()
    inst_net.retain_layer(layername)
    
    def flat_acts(batch):
        inst_net(batch.cuda())
        acts = upfn(inst_net.retained_layer(layername))
        return acts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
    s_rq = tally.tally_quantile(flat_acts, sds, cachefile=os.path.join(qd.dir(layername), 's_rq.npz'))
    u_rq = qd.rq(layername)

    def intersect_99_fn(uimg, simg):
        s_99 = s_rq.quantiles(0.99)[None,:,None,None].cuda()
        u_99 = u_rq.quantiles(0.99)[None,:,None,None].cuda()
        with torch.no_grad():
            ux, sx = uimg.cuda(), simg.cuda()
            inst_net(ux)
            ur = inst_net.retained_layer(layername)
            inst_net(sx)
            sr = inst_net.retained_layer(layername)
            return ((sr > s_99).float() * (ur > u_99).float()).permute(0, 2, 3, 1).reshape(-1, ur.size(1))
    
    intersect_99 = tally.tally_mean(intersect_99_fn, ds,
        cachefile=os.path.join(qd.dir(layername), 'intersect_99.npz'))
    numerator = intersect_99.mean()
    denominator = (0.02 - intersect_99.mean())
    score = (numerator / denominator).clamp(0, 0.5)
    ax = axes[i]
    ax.hist(score)
    # ax.set_ylabel('%s' % (layername.replace('features.', '')))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # ax.set_xlabel('unit IoU (stylized vs original)')
plt.show()


In [None]:
for u in [166, 107, 268, 434, 436, 437, 73, 220, 299, 494, 485, 477, 462, 338]:
    print(u, score[u].item())

In [None]:
qd.dirs