In [None]:
import numpy as np
import pickle
from utils.Association import Association
import subprocess
from torchvision import transforms
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
from utils.mapping import simple_labels
import matplotlib.pyplot as plt
import time
import torchvision.models as models
import torch
from tqdm import tqdm
from itertools import chain, combinations
from collections import defaultdict

In [None]:
tmp = pickle.load(open("data/val_set/1/patches_interpolated_filters.pkl", "rb"))
print(tmp[0].keys())

dict_keys(['mask_score', 'mask', 'name', 'features', 'perturbed_score', 'true_label', 'image_label'])


In [None]:
%%time

# Load all transactions

class_ids = range(1000)

patches = []
for i in tqdm(class_ids):
    patches.extend(pickle.load(open("data/val_set/" + str(i) + "/patches_interpolated_filters.pkl", "rb")))


def get_trans(patch):
    indices = [(i, x) for i, x in enumerate(patch["features"])]
    trans = [i[0] for i in sorted(indices, key=lambda x: x[1], reverse=True)][:5]
    trans.append(p["true_label"] * -1 - 10000)
    #trans.append(p["image_label"] * -1 - 20000)
    return trans


transactions = []
for p in tqdm(patches):
    transactions.append(get_trans(p))

print(transactions[0])

100%|██████████| 1000/1000 [00:05<00:00, 190.77it/s]
100%|██████████| 84616/84616 [00:11<00:00, 7542.55it/s]

[478, 100, 302, 432, 242, -10000]
CPU times: user 16.1 s, sys: 280 ms, total: 16.3 s
Wall time: 16.5 s





In [None]:
# author: Bart Goethals, University of Antwerp, Belgium
# Adapted by Toon Meynen & Stijn Rosaer


def eclat(prefix, minsup, items, start=True):
    frequents = []
    while items:
        i, itids = items.pop()
        isupp = len(itids)
        if isupp >= minsup:
            frequents.append((frozenset(prefix + [i]), isupp))
            suffix = []
            for j, ojtids in items:
                jtids = set(itids) & set(ojtids)
                if len(jtids) >= minsup:
                    suffix.append((j, jtids))
            frequents.extend(
                eclat(prefix + [i], minsup, sorted(suffix, key=lambda item: len(item[1]), reverse=True), False))
    return frequents


def subsets(itemset):
    """ List all strict subsets of an itemset without the empty set or with the empty set if include_empty_set=True
        subsets({1,2,3}) --> [{1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}]
    """
    s = list(itemset)
    #if len(s) < 3:
    if len(s) < 2:
        return set()
    return map(set, chain.from_iterable(combinations(s, r) for r in range(1, len(s) - 1)))
    # return map(set, combinations(s, len(s)-2))


def deriveRules(itemsets, minconf):
    """ Returns all rules with conf >= minconf that can be derived from the itemsets.
        Return: list of association rules in the format: [(antecedent, consequent, supp, conf), ...]
    """
    search_items = dict(itemsets)
    rules = set()
    for item_set, supp in itemsets:  #
        if len(item_set) > 1:
            for subset in subsets(item_set):  # for each subset generate a rule
                antecedent = frozenset([i for i in subset if i > 0])
                if len(antecedent) < 1:
                    continue

                consequent = frozenset([i for i in item_set - subset if i < 0])
                if len(consequent) == 1 or len(consequent) == 2:
                    if len(antecedent) > 1:
                        conf = supp / search_items[antecedent]
                        if conf >= minconf:
                            rules.add(Association(antecedent, consequent, conf, supp))

    return rules

# counts how often "items" occur together, ignores labels in this step.
def count(items, tidlist):
    tids = [tidlist[i] for i in items if i >= 0]
    if tids:
        return len(set.intersection(*tids))
    else:
        return 0

# function that converts a list of transactions to a dictionary going {item -> [trans_1, trans_5, ...]}
def tidlist(transactions):
    data = {}
    trans = 0
    for transaction in transactions:
        trans += 1
        for item in transaction:
            if item not in data:
                data[item] = set()
            data[item].add(trans)
    return data

def genRules(min_conf, minsup, transactions, all_data):
    # generate tidlist of subset
    data = tidlist(transactions)
    # find frequent sets within this subset
    frequent_itemsets = eclat([], minsup, sorted(data.items(), key=lambda item: len(item[1]), reverse=True))
    # reweigh these sets over the full dataset
    reweighed_sets = []
    for i, _ in frequent_itemsets:
        # count will count how often this set appears, while disregarding all labels
        c = count(i, all_data)
        # if set only contains labels this can thus return zero, remove those
        if c:
            reweighed_sets.append((i, c))
    # derive rules from reweighed sets
    rules = deriveRules(reweighed_sets, min_conf)
    return rules

In [None]:
%%time

# Generate rules label by label


