In [None]:
# mini MNIST
# sampled from original MNIST data, e.g., https://huggingface.co/datasets/ylecun/mnist

import pandas as pd
import numpy as np

x = pd.read_csv("mnist-imgs-mini.csv.bz2", index_col=0)
x = x/255.0
x = x.to_numpy()
print(x.shape, x.dtype)

y = pd.read_csv("mnist-lbls-mini.csv.bz2", index_col=0)
y = y.to_numpy()
print(y.shape, y.dtype)

label_names = range(10)
labels = list(label_names)
print(label_names)

n_channels = 1

In [None]:
# Plot some random examples
import matplotlib as mpl
import matplotlib.pyplot as plt
import math

# based on https://www.kaggle.com/code/gainknowledge/mnist-scikit-learn-tutorial
# Apache 2.0 License

width = int(math.sqrt(x.shape[1] / n_channels))

def displayData(X,Y):
    # set up array
    grid = 7
    fig, ax = plt.subplots(nrows=grid, ncols=grid, figsize=(15,15))
    # loop over randomly drawn numbers
    for i in range(grid):
        for j in range(grid):
            ind = np.random.randint(X.shape[0])
            if (n_channels == 1):
                tmp = X[ind,:].reshape(width, width)
                ax[i,j].imshow(tmp, cmap='gray_r') # display it as gray colors.
            else:
                tmp = X[ind,:].reshape(width, width, n_channels)
                ax[i,j].imshow(tmp) # display it as rgb colors.
            ax[i,j].set_title(label_names[Y[ind][0]])
            ax[i,j].axis('off')
    
    fig.subplots_adjust(hspace=0.5, wspace=0.5)

print(x.shape)
displayData(x,y)

In [None]:
# train SVM on classes
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.svm import SVC
from sklearn.metrics import f1_score

n_samples = 1000
train_rows = np.random.choice(x.shape[0], size=n_samples)
classifier_svm = SVC(probability=True) #), gamma=0.005, C=8.0)
classifier_svm.fit(x[train_rows], y[train_rows])
pred = classifier_svm.predict(x[train_rows])
#ConfusionMatrixDisplay.from_estimator(true_svm, x[train_rows], y[train_rows], labels=labels, normalize="true", values_format='.2f')
ConfusionMatrixDisplay.from_predictions(y[train_rows], pred, labels=labels, normalize="true", values_format='.2f')
print(label_names)
print(f'f1 = {f1_score(y[train_rows], pred, average="micro", labels=labels)}')
plt.show()

In [None]:
# optional: save svm
import pickle
with open("cls.pkl", "wb") as f:
    pickle.dump(classifier_svm, f, protocol=5)

In [None]:
# optional: load svm
import pickle
with open("cls.pkl", "rb") as f:
    classifier_svm = pickle.load(f)

In [None]:
# find interesting classifications
import math

n_samples = min(x.shape[0], 10000)
selected_rows = np.random.choice(x.shape[0], size=n_samples)
pred_proba = classifier_svm.predict_proba(x[selected_rows])
pred_cls = classifier_svm.predict(x[selected_rows])
ground_truth = y[selected_rows]

cls_of_interest = []
correct = 0
true_mismatch_proba = 0
for idx in range(n_samples):
    pred_cls_1st = np.argmax(pred_proba[idx])

    if pred_cls_1st == ground_truth[idx]:
        correct = correct + 1

    if pred_cls_1st != pred_cls[idx]:
        # print(f"true mismatch #{idx}: pred={pred_true[idx]} max_prob={pred_true_1st} p={pred_true_p[idx][pred_true_1st]} y={y[idx][0]}")
        true_mismatch_proba = true_mismatch_proba + 1
        continue

    proba_1st = pred_proba[idx][pred_cls_1st]
    pred_proba[idx][pred_cls_1st] = 0 

    pred_cls_2nd = np.argmax(pred_proba[idx])
    proba_2nd = pred_proba[idx][pred_cls_2nd]

    # ignore if ground truth is not in first two guesses
    if (pred_cls_2nd != ground_truth[idx]) & (pred_cls_1st != ground_truth[idx]):
        continue

    # ignore if classifier is super shure
    if proba_1st - proba_2nd > .5: continue

    if pred_cls_1st != ground_truth[idx]:
        cls_of_interest.append([selected_rows[idx], pred_cls_1st, pred_cls_2nd, ground_truth[idx][0], [proba_1st, proba_2nd]])
        

print(f"correct: {correct/n_samples}")
print(f"mismatch between pred and probs: true={true_mismatch_proba/n_samples}")
print(f"interesting examples {len(cls_of_interest) / n_samples}")

In [None]:
# explain with lime

# Ribeiro, M. T., Singh, S. & Guestrin, C. (2016). Why should I trust you? Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (S. 1135–1144). doi: 10.1145/2939672.2939778
# cmp "official" tutorial at https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20MNIST%20and%20RF.ipynb
# BSD 2-clause

from lime.lime_image import LimeImageExplainer
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage import color as sk_color
from skimage.color import label2rgb
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
import random

