In [None]:
import os
import sys
from PIL import Image

import clip
import torch
import shapiq
import datasets
import numpy as np
import pandas as pd
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

import matplotlib.pyplot as plt

import Game_MM_CLIP.clip as mm_clip

from generate_emap import clipmodel, preprocess, mm_clipmodel, mm_interpret, gradeclip, gradeclip_text, clip_encode_dense, clip_encode_text_dense

sys.path.append('../')
import src
PATH = "../data"

In [None]:
def explain_image(mode, input_image, input_text):
    image_preprocessed = preprocess(input_image).to(device).unsqueeze(0)
    if "gradeclip" in mode:
        text_processed = clip.tokenize([input_text]).to(device)
        text_embedding = clipmodel.encode_text(text_processed)
        text_embedding = F.normalize(text_embedding, dim=-1)
        outputs, v_final, last_input, v, q_out, k_out, attn, att_output, map_size = clip_encode_dense(image_preprocessed)
        image_embedding = F.normalize(outputs[:, 0], dim=-1)
        cosines = (image_embedding @ text_embedding.T)[0]
        emap = [gradeclip(c, q_out, k_out, v, att_output, map_size, withksim=True) for c in cosines]
        emap = torch.stack(emap, dim=0).sum(0) 
    elif "game" in mode:
        text_tokenized = mm_clip.tokenize([input_text]).to(device)
        emap = mm_interpret(model=mm_clipmodel, image=image_preprocessed, texts=text_tokenized, device=device)    
        emap = emap.sum(0) 
    return emap

def explain_text(mode, input_image, input_text):
    image_preprocessed = preprocess(input_image).to(device).unsqueeze(0)
    if "gradeclip" in mode:
        image_embedding = clipmodel.encode_image(image_preprocessed)
        image_embedding = F.normalize(image_embedding, dim=-1)
        text_processed = clip.tokenize([input_text]).to(device)
        x, (qs, ks, vs), attns, atten_outs = clip_encode_text_dense(text_processed, n=8)
        text_embedding = F.normalize(x, dim=-1)
        cosine = (image_embedding @ text_embedding.T)
        eos_position = text_processed.argmax(dim=-1)
        emap = gradeclip_text(cosine[0], qs, ks, vs, atten_outs, eos_position, withksim=True)
    elif "game" in mode:
        text_tokenized = mm_clip.tokenize([input_text]).to(device)
        emap = mm_interpret(model=mm_clipmodel, image=image_preprocessed, texts=text_tokenized, device=device, flag="text")    
        id_cls = text_tokenized.argmax(dim=-1).item()
        r_text = emap[0][id_cls, 1:id_cls]
        emap = r_text.flatten()
    return emap

In [None]:
def denormalize(img, mean, std):
    return img * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
image_mean = (0.48145466, 0.4578275, 0.40821073)
image_std = (0.26862954, 0.26130258, 0.27577711)

### example

In [None]:
input_text = "black dog next to a yellow hydrant"
input_image = Image.open(os.path.join("..", "assets", "dog_and_hydrant.png"))
values_image = explain_image("game", input_image, input_text).reshape(-1)
values_text = explain_text("game", input_image, input_text)
values = np.concat([values_image.detach().cpu().numpy(), values_text.detach().cpu()])
gv = src.utils.convert_array_to_first_order(values)
gv.save("../results/dog_and_hydrant_game.pkl")

In [None]:
input_text = "black dog next to a yellow hydrant"
input_image = Image.open(os.path.join("..", "assets", "dog_and_hydrant.png"))
values_image = explain_image("gradeclip", input_image, input_text).reshape(-1)
values_text = explain_text("gradeclip", input_image, input_text)
values = np.concat([values_image.detach().cpu().numpy(), values_text.detach().cpu()])
gv = src.utils.convert_array_to_first_order(values)
gv.save("../results/dog_and_hydrant_gradeclip.pkl")

### insertion/deletion

In [None]:
PATH_RESULTS = "../results"
PATH_INPUT = "../results/mscoco"
# MODEL_NAME = "openai/clip-vit-base-patch32"
MODEL_NAME = "openai/clip-vit-base-patch16"

In [None]:
dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
    streaming=True
)

