# Interpreting Image Classifier

In [19]:
import json
import torch
import numpy as np
from PIL import Image as PilImage

from omnixai.preprocessing.image import Resize
from omnixai.data.image import Image
from omnixai.explainers.vision import VisionExplainer
from omnixai.visualization.dashboard import Dashboard

In [42]:
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image(PilImage.open(requests.get(url, stream=True).raw).convert("RGB"))

#feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
#pixel_values = feature_extractor(image, return_tensors="pt").pixel_values

In [46]:
from torchvision import models, transforms

model = models.resnet34(pretrained=True).to(device)

device = "cuda" if torch.cuda.is_available() else "cpu"
# The preprocessing function
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims]).to(device)

# The postprocessing function
postprocess = lambda logits: torch.nn.functional.softmax(logits, dim=1)


The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /Users/nazneenrajani/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

In [53]:
explainer = VisionExplainer(
    explainers=[ "gradcam", "shap", "lime","ig"],
    mode="classification",
    model=model,
    preprocess=preprocess,
    postprocess=postprocess,
    params={
        "gradcam": {"target_layer": model.layer4[-1]},

    }

)
# Generate explanations
local_explanations = explainer.explain(Image(
    data=np.concatenate([
        image.to_numpy()]),
    batched=True
))


Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.



  0%|          | 0/1000 [00:00<?, ?it/s]


Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.



In [56]:
index=0
print("LIME results:")
local_explanations["lime"].ipython_plot(index)
print("Integrated-gradient results:")
local_explanations["ig"].ipython_plot(index)
print("SHAP results:")
local_explanations["shap"].ipython_plot(index)
print("Gradcam results:")
local_explanations["gradcam"].ipython_plot(index)


LIME results:


Integrated-gradient results:


SHAP results:


Gradcam results:


In [None]:
# Document time for each explanation

# LIME

# Integrated-gradient

# Grad-CAM

# SHAP

# Visualize the explanations

# Compare the explanations  

# Which method do you agree with the most? Why?

In [57]:
# Launch a dashboard for visualization using streamlit or gradio

dashboard = Dashboard(
    instances=image,
    local_explanations=local_explanations,

)
dashboard.show()

Dash is running on http://127.0.0.1:8050/

 * Serving Flask app 'omnixai.visualization.dashboard'
 * Debug mode: off


 * Running on http://127.0.0.1:8050
Press CTRL+C to quit
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /_dash-component-suites/dash/deps/polyfill@7.v2_6_1m1663266855.12.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /assets/xai.css?m=1663266542.9749486 HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /assets/styles.css?m=1663266542.9748507 HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /assets/base.css?m=1663266542.974423 HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /_dash-component-suites/dash/deps/react@16.v2_6_1m1663266855.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /_dash-component-suites/dash/deps/react-dom@16.v2_6_1m1663266855.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /_dash-component-suites/dash/deps/prop-types@15.v2_6_1m1663266855.8.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [17/Sep/2022 10:48:25] "GET /_dash-component-su