n_pixels = x.shape[1]
width = int(math.sqrt(n_pixels / n_channels))

setup = [[1, classifier_svm]]

explainer = LimeImageExplainer()
# segmenter -> return int matrix of same shape with superpixel index in each cell

def split_evenly(img, n_splits):
    grid = np.zeros((img.shape[0], img.shape[1]), dtype=int)
    grid_width_x = img.shape[0] / n_splits
    grid_width_y = img.shape[1] / n_splits
    for col in range(img.shape[0]):
        x_grid = math.trunc(col / grid_width_x)
        for row in range(img.shape[1]):
            y_grid = math.trunc(row / grid_width_y)
            grid[col][row] = x_grid + y_grid * n_splits
    return grid


grids = {}
def grid_n_parts(img, splits):
    grid_idx = (img.shape[0], img.shape[1], splits)
    if grid_idx in grids:
        return grids[grid_idx]
    else:
        grid = split_evenly(img, splits)
        grids[grid_idx] = grid
        return grid

def grid_5parts(img):
    return grid_n_parts(img, 5)

def grid_7parts(img):
    return grid_n_parts(img, 7)

def grid_11parts(img):
    return grid_n_parts(img, 11)

# 'quickshift', 'slic', 'felzenszwalb'
# https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.quickshift
# ratio: color=1 space=0
quickshift = SegmentationAlgorithm('quickshift')
# https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.felzenszwalb
felzenszwalb = SegmentationAlgorithm('felzenszwalb', sigma=width/20, scale=width/50, channel_axis=2)
# https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.slic
slic_sma = SegmentationAlgorithm('slic', n_segments=25, compactness=.01, convert2lab=False, enforce_connectivity=True)
slic_mid = SegmentationAlgorithm('slic', n_segments=25, compactness=0.1, convert2lab=False, enforce_connectivity=True)
slic_lrg = SegmentationAlgorithm('slic', n_segments=25, compactness=1.0, convert2lab=False, enforce_connectivity=True)


## in case different aggregation functions should be used for some of the channels (e.g. value)
def average_hsvs(imgs):
    c = np.zeros_like(imgs[0])
    for img in imgs:
        c += img
    c /= len(imgs)
    return c

## combine several hsv images, with hue being weighted by the saturation
def combine_hsvs(imgs):
    c = np.zeros_like(imgs[0])
    for col in range(c.shape[0]):
        for row in range(c.shape[1]):
            huevec = [0, 0]
            saturation = 0
            value = 0
            for img in imgs:
                huevec[0] += math.cos(img[col][row][0]*math.pi*2) * img[col][row][1]
                huevec[1] += math.sin(img[col][row][0]*math.pi*2) * img[col][row][1]
                saturation += img[col][row][1]
                value += img[col][row][2]

            c[col][row][0] = math.atan2(huevec[1], huevec[0])/math.pi/2
            c[col][row][1] = saturation / len(imgs)
            c[col][row][2] = value / len(imgs)
    return c

def combine_weights(ws):
    c = np.zeros_like(ws[0])
    for w in ws:
        c += w
    norm = max(abs(np.min(c)), np.max(c))
    c /= norm
    return c

white = np.ones((width,width))
HUE_RED = 0.0 # 0 deg
HUE_GREEN = 1.0/3.0 # 120 deg
VALUE_ADD = 0.5 # how much should the computed weight also increases the value

N_SAMPLES = 4

print(f"{len(cls_of_interest)} examples")

