In [1]:
import os, sys
import numpy as np
import math
from time import time
import pickle
from tqdm import tqdm

import torch
from torchvision import models

import matplotlib.pyplot as plt
import skimage.io as io
from skimage.color import gray2rgb

from pycocotools.coco import COCO
from experiment_utils import *

sys.path.append("../../src")
from explainer import Explainer
from application_utils.image_utils import *
from application_utils.utils_torch import ModelWrapperTorch

sys.path.append("../../baselines/integrated_gradients")
import ig, ig_utils

sys.path.append("../../baselines/shapley_interaction_index")
from si_explainer import SiExplainer

sys.path.append("../../baselines/shapley_taylor_interaction_index")
from sti_explainer import StiExplainer

import warnings
warnings.filterwarnings("ignore")


%load_ext autoreload
%autoreload 2
%matplotlib inline

device = torch.device("cuda:0")

In [2]:
methods = ["archattribute"] # for analysis code to run smoothly, use one method per experiment run
save_path = "analysis/results/segment_auc_archattribute.pickle"

## Get Model

In [3]:
model = models.resnet152(pretrained=True).to(device).eval();
model_wrapper = ModelWrapperTorch(model, device)

## Get Data

In [4]:
data_dir = '/meladyfs/newyork/datasets/mscoco'
data_type = "val2017"
coco_to_i1k_path = "processed_data/image_data/coco_to_i1k_map.pickle"
annFile='{}/annotations/instances_{}.json'.format(data_dir, data_type)
coco=COCO(annFile)
i1k_idx_to_cat, valid_cat_ids, cat_map = prep_imagenet_coco_conversion(coco, data_dir=data_dir, data_type=data_type, coco_to_i1k_path=coco_to_i1k_path)

loading annotations into memory...
Done (t=4.17s)
creating index...
index created!


## Baseline Methods

In [5]:
def archattribute(model, image_tensor, mask_tensor, model_target_idx, device):
    predictions_island = model(image_tensor*mask_tensor.to(device))
    predictions_baseline = model(torch.zeros_like(mask_tensor).to(device))
    predictions = (predictions_island - predictions_baseline)
    att_score = predictions.data.cpu().numpy()[0][model_target_idx]
    return att_score

def integrated_gradients(model, image_tensor, model_target_idx, device):
    ig_score = ig.integrated_gradients(image_tensor.squeeze().cpu().numpy(), model, model_target_idx, ig_utils.get_gradients, None, device, steps=50)
    return ig_score

def shapley_interaction_index(image, baseline, segments, S, model_target_idx, seed =None, num_T=20):
    xf = ImageXformer(image, baseline, segments)
    e = SiExplainer(model_wrapper, data_xformer=xf, output_indices=model_target_idx, batch_size=20, seed=seed)
    att = e.attribution(S, num_T)
    return att

def shapley_taylor_interaction_index(image, baseline, segments, S, model_target_idx, max_order=2, num_orderings=20, seed=None):

    def subset_before(S, ordering, ordering_dict):
        end_idx = min(ordering_dict[s] for s in S)
        return ordering[:end_idx]
    
    if seed is not None:
        np.random.seed(seed)
        
    xf = ImageXformer(image, baseline, segments)
    e = StiExplainer(model_wrapper, data_xformer=xf, output_indices=model_target_idx, batch_size=20)

    num_feats = len(np.unique(segments))
    att = 0
    for ordering in range(num_orderings):
        ordering = np.random.permutation(list(range(num_feats)))
        ordering_dict = {ordering[i]: i for i in range(len(ordering))}
    
        if len(S) == max_order:
            T = subset_before(S, ordering, ordering_dict)
            att_inst = e.attribution(S, T)
        else:
            att_inst = e.attribution(S, [])
            
        att += att_inst
                
    return att/num_orderings

## Run Experiment

In [6]:
if os.path.exists(save_path):
    with open(save_path, 'rb') as handle:
        results = pickle.load(handle)
else:
    results = {}

In [7]:
show_plots = False
max_imgs_per_category = 500

t0 = time()

