In [None]:
import torch
import torchvision
import itertools
import shap
import numpy as np
import scipy.special
import time
import warnings
from skimage.segmentation import slic, mark_boundaries
from PIL import Image
import matplotlib.pyplot as plt


import test

In [None]:
img = Image.open('images/cat_guitar.jpg')
plt.imshow(img)

In [None]:
img_tensor = torchvision.transforms.functional.pil_to_tensor(img)
img_array = np.array(img)
img_array2 = np.tile(img, reps=(2,1,1,1))
a = slic(img_array, n_segments=50, compactness=10, start_label=0)

print(img_array.shape)
print(a.shape)

plt.imshow(mark_boundaries(img_array, a))

In [None]:
segments, masks = test.slic_segmenter(img_array, nbr_segments=50, compactness=10)
_, rise_masks = test.grid_segmenter(img_array, 10, 10, True)

print(_.shape)

M = np.unique(segments).shape[0]
samples = test.shap_sampler(M, sample_size=M*2+2)
ciu_samples = test.naive_ciu_sampler(M,inverse=False)
rise_samples = test.random_sampler(rise_masks.shape[0], sample_size=300)

a = test.perturbation_masks(rise_masks, rise_samples)

perturbed_image, samples = test.single_color_pertuber(img_array, masks, samples, np.array((190,190,190)))
ciu_perturbed_image, ciu_samples = test.single_color_pertuber(img_array, masks, ciu_samples, np.array((190,190,190)))
rise_perturbed_image, rise_samples = test.single_color_pertuber(img_array, rise_masks, rise_samples, np.array((190,190,190)))

print(perturbed_image.shape, rise_masks.shape)
plt.figure()
plt.imshow(perturbed_image[1])
plt.figure()
plt.imshow(perturbed_image[14])
plt.figure()
plt.imshow(perturbed_image[73])
plt.figure()
plt.imshow(rise_perturbed_image[0])
plt.figure()
plt.imshow(a[0])

In [None]:
alexnet = torchvision.models.resnet50(weights='IMAGENET1K_V1')
alexnet.eval()


from torchvision.transforms import v2

transforms = v2.Compose([
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    v2.Resize((224,224))
])

x = transforms(torch.from_numpy(perturbed_image/255).permute((0,3,1,2))).float()
ciu_x = transforms(torch.from_numpy(ciu_perturbed_image/255).permute((0,3,1,2))).float()
rise_x = transforms(torch.from_numpy(rise_perturbed_image/255).permute((0,3,1,2))).float()

y = alexnet(x)
ciu_y = alexnet(ciu_x)
rise_y = alexnet(rise_x)

idx = y[1].argmax()
print(idx)

sorted, indices = y[1].sort(descending=True)
print(indices)

y = y.detach().numpy()
ciu_y = ciu_y.detach().numpy()
rise_y = rise_y.detach().numpy()

cat_idx = 281
guitar_idx = 402


cat_ys = y[:,cat_idx]
guitar_ys = y[:,guitar_idx]

print(ciu_y[:,cat_idx])

ciu_cat_ys = ciu_y[:,cat_idx]
ciu_guitar_ys = ciu_y[:,guitar_idx]

rise_cat_ys = rise_y[:,cat_idx]
rise_guitar_ys = rise_y[:,guitar_idx]



cat_shaps = test.shap_values(cat_ys, samples)
guitar_shaps = test.shap_values(guitar_ys,  samples)
 
cat_ciu = test.original_ciu_values(ciu_cat_ys, ciu_samples, inverse=False)
guitar_ciu = test.original_ciu_values(ciu_guitar_ys, ciu_samples, inverse=False)

rise_perturbed_masks = test.perturbation_masks(rise_masks, rise_samples)

cat_rise = test.rise_values(rise_cat_ys, rise_perturbed_masks)
guitar_rise = test.rise_values(rise_guitar_ys, rise_perturbed_masks)

#print(ciu_samples)

plt.imshow(perturbed_image[np.argmax(cat_shaps[0][:-1])+2])
plt.figure()
plt.imshow(perturbed_image[np.argmax(guitar_shaps[0][:-1])+2])
plt.figure()
plt.imshow(ciu_perturbed_image[np.argmax(cat_shaps[0][:-1])])
plt.figure()
plt.imshow(ciu_perturbed_image[np.argmax(guitar_shaps[0][:-1])])
#plt.figure()
#plt.imshow(rise_masks[np.argmax(cat_rise)])
#plt.figure()
#plt.imshow(rise_masks[np.argmax(guitar_rise)])
plt.figure()
plt.imshow(cat_rise)
plt.figure()
plt.imshow(guitar_rise)
plt.figure()
plt.imshow(ciu_perturbed_image[np.argmax(cat_ciu)])
plt.figure()
plt.imshow(ciu_perturbed_image[np.argmax(guitar_ciu)])
#plt.figure()
#plt.imshow(ciu_perturbed_image[np.argwhere(np.max(cat_ciu) ==cat_ciu)[0,0]])
#plt.figure()
#plt.imshow(ciu_perturbed_image[np.argwhere(np.max(guitar_ciu) ==guitar_ciu)[1,0]])