rules = set()
all_data = tidlist(transactions)
for true_id in tqdm(range(-11000, -9999)):
    tmp = genRules(0.5, 2, list(t for t in transactions if true_id in t), all_data)
    tmp = sorted(tmp, key=lambda x: x.s * x.c, reverse=True)
    rules.update(tmp[:10])
rules = list(rules)

print(f"{len(rules)} rules generated ")
#rules = genRules(0.8, 20, transactions)

100%|██████████| 1001/1001 [00:17<00:00, 58.67it/s]

9972 rules generated 
CPU times: user 17.1 s, sys: 27.3 ms, total: 17.2 s
Wall time: 17.1 s





In [None]:
# Transactions are no longer needed after generating rules
del transactions

In [None]:
# Remove rules that are contained within other rules
# A -> C & AB -> C


print("---Removing subset rules---")
t = time.time()
new_rules = set()
for r in sorted(rules, reverse=True, key=lambda x: len(x.left)):
    if not len(new_rules):
        new_rules.add(r)

    subset = False
    for new_rule in new_rules:
        if r.right == new_rule.right and r.s == new_rule.s and r.c == new_rule.c:
            if all(item in new_rule.left for item in r.left):
                subset = True
                break

    if not subset:
        new_rules.add(r)
print(f"{len(new_rules)} remain after {time.time() - t:.2f} seconds")
del rules

---Removing subset rules---
9941 remain after 4.94 seconds


In [None]:
### Generate basic association rule based system
def generate_filters():
    filters = dict()
    filters["cnn_18"] = dict()
    for i in range(1000):
        filters["cnn_18"][i] = set()
    for r in new_rules:
        # cnn_18 as it only affects final layer
        # right side is the label
        # left side is the set of filters
        filters["cnn_18"][r.right[0] * -1 - 10000].update(r.left)

    for i in range(1000):
        filters["cnn_18"][i] = list(filters["cnn_18"][i])

    with open("VEBI/assocL.pickle", "wb") as f:
        pickle.dump(filters, f)

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

# values to normalize input
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
toTensor = transforms.ToTensor()
normalize = transforms.Normalize(mean, std)


In [None]:
%load_ext line_profiler

In [None]:
new_rules = sorted(new_rules, key=lambda x: x.c * x.s, reverse=True)

In [None]:
t = time.time()
cnn_groups = dict()
for index, patch in tqdm(enumerate(patches)):
    trans = get_trans(patch)
    for rule in sorted(new_rules, key=lambda x: len(x.left), reverse=True):
        if all(item in trans for item in rule.left):
            if rule.left not in cnn_groups:
                cnn_groups[rule.left] = set()
            cnn_groups[rule.left].add(index)
print(f"{len(cnn_groups)} groups in {time.time() - t:.2f} seconds")

84616it [06:35, 214.21it/s]

5474 groups in 395.02 seconds





In [None]:
new_groups = dict()
for i in cnn_groups:
    if len(cnn_groups[i]) >= 20:
        new_groups[i] = cnn_groups[i]

print(len(new_groups))

1996


In [None]:
label_dict = dict()
for l in tqdm(class_ids):
    label_dict[l] = []
    for label in cnn_groups:
        indices = cnn_groups[label]
        group = [patches[i] for i in indices]
        scores = dict()
        for patch in group:
            if patch["true_label"] not in scores:
                scores[patch["true_label"]] = 0
            scores[patch["true_label"]] += 1
        if max(scores, key=scores.get) == l:
            label_dict[l].append(label)

100%|██████████| 1000/1000 [00:45<00:00, 21.76it/s]


In [None]:
#frozendict.frozendict({'filename': 'ILSVRC2012_val_00001906.JPEG', 'true_label': 8, 'labels': (-1012, -12), 'patch': (44, 100, 99, 155), 'score': 0.8469841})
def bar_dict(d, labels=True):
    totals = sum([d[line] for line in d])
    if labels:
        for i in sorted([(simple_labels[line], f"{d[line] / totals * 100:.1f}%", d[line]) for line in d],
                        key=lambda x: x[2], reverse=True):
            if i[2] / totals > 0.01:
                print(i[1], "\t", i[0])
    else:
        for i in sorted([(line, f"{d[line] / totals * 100:.1f}%", d[line]) for line in d],
                        key=lambda x: x[2], reverse=True):
            if i[2] / totals > 0.01:
                print(i[1], "\t", i[0])

    return {line: d[line] / totals for line in d}


class AddAll:
    def __init__(self):
        self.image = None
        self.counter = 0

    def add(self, im):
        if self.image is None:
            self.image = im
        else:
            self.image += im
        self.counter += 1

    def normalize(self):
        self.image /= self.counter

    def display(self, name):
        plt.figure(figsize=(5, 5))
        plt.imshow(self.image)
        plt.savefig(f"plots/{name}.png")
        plt.show()