In [None]:
start, stop = 0, 1000
df_results = pd.read_csv(os.path.join(PATH_RESULTS, MODEL_NAME, "mscoco_predictions.csv"), index_col=0)
top_ids = df_results.sort_values("logit", ascending=False).iloc[start:stop, :].index

In [None]:
results_details = {}
for mode in ['game', 'gradeclip']:
    results_details[f'{mode}'] = {}

results = pd.DataFrame({
    'input': [], 
    'mode': [], 
    'mean': [],
    'mean_normalized': []
})

In [None]:
n_iter = 0
for i, d in enumerate(dataset):
    if i not in top_ids:
        continue
    n_iter += 1
    if n_iter % 5 == 1:
        print(f'iter: {start + n_iter}/{stop}', flush=True)

    input_image = d['jpg']
    input_text = d['txt'].split("\n")[df_results.loc[i, "best_text_id"].item()]
    
    game = src.game_openai.CLIPGame(
        clipmodel, preprocess, 
        input_image=input_image,
        input_text=input_text,
        patch_size=16 if MODEL_NAME.endswith("16") else 32,
        batch_size=64
    )
    
    for mode in ['game', 'gradeclip']:

        path_file = os.path.join(PATH_INPUT, MODEL_NAME, mode, f'iv_order1_{i}.pkl')
        iv_object = shapiq.InteractionValues.load(path_file)
        
        attribution_values = iv_object.get_n_order(1).values
        attribution_values_sorted = np.sort(attribution_values)
        # insertion / deletion, most important first / least important first
        coalition_matrix_deletion_mif = np.stack([attribution_values <= v for v in attribution_values_sorted[::-1]] + [game.empty_coalition])
        predictions_deletion_mif = game.value_function(coalition_matrix_deletion_mif)
        coalition_matrix_deletion_lif = np.stack([attribution_values >= v for v in attribution_values_sorted] + [game.empty_coalition])
        predictions_deletion_lif = game.value_function(coalition_matrix_deletion_lif)

        results_details[f'{mode}'][i] = {
            'predictions_deletion_mif': predictions_deletion_mif,
            'predictions_deletion_lif': predictions_deletion_lif,
        }

        assert predictions_deletion_mif[-1] == predictions_deletion_lif[-1]
        assert predictions_deletion_mif[0] == predictions_deletion_lif[0]
        
        # normalize the curve
        min_value = predictions_deletion_mif[-1]
        max_value = predictions_deletion_mif[0]

        predictions_deletion_mif_01 = (predictions_deletion_mif - min_value) / (max_value - min_value)
        predictions_deletion_lif_01 = (predictions_deletion_lif - min_value) / (max_value - min_value)

        results = pd.concat([results, pd.DataFrame({
            'input': [i], 
            'mode': [mode], 
            'mean': [np.mean(predictions_deletion_lif - predictions_deletion_mif)],
            'mean_normalized': [np.mean(predictions_deletion_lif_01 - predictions_deletion_mif_01)]
        })])

In [None]:
results.groupby(["mode"]).aggregate({
    'mean': ['mean', 'std'],
    'mean_normalized': ['mean', 'std'],
}).round(2)

In [None]:
results.to_csv(os.path.join(PATH_RESULTS, MODEL_NAME, "mscoco_aid_game_gradeclip.csv"), index=False)

In [None]:
np.save(os.path.join(PATH_RESULTS, MODEL_NAME, "mscoco_aid_game_gradeclip.npy"), results_details)

### visualize explanations

In [None]:
# MODEL_NAME = "openai/clip-vit-base-patch16"
MODEL_NAME = "openai/clip-vit-base-patch32"
PATH_INPUT = "../results"
PATH_OUTPUT = "../results/mscoco"

In [None]:
dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
    streaming=True
)

_tokenizer = clip.simple_tokenizer.SimpleTokenizer()

In [None]:
df_results = pd.read_csv(os.path.join(PATH_INPUT, MODEL_NAME, "mscoco_predictions.csv"), index_col=0)
top_ids = df_results.sort_values("logit").tail(1000).index

