In [1]:
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import models
from skimage.segmentation import quickshift

sys.path.append("../../src")
from context_explainer import ContextExplainer
from application_utils.image_utils import *
from application_utils.utils_torch import ModelWrapperTorch

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2
%matplotlib inline

device = torch.device("cuda:0")

from tqdm import tqdm
import pickle

In [2]:
save_path = "results/resnet_random_context_only.pickle"
random_context_only = True

## Get Model

In [3]:
model = models.resnet152(pretrained=True).to(device).eval();
model_wrapper = ModelWrapperTorch(model, device)

## Get Images

In [4]:
base_path = "../../downloads/imagenet14/test"
test_data = sorted(
    [base_path + "/" + f for f in os.listdir(base_path) if f.endswith(".JPEG")]
)

target_count = 100

np.random.seed(42)
indexes = np.random.choice(len(test_data), target_count, replace=False)

In [5]:
if os.path.exists(save_path):
    print("loaded")
    with open(save_path, 'rb') as handle:
        all_res = pickle.load(handle)
else:
    all_res = []

## Run Experiment

In [6]:
np.random.seed(42)

for counter, index in enumerate(tqdm(indexes)):

    image_path = test_data[index]
    
    img_filename = image_path.split("/")[-1]
    
    if counter < len(all_res):
        # if an experiment is already done, skip it
        print("skip", counter)
        assert(all_res[counter]["img_filename"] == img_filename)
        continue

    image, labels = get_image_and_labels(image_path, device)

    predictions = model_wrapper(np.expand_dims(image,0))
    class_idx = predictions[0].argsort()[::-1][0]
    
    baseline = np.zeros_like(image)
    segments = quickshift(image, kernel_size=3, max_dist=300, ratio=0.2)

    xf = ImageXformer(image, baseline, segments)

    ctx = ContextExplainer(model_wrapper, data_xformer=xf, output_indices=class_idx, batch_size=20, verbose=False)

    context1 = ctx.input
    context2 = ctx.baseline

    n_samples = 9

    new_contexts = []
    if random_context_only:
        seen_contexts_tuples = []
        n_samples += 2
    else:
        seen_contexts_tuples = [tuple(context1), tuple(context2)]

    for n in range(n_samples):
        while True:
            context = np.random.randint(0, high=2, size=len(context1)).astype(bool)
            context_tuple = tuple(context)
            if context_tuple not in seen_contexts_tuples:
                break
        new_contexts.append(context)
        seen_contexts_tuples.append(context_tuple)

    if random_context_only:
        all_contexts = new_contexts
    else:
        all_contexts =[context1, context2] + new_contexts

    res = ctx.detect_with_running_contexts(all_contexts)
    
    all_res.append({"img_filename": img_filename, "result": res})
    
    with open(save_path, 'wb') as handle:
        pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)
    

with open(save_path, 'wb') as handle:
    pickle.dump(all_res, handle, protocol=pickle.HIGHEST_PROTOCOL)