class Score:
    def __init__(self):
        self.score = 0
        self.count = 0
        self.buffer = []

    def add(self, original_values, label):
        self.buffer.append((original_values, label))


    def process(self, distribution):
        for ov, _ in self.buffer:
            for d in distribution:
                if ov[d] > 0 and ov[d] <= 1:
                    self.score += ov[d] * distribution[d]
            self.count += 1

        self.buffer = []

    def normalize(self, distribution):
        if len(self.buffer):
            self.process(distribution)
            self.score /= self.count

    def display(self):
        print(f"Score: {self.score*100:.3f}%")

In [None]:
def display_all_layers():
    groups_containing_layer = [(l, cnn_groups[l]) for l in cnn_groups if len(l) == 5]
    display_groups(groups_containing_layer, "len = 5")

def display_layers(layers, DISPLAY=True):
    groups_containing_layer = [(l, cnn_groups[l]) for l in cnn_groups if all(item in l for item in layers)]
    display_groups(groups_containing_layer, layers, DISPLAY)

def display_single_group(totals, all_score, label, indices, layers, DISPLAY):
    group = [patches[i] for i in indices]

    if DISPLAY:
        print("="*30)
        print("Groupsize:", len(group))
    group_totals = dict()

    if DISPLAY:
        width = 16
        height = int(np.ceil(min(len(group), 128) / width))
        axes = []
        plt.rcParams['figure.figsize'] = [14, 14 * (height / width)]
        fig = plt.figure()

    # for average image
    add_all = AddAll()
    score = Score()
    for index, patch in enumerate(sorted(group, key=lambda x: x["mask_score"], reverse=True)):

        if DISPLAY:
            # Read file
            input = transform(
                Image.open("data/val_set/" + str(patch["true_label"]) + "/img/" + str(patch["name"])).convert(
                    'RGB'))[None, :, :]
            p = patch["mask"]

            # add to visual output
            if index < 128:
                axes.append(fig.add_subplot(height, width, index + 1))
                plt.imshow(np.transpose(input.data[0].cpu().numpy(), (1, 2, 0))[p[0]:p[1], p[2]:p[3], :])
                axes[-1].set_xticks([])
                axes[-1].set_yticks([])

            # add to averaged image
            #add_all.add(np.transpose(input.data[0].cpu().numpy(), (1, 2, 0)))
            add_all.add(np.transpose(input.data[0].cpu().numpy(), (1, 2, 0))[p[0]:p[1], p[2]:p[3], :])

        # calculate groups score
        score.add(patch["perturbed_score"], patch["true_label"])
        all_score.add(patch["perturbed_score"], patch["true_label"])

        # count labels
        for potential_label in [patch["true_label"]]:  # + list(patch["labels"]):
            if potential_label < 0:
                potential_label *= -1
                potential_label -= 10000
                if potential_label > 10000:
                    potential_label -= 10000

            if potential_label not in group_totals:
                group_totals[potential_label] = 0
                if potential_label not in totals:
                    totals[potential_label] = 0
            group_totals[potential_label] += 1
            totals[potential_label] += 1

    if DISPLAY:
        fig.suptitle("filters: " + str(label), fontsize=16)

        plt.tight_layout(pad=0, h_pad=0, w_pad=0, rect=[0, 0, 1, 0.95])
        # plt.subplots_adjust(wspace=0.1, hspace=0.1, left=0, right=1, bottom=0, top=1)
        plt.savefig(f"plots/{str(label)}_full.png")
        plt.show()

    if DISPLAY:
        add_all.normalize()
        add_all.display(str(label))

    if DISPLAY:
        score.normalize(bar_dict(group_totals))
        score.display()

    return totals, all_score

def display_groups(groups_containing_layer, layers, DISPLAY=True):
    print(len(groups_containing_layer))
    print(f"---Displaying {len(groups_containing_layer)} groups containing {str(layers)}---")

    totals = dict()
    all_score = Score()
    for label, indices in groups_containing_layer:
        totals, all_score = display_single_group(totals, all_score, label, indices, layers, DISPLAY)
    all_score.normalize(bar_dict(totals))
    all_score.display()

def display_label(label, DISPLAY=True):
    print(f"--- Displaying {simple_labels[label]} ({label}) ---")

    totals = dict()
    all_score = Score()

    for layers in label_dict[label]:
        indices = cnn_groups[layers]
        totals, all_score = display_single_group(totals, all_score, layers, indices, layers, DISPLAY)

    all_score.normalize(bar_dict(totals))
    all_score.display()

    label_importance = defaultdict(int)
    for lb in label_dict[label]:
        for l in lb:
            label_importance[l] += 1

    print(f"--- Filters used for {simple_labels[label]} ---")
    bar_dict(label_importance, False)
    return label_importance

def display_filter(layer, DISPLAY=True):
    display_layers((layer,), DISPLAY)