In [None]:
for mode in [
    'gradeclip', 
    'game'
]:
    path_output = os.path.join(PATH_OUTPUT, MODEL_NAME, mode)
    if not os.path.exists(path_output):
        os.makedirs(path_output)

    print_counter = 0
    for i, d in enumerate(dataset):
        if i not in top_ids:
            continue
        if print_counter % 100 == 0:
            print(f'{mode} | iter: {print_counter}/1000', flush=True)
        print_counter += 1

        input_image = d['jpg']
        input_text = d['txt'].split("\n")[df_results.loc[i, "best_text_id"].item()]

        values_image = explain_image(mode, input_image, input_text).reshape(-1)
        values_text = explain_text(mode, input_image, input_text)
        values = np.concat([values_image.detach().cpu().numpy(), values_text.detach().cpu()])
        gv = src.utils.convert_array_to_first_order(values, index=mode)
        gv.save(os.path.join(path_output, f'iv_order1_{i}.pkl'))
        ##
        text_processed = clip.tokenize([input_text])
        text_tokens = _tokenizer.encode(input_text)
        text_tokens_decoded = [_tokenizer.decode([a]) for a in text_tokens]
        image_preprocessed = preprocess(input_image)
        input_image_denormalized = denormalize(image_preprocessed, image_mean, image_std)
        input_image_denormalized = input_image_denormalized.permute(1, 2, 0).numpy()
        fig = src.plots.plot_image_and_text_together(
            img=input_image_denormalized, 
            text=text_tokens_decoded, 
            image_players=list(range(len(values_image))), 
            iv=gv, 
            normalize_jointly=False,
            figsize=(8, 8),
            show=False
        ) 
        fig.suptitle(f'{MODEL_NAME} {mode}', fontsize=20, y=1.05)
        fig.savefig(os.path.join(path_output, f'ex_order1_{i}.png'), bbox_inches='tight')
        plt.close(fig)

### pointing game

In [None]:
MODEL_NAME = "openai/clip-vit-base-patch32"
GAMES = [
    ['goldfish', 'husky', 'pizza', 'tractor'],
    ['cat', 'goldfish', 'plane', 'pizza'],
    ['banana', 'cat', 'tractor', 'ball'],
    ['husky', 'banana', 'plane', 'church'],
    ['pizza', 'ipod', 'goldfish', 'banana'],
    ['ipod', 'cat', 'husky', 'plane'],
    ['tractor', 'ball', 'banana', 'ipod'],
    ['plane', 'church', 'ball', 'goldfish'],
    ['church', 'pizza', 'ipod', 'cat'],
    ['ball', 'husky', 'banana', 'tractor'],
]
PATH_OUTPUT = "../results/imagenet_pointing_game"

In [None]:
for mode in ['gradeclip', 'game']:
    for game in GAMES:
        PATH_INPUT = f'../data/imagenet_pointing_game/{"_".join(game)}'
        for i_objects in range(1, 5):
            class_labels = game[:i_objects]
            cl = "_".join(class_labels)
            input_text = cl.replace("_", " ")
            path_output = os.path.join(PATH_OUTPUT, MODEL_NAME, mode, cl)
            if not os.path.exists(path_output):
                os.makedirs(path_output)
            for i_image in range(50):
                input_image = Image.open(os.path.join(PATH_INPUT, f'{i_image}.jpg'))
                values_image = explain_image(mode, input_image, input_text).reshape(-1)
                values_text = explain_text(mode, input_image, input_text)
                values = np.concat([values_image.detach().cpu().numpy(), values_text.detach().cpu()])
                gv = src.utils.convert_array_to_first_order(values)
                gv.save(os.path.join(path_output, f'iv_order1_{i_image}.pkl'))
                ##
                image_preprocessed = preprocess(input_image)
                input_image_denormalized = denormalize(image_preprocessed, image_mean, image_std)
                input_image_denormalized = input_image_denormalized.permute(1, 2, 0).numpy()
                fig = src.plots.plot_image_and_text_together(
                    img=input_image_denormalized, 
                    text=cl.split("_"), 
                    image_players=list(range(len(values_image))), 
                    iv=gv, 
                    normalize_jointly=True,
                    figsize=(6, 6),
                    show=False
                ) 
                fig.suptitle(f'{MODEL_NAME} {mode}', fontsize=20, y=1.05)
                fig.savefig(os.path.join(path_output, f'ex_order1_{i_image}.png'), bbox_inches='tight')
                plt.close(fig)