In [None]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models
from skimage.segmentation import quickshift
import torch
sys.path.append("/home/mengfeiz/archipelago/src")
sys.path


In [None]:
from explainer import Archipelago
from application_utils.image_utils import *
from application_utils.utils_torch import ModelWrapperTorch

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

##get model
model = models.resnet152(pretrained=True).to(device).eval();
model_wrapper = ModelWrapperTorch(model, device)

import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np

class FineTuneResNet(nn.Module):
    def __init__(self, original_model, num_classes):
        super(FineTuneResNet, self).__init__()
        fc1 = nn.Linear(2048, 1000)
        relu1 = nn.ReLU()
        fc2 = nn.Linear(1000, 256)
        relu2 = nn.ReLU()
        # fc3 = nn.Linear(1000, num_classes)
        fc3 = nn.Linear(256, num_classes)
        self.features = nn.Sequential(*list(original_model.children())[:-1])
        self.classifier = nn.Sequential(fc1, relu1, fc2, relu2, fc3)
        # self.classifier = nn.Sequential(fc1, relu1, fc3)

    def forward(self, x):
        out = self.features(x)
        # print('prev out.shape: ', out.shape)
        out = out.view(out.size(0), -1)
        # print('out.shape: ', out.shape)
        out = self.classifier(out)
        return out


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = models.__dict__["resnet50"](pretrained=True)
model = FineTuneResNet(base_model, 2)
model.load_state_dict(torch.load("/home/mengfeiz/archipelago/src/burnCNN_fold4.pt"))
model_wrapper = ModelWrapperTorch(model, device)

In [None]:
def show_image_explanation(
    inter_effects,
    image,
    segments,
    figsize=0.4,
    spacing=0.15,
    main_effects=None,
    savepath="",
):
    """
    Format image visualizations for plotting
    """

    inter_sets, inter_atts = list(zip(*inter_effects))
    if main_effects is not None:
        main_id_list = []
        main_att_list = []
        for main_id, main_att in main_effects:
            main_id_list.append(main_id)
            main_att_list.append(main_att)

    max_att_main = np.amax(np.abs(main_att_list))
    max_att_inter = np.amax(np.abs(inter_atts))

    img_arrays = []
    img_arrays.append([(image, "Original image")])

    ## main effects
    if main_effects is not None:
        img_arrays.append(
            [
                (
                    get_set_img(
                        image, segments, main_id_list, main_att_list, max_att_main
                    ),
                    "Main effects",
                )
            ]
        )

    inter_img_arrays = []
    for i, inter_set in enumerate(inter_sets):
        inter_img_arrays.append(
            (
                get_set_img(image, segments, inter_set, inter_atts[i], max_att_inter),
                "Interaction $\mathcal{I}_" + str(i + 1) + "$",
            )
        )
    img_arrays.append(inter_img_arrays)
    return img_arrays

In [None]:
img_main = []
def get_explanations(img_arrays, figsize=0.4, spacing=0.15, savepath=""):
    w_spacing = (2 / 3) * spacing
    left = 0
    ax_arays = []
    fig = plt.figure()
    for img_array in img_arrays:
        num_imgs = len(img_array)
        right = left + figsize * (num_imgs) + (num_imgs - 1) * 0.4 * w_spacing
        ax_arays.append(
            fig.subplots(
                1, num_imgs, gridspec_kw=dict(left=left, right=right, wspace=w_spacing)
            )
        )
        left = right + spacing

    for i, ax_array in enumerate(ax_arays):
        if hasattr(ax_array, "flat"):
             for j, ax in enumerate(ax_array.flat):
                img, title = img_arrays[i][j]
                ax.imshow(img / 2 + 0.5)
                img_main.append(img / 2 + 0.5)
                ax.set_title(title, fontsize=55 * figsize)
                ax.axis("off")
        else:
            img, title = img_arrays[i][0]

            ax_array.imshow(img / 2 + 0.5)
            img_main.append(img / 2 + 0.5)
            ax_array.set_title(title, fontsize=55 * figsize)
            ax_array.axis("off")

    if savepath:
        plt.savefig(savepath, bbox_inches="tight")
#     plt.show()

In [None]:
def ExplainPrediction(image):
    baseline = np.zeros_like(image)
    image = image.astype(np.double)
    segments = quickshift(image, kernel_size=3, max_dist=300, ratio=0.2)

    xf = ImageXformer(image, baseline, segments)
    apgo = Archipelago(model_wrapper, data_xformer=xf, output_indices=0, batch_size=20, verbose=True)
    
    inter_effects, main_effects = apgo.explain(top_k=15, separate_effects=True)
    img_arrays = show_image_explanation(inter_effects.items(), image, segments, main_effects = main_effects.items())
    get_explanations(img_arrays)

    return img_main[1].astype(int)

iface = gr.Interface(ExplainPrediction, gr.inputs.Image(shape=(200, 200)), "image")
iface.launch(share=True)