seenImgs = set()
img_count = 0

for cat_id in tqdm(valid_cat_ids):
    # get image ids corresponding to a category
    imgIds = coco.getImgIds(catIds=[cat_id] );

    for i, imgId in enumerate(imgIds):
        
        if imgId in seenImgs:
            continue
        seenImgs.add(imgId)
        
        # load the image metadata
        img = coco.loadImgs(imgId)[0]
        # load the annotation ids for this image
        annIds = coco.getAnnIds(imgIds=img['id'], catIds=valid_cat_ids, iscrowd=None)

        # if this image and all of its annotations have already been examined for each method, skip
        if imgId in results and all(len(results[imgId]["est"][m]) == len(annIds) for m in methods):
            continue
            
        # load the actual image
        I = io.imread('%s/images/%s/%s'%(data_dir,data_type,img['file_name']))

        # if grey, convert to RGB
        if len(I.shape) == 2:
            I = gray2rgb(I)
            
        if show_plots:
            plt.imshow(I); plt.axis('off')

        image = Image.fromarray(I)
        image, image_tensor = transform_img(I, preprocess)
        top_model_class_idxs = model(image_tensor.to(device)).data.cpu().numpy()[0].argsort()[::-1]

        # select the top predicted class that intersects with coco classes
        for i in top_model_class_idxs:
            if i in i1k_idx_to_cat:
                model_target_idx = i
                break

        # use superpixel segmenting with SI or STI
        if any( m in {"si", "sti"} for m in methods):
            segments = quickshift(image, kernel_size=3, max_dist=300, ratio=0.2)

        # actually load the annotations
        anns = coco.loadAnns(annIds)

        for method in methods:
                
            results[imgId] = {"ref": [], "est": {}}
            results[imgId]["est"][method] = []

            if method == "integrated_gradients":
                ig_score =  integrated_gradients(model, image_tensor, model_target_idx, device)

            for ann in (anns):

                assert(ann["category_id"] in valid_cat_ids)
                
                # get the mask for this annotation
                mask = coco.annToMask(ann)

                # transform (resize) the mask through torch, but maintain a mask with black background (dont normalize)
                mask_resize = np.tile(np.expand_dims(mask, 2), 3).astype(np.uint8)
                mask_orig, mask_tensor = transform_img(mask_resize, preprocess_mask)
                
                # there must be something to show through the mask after resizing 
                if math.isnan(mask_tensor.sum().item()): 
                    continue
                    
                # process how the mask corresponds to superpixel segments
                if method in {"si", "sti"}:
                    inter = match_segments_and_mask(segments, mask_orig)
                    # require that with SI and STI, only two segments are selected for pairwise interaction attribution 
                    if len(inter) != 2:
                        continue

                if show_plots:
                    plt.figure(figsize = (6,6))
                    plt.axis('off')
                    plt.imshow((image*mask_orig)/2+0.5)
                    plt.show()

                # apply the different methods for the model_target_idx, the "top" prediction on the original image
                if method == "archattribute":
                    att_score = archattribute(model, image_tensor, mask_tensor, model_target_idx, device)
                elif method == "integrated_gradients":
                    att_score = (mask_tensor.cpu().numpy()*ig_score).sum()
                elif method == "si":
                    att_score = shapley_interaction_index(image, np.zeros_like(image), segments, inter, model_target_idx, seed=img_count)
                elif method == "sti":
                    att_score = shapley_taylor_interaction_index(image, np.zeros_like(image), segments, inter, model_target_idx, seed=img_count)
                    

                results[imgId]["est"][method].append(att_score)

                # if the annotation does not belong to the category of model_target_idx, ground truth is 0, else it is 1
                if cat_map[ann["category_id"]] not in i1k_idx_to_cat[model_target_idx]: # "in" handles the vase case, which maps to two coco categories
                    results[imgId]["ref"].append(0)
                else:
                    results[imgId]["ref"].append(1)
            
        # save attribution results and corresponding ground truth
        if img_count % 1 == 0:
            with open(save_path, 'wb') as handle:
                pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
                
        img_count += 1
            
t1 = time()