for trial in random.sample(cls_of_interest, N_SAMPLES):
    idx = trial[0]
    pred_1st = trial[1]
    pred_2nd = trial[2]
    true_class = trial[3]

    print(f'{idx}: 1st {pred_1st}, 2nd {pred_2nd}, true {true_class}, c_1 {trial[4]}')

    for s in setup:
        clsf_id = s[0]
        classifier = s[1]

        def predict_instance(img_batch):
            if n_channels == 1:
                # back to gray scale and adjust dimensionality
                gray = sk_color.rgb2gray(img_batch).reshape((-1, width*width))
                pred = classifier.predict_proba(gray)
            else:
                # adjust dimensionality
                pred = classifier.predict_proba(img_batch.reshape((-1, width*width*n_channels)))

            return pred

        if n_channels == 1:
            img_data = x[idx].reshape((width, width))
            img = sk_color.gray2rgb(img_data)
            # MNIST is inverted
            orig_img = sk_color.gray2rgb(white - img_data)
        else:
            img_data = x[idx].reshape((width, width, n_channels))
            # no need to invert
            img = img_data
            orig_img = img

        def get_expl_weights(segm_fn):
            explanation = explainer.explain_instance(img, predict_instance, 
                                                    num_features=100,
                                                    labels=(pred_1st, pred_2nd),
                                                    hide_color=None, num_samples=500,
                                                    segmentation_fn=segm_fn,
                                                    batch_size=100)

            p1st_local_weights = explanation.local_exp[pred_1st] # list of (feat-id, weight)
            p2nd_local_weights = explanation.local_exp[pred_2nd]

            p1st_max_weight = abs(p1st_local_weights[0][1])
            p2nd_max_weight = abs(p2nd_local_weights[0][1])

            p1weights = {}
            for w in p1st_local_weights:
                p1weights[w[0]] = w[1]
            p2weights = {}
            for w in p2nd_local_weights:
                p2weights[w[0]] = w[1]

            super_pixels = explanation.segments
            w1 = np.zeros_like(super_pixels, dtype=np.float32)
            w2 = np.zeros_like(super_pixels, dtype=np.float32)

            for col in range(img.shape[0]):
                for row in range(img.shape[1]):
                    pix_id = super_pixels[col][row]
                    if pix_id in p1weights:
                        w1[col][row] = p1weights[pix_id] / p1st_max_weight
                    if pix_id in p2weights:
                        w2[col][row] = p2weights[pix_id] / p2nd_max_weight

            return w1, w2
        
        def weights2mask(w):
            base_color = np.mean(orig_img)
            bg_color = np.array([base_color,base_color,base_color])
            white = np.array([1,1,1])

            result_img = np.copy(orig_img) # todo create empty image in img.size and init with bgcolor?
            for col in range(result_img.shape[0]):
                for row in range(result_img.shape[1]):
                    if w[col][row]>0:
                        weight = math.pow(w[col][row], 1.0)
                        result_img[col][row] = weight * orig_img[col][row] + (1-weight) * white
                    else:
                        result_img[col][row] = white
            return result_img

        def weights2hsv(w):
            result_img = sk_color.rgb2hsv(orig_img)
            for col in range(w.shape[0]):
                for row in range(w.shape[1]):
                    #weight = weight_trans(abs(w[col][row]))
                    weight = math.pow(abs(w[col][row]), 1.0)
                    # saturation
                    result_img[col][row][1] = weight
                    # increase value (to get visible colors in dark areas)
                    result_img[col][row][2] = min(1, weight*VALUE_ADD + result_img[col][row][2])
                    # TODO blend between weight color and original color
                    if w[col][row] > 0:
                        result_img[col][row][0] = HUE_GREEN # positive weight -> green
                    elif w[col][row] < 0: 
                        result_img[col][row][0] = HUE_RED # negative weight -> red
            return result_img
        
        def make_plot(wimg1, wimg2, with_p, seg_id):
            inverse_order = False
            p1 = math.trunc(100 * trial[3+clsf_id][0] + .5)
            p2 = math.trunc(100 * trial[3+clsf_id][1] + .5)

            fig, ((orig, img_1st, img_2nd)) = plt.subplots(1, 3, figsize=(4,2))
            orig.set_title('Original')
            orig.imshow(orig_img)
            orig.axis('off')

            if with_p:
                img_1st.set_title(f'{p1}% {label_names[pred_2nd if inverse_order else pred_1st]}')
            else:
                img_1st.set_title(f'A: {label_names[pred_2nd if inverse_order else pred_1st]}')
            img_1st.imshow(wimg2 if inverse_order else wimg1)
            img_1st.axis('off')

            if with_p:
                img_2nd.set_title(f'{p2}% {label_names[pred_1st if inverse_order else pred_2nd]}')
            else:
                img_2nd.set_title(f'B: {label_names[pred_1st if inverse_order else pred_2nd]}')
            img_2nd.imshow(wimg1 if inverse_order else wimg2)
            img_2nd.axis('off')
            p_on = 1 if with_p else 0
            plt.tight_layout()
            plt.savefig(f"lime-imgs/expl-i{idx}-t{true_class}-c{clsf_id}-{seg_id}-p{p_on}.png")
            #plt.close()

        try:
            (w1a, w2a) = get_expl_weights(slic_sma)
            (w1b, w2b) = get_expl_weights(slic_mid)
            (w1c, w2c) = get_expl_weights(slic_lrg)
            (w1d, w2d) = get_expl_weights(grid_5parts)
            (w1e, w2e) = get_expl_weights(grid_7parts)
            (w1f, w2f) = get_expl_weights(grid_11parts)
            #(w1g, w2g) = get_expl_weights(felzenszwalb)
            #(w1h, w2h) = get_expl_weights(quickshift)
            w1ac = combine_weights([w1a, w1b, w1c])
            w2ac = combine_weights([w2a, w2b, w2c])
            w1df = combine_weights([w1d, w1e, w1f])
            w2df = combine_weights([w2d, w2e, w2f])
            w1af = combine_weights([w1a, w1b, w1c, w1d, w1e, w1f])
            w2af = combine_weights([w2a, w2b, w2c, w2d, w2e, w2f])

            # HSV version
            make_plot(sk_color.hsv2rgb(weights2hsv(w1af)), 
                    sk_color.hsv2rgb(weights2hsv(w2af)), True, "haf")

            # Mask parts of the image
            make_plot(weights2mask(w1af), 
                      weights2mask(w2af), True, "maf")

        except Exception as err:
            print(f"could not create explanation for trial {idx}, classifier {clsf_id}", err)
