## LIME for images

For images, it is not a good idea to perturb individual pixels to understand the behaviour of our model. This is because many more than one pixel contribute to one class. Randomly changing individual pixels would probably not change the predictions by much. Therefore, variations of the images are created by **segmenting** the image into **“superpixels”** and turning superpixels off or on. Let us take a detour and briefly discuss superpixels.   

<img src="images/super_pixels.jpg" alt="super_pixel" style="float: left; margin-right: 10px;" align="right" width="200"/>

Superpixel algorithms group pixels into perceptually meaningful regions while respecting potential object contours, and thereby can replace the rigid pixel grid structure. More formally,

**Superpixel.** Superpixels are an **oversegmentation** of an image. A superpixel is a perceptual grouping of pixels. Instead of finding segments that correspond to objects (as done in instance segmentation), superpixel segmentation algorithms split the image into typically a few hundered (eg., 2500) segments. The objective of this oversegmentation is to partition the image such that **1)** no superpixel is split by an object boundary, **2)** while objects may be divided into multiple superpixels.



The LIME algorithm for images uses super-pixels as image features to interpret a black-box model. Let's see how.


## Setting up 

As always, we start by loading relevant packages. From the LIME package, we import lime_image. We also make use of the skimage package for visualization. See the instructions below. 

In [None]:
import json
import os
import numpy as np
import PIL
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torchvision.models import resnet18
import torchvision.transforms as T
import torch.nn.functional as F

# reading lime_image
from lime import lime_image

# we also need skimage package for visualization purposes
# you can install skimage in conda using: 
#     conda install scikit-image
#
from skimage.segmentation import mark_boundaries 

In [None]:
# Set random seed for reproducibility.
np.random.seed(0)
torch.manual_seed(0) 


device ="cuda:0" if torch.cuda.is_available() else "cpu"

## Interpreting ResNet model

We focus on interpreting ResNet model. Below, we load the model. We also need the class descriptions for the ImageNet dataset which the ResNet is trained on. This information is available in a json file which again we read below.

In [None]:
# loading the model
net = resnet18(pretrained=True)
net = net.eval().to(device)


# reading imagenet classes
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.join("data","imagenet_class_index.json"), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}


### Utility functions.

We need to develop a couple of utility function. Below, we implement a function to read an image and construct a torch tensor for the ResNet model. We recall that the ResNet model requires images to be normalized (RGB channels). 

We also need a function to convert an image to a torch tensor withoiut normalizing. The latter is used with the LIME package.  

In [None]:
# resize and take the center part of image to what our model expects
def pil_to_torch(img):
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])       
    transf = T.Compose([
        T.Resize((256, 256)),
        T.CenterCrop(224),
        T.ToTensor(),
        normalize
    ])        
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

def pil_transform(img): 
    transf = T.Compose([
        T.Resize((256, 256)),
        T.CenterCrop(224)
    ])    

    return transf(img)

### Read an image

Below, we read the "puppy_kitten.jpg" from the data folder and display it.

In [None]:
img_file_name = "puppy_kitten.jpg"
img_pil = Image.open(os.path.join("data",img_file_name)).convert('RGB')


plt.imshow(img_pil)
plt.show()

### ResNet predicts as?

Let's see how the ResNet will recognize the image, is it a cat or a dog?

In [None]:
img = pil_to_torch(img_pil) #get_input_tensors(img_pil)

logits = net(img)
probs = F.softmax(logits, dim=1)
probs5 = probs.topk(5)
tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))

### Utility function for LIME

We need to write a utility function to work with LIME. LIME will provide us with images where superpixels will turn-on and off and requires predictions in the form of class-probabilities to train its local model. hence, we need to write a function that gets images in the form of numpy arrays and make predictiosn using our ResNet model. 

In [None]:
def cnn_predict(images): 
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])     
    transf = T.Compose([
        T.ToTensor(),
        normalize
    ])    
    batch = torch.stack(tuple(transf(img) for img in images), dim=0)

    logits = net(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

### Use LIME to explain the ResNet

To use LIME, we need to first define an explainer object. Check the cell below for this purpose.
Then, we can use the method [explain_instance](https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_image) from the explainer object to understand the behaviour of Resnet. Run the "explain_instance" and discuss the results. 

In [None]:
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pil_transform(img_pil)), 
                                         cnn_predict, # classification function
                                         top_labels=2, 
                                         hide_color=0, 
                                         num_samples=1000) # number of images that will be sent to classification function

After running LIME on our image, you can use various methods to visualize which parts of the image contributed to the decisions. Check the method [get_image_and_mask](https://lime-ml.readthedocs.io/en/latest/lime.html#lime.lime_image.ImageExplanation.get_image_and_mask) which can be used to visulize superpixels that **positively** or **negatively** contribute to the prediction of the label.

In [None]:
fig, axes = plt.subplots(1,3)
axes[0].imshow(pil_transform(img_pil))
axes[0].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, negative_only=False, num_features=3, hide_rest=True)
img_boundry = mark_boundaries(temp/255.0, mask)
axes[1].imshow(img_boundry)
axes[1].set_title("Positive mask")
axes[1].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, negative_only=True, num_features=1, hide_rest=True)
img_boundry = mark_boundaries(temp/255.0, mask)
axes[2].imshow(img_boundry)
axes[2].set_title("Negative mask")
axes[2].axis('off')

plt.show()

In [None]:
For the fun of it, I also tried LIME on the image below. 

In [None]:
#=====
img_file_name = "penguin2.jpg"
img_pil = Image.open(os.path.join("data",img_file_name)).convert('RGB')
#=====
img = pil_to_torch(img_pil)
logits = net(img)
probs = F.softmax(logits, dim=1)
probs5 = probs.topk(5)

#=====
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(np.array(pil_transform(img_pil)), 
                                         cnn_predict, # classification function
                                         top_labels=2, 
                                         hide_color=0, 
                                         num_samples=1000) # number of images that will be sent to classification function
#=====
fig, axes = plt.subplots(1,3)
axes[0].imshow(pil_transform(img_pil))
axes[0].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, negative_only=False, num_features=3, hide_rest=True)
img_boundry = mark_boundaries(temp/255.0, mask)
axes[1].imshow(img_boundry)
axes[1].set_title("Positive mask")
axes[1].axis('off')

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, negative_only=True, num_features=1, hide_rest=True)
img_boundry = mark_boundaries(temp/255.0, mask)
axes[2].imshow(img_boundry)
axes[2].set_title("Negative mask")
axes[2].axis('off')

plt.show()

tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy()))