This is an example of the LIME explainer for image data. This explainer only supports image classification tasks. If using this explainer, please cite the original work: https://github.com/marcotcr/lime.

In [49]:
import json
import unittest
import torch
from torchvision import models, transforms
from PIL import Image as PilImage
from omnixai.data.image import Image
from omnixai.explainers.vision import LimeImage

We recommend using Image to represent a batch of images. Image can be constructed from a numpy array or a Pillow image. The following code loads one test image and the class names on ImageNet.

In [50]:
# Load the test image
img = Image(PilImage.open(r'images_resize0.5_contrastReduce1-20_num_image10000\part_whole_test\1\g_min_l_plus+8.png').convert('RGB'))
# Load the class names
idx2label = ['0_plus', '1_min']

In [51]:
img_size = 224
model = torch.load(r'models\resnet18_is224_bs4_e10_i10000_resize0.5_contrastReduce1-20_num_image10000.pth').to('cpu')
transform = transforms.Compose([
        transforms.Resize(size=img_size),
        transforms.ToTensor(),
        ])

In [52]:
model.eval()
input_img = transform(img.to_pil()).unsqueeze(dim=0)
probs_top_2 = torch.nn.functional.softmax(model(input_img), dim=1).topk(2)
r = tuple((p, c, idx2label[c]) for p, c in
          zip(probs_top_2[0][0].detach().numpy(), probs_top_2[1][0].detach().numpy()))
print(r)

((0.8095921, 0, '0_plus'), (0.19040795, 1, '1_min'))


In [53]:
def batch_predict(images):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    batch = torch.stack([transform(img.to_pil()) for img in images])
    batch = batch.to(device)
    logits = model(batch)
    probs = torch.nn.functional.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

explainer = LimeImage(predict_function=batch_predict)
# Explain the top labels
explanations = explainer.explain(img, hide_color=0, num_samples=1000)
explanations.ipython_plot(index=0, class_names=idx2label)

  0%|          | 0/1000 [00:00<?, ?it/s]