In [None]:
import os
import sys
from PIL import Image

import clip
import torch
import datasets
import numpy as np
import pandas as pd
import shapiq
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

sys.path.append('../')
import src
PATH = "../data"

from wrapper import exCLIP

def denormalize(img, mean, std):
    return img * torch.tensor(std).view(3, 1, 1) + torch.tensor(mean).view(3, 1, 1)
image_mean = (0.48145466, 0.4578275, 0.40821073)
image_std = (0.26862954, 0.26130258, 0.27577711)

In [None]:
MODE = "exclip"
MODEL_NAME = "openai/clip-vit-base-patch32"
model, preprocess = clip.load("ViT-B/32", device=device)
# MODEL_NAME = "openai/clip-vit-base-patch16"
# model, preprocess = clip.load("ViT-B/16", device=device)

### example

In [None]:
input_text = "black dog next to a yellow hydrant"
text_processed = clip.tokenize([input_text]).to(device)
input_image = Image.open(os.path.join("..", "assets", "dog_and_hydrant.png"))
image_preprocessed = preprocess(input_image).to(device).unsqueeze(0)

explanation = exCLIP(model, device=device)
output = explanation.attribute_prediction(
    text_processed, 
    image_preprocessed, 
    text_layer=11, 
    image_layer=11,
    N=20
)
crossmodal_interactions = output[1:-1,1:]
gv = src.utils.convert_array_to_second_order(crossmodal_interactions, index="exclip")
gv.save("../example/dog_and_hydrant_exclip.pkl")

### insertion/deletion

In [None]:
PATH_RESULTS = "../results"
PATH_INPUT = "../results/mscoco"

dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
    streaming=True
)

start, stop = 0, 1000
df_results = pd.read_csv(os.path.join(PATH_RESULTS, MODEL_NAME, "mscoco_predictions.csv"), index_col=0)
top_ids = df_results.sort_values("logit", ascending=False).iloc[start:stop, :].index

In [None]:
results_details = {MODE: {}}

results = pd.DataFrame({
    'input': [], 
    'mode': [], 
    'mean': [],
    'mean_normalized': []
})

In [None]:
n_iter = 0
for i, d in enumerate(dataset):
    if i not in top_ids:
        continue
    n_iter += 1
    if n_iter % 5 == 1:
        print(f'iter: {start + n_iter}/{stop}', flush=True)

    input_image = d['jpg']
    input_text = d['txt'].split("\n")[df_results.loc[i, "best_text_id"].item()]
    
    game = src.game_openai.CLIPGame(
        model, preprocess, 
        input_image=input_image,
        input_text=input_text,
        patch_size=16 if MODEL_NAME.endswith("16") else 32,
        batch_size=64
    )


    path_file = os.path.join(PATH_INPUT, MODEL_NAME, MODE, f'iv_order2_{i}.pkl')
    iv_object = shapiq.InteractionValues.load(path_file)

    ## baseline
    coalition_matrix_deletion_mif, coalition_matrix_deletion_lif = src.clique.get_cliques_greedy_mif_lif(iv=iv_object)
    predictions_deletion_mif = game.value_function(np.concatenate((coalition_matrix_deletion_mif, [game.empty_coalition]), axis=0))
    predictions_deletion_lif = game.value_function(np.concatenate((coalition_matrix_deletion_lif, [game.empty_coalition]), axis=0))

    ## bad
    # attribution_values = src.utils.convert_iv_to_first_order(iv_object, p_sampler=0.5).get_n_order(1).values
    # attribution_values = src.utils.convert_iv_to_first_order(iv_object, p_sampler=1).get_n_order(1).values
    # attribution_values = src.utils.convert_exclip_to_first_order(iv_object, game.n_players_image, game.n_players_text).get_n_order(1).values
    # attribution_values_sorted = np.sort(attribution_values)
    # coalition_matrix_deletion_mif = np.stack([attribution_values <= v for v in attribution_values_sorted[::-1]] + [game.empty_coalition])
    # predictions_deletion_mif = game.value_function(coalition_matrix_deletion_mif)
    # coalition_matrix_deletion_lif = np.stack([attribution_values >= v for v in attribution_values_sorted] + [game.empty_coalition])
    # predictions_deletion_lif = game.value_function(coalition_matrix_deletion_lif)
    
    ## worse
    # coalition_matrix_deletion_mif, coalition_matrix_deletion_lif = src.clique.get_cliques_greedy_mif_lif(iv=iv_object, return_complement=True)
    # predictions_deletion_mif = game.value_function(np.concatenate(([game.empty_coalition == False], coalition_matrix_deletion_lif[::-1]), axis=0))
    # predictions_deletion_lif = game.value_function(np.concatenate(([game.empty_coalition == False], coalition_matrix_deletion_mif[::-1]), axis=0))
    
    results_details[MODE][i] = {
        'predictions_deletion_mif': predictions_deletion_mif,
        'predictions_deletion_lif': predictions_deletion_lif,
    }

    assert predictions_deletion_mif[-1] == predictions_deletion_lif[-1]
    assert predictions_deletion_mif[0] == predictions_deletion_lif[0]
    
    # normalize the curve
    min_value = predictions_deletion_mif[-1]
    max_value = predictions_deletion_mif[0]

    predictions_deletion_mif_01 = (predictions_deletion_mif - min_value) / (max_value - min_value)
    predictions_deletion_lif_01 = (predictions_deletion_lif - min_value) / (max_value - min_value)

    results = pd.concat([results, pd.DataFrame({
        'input': [i], 
        'mode': [MODE], 
        'mean': [np.mean(predictions_deletion_lif - predictions_deletion_mif)],
        'mean_normalized': [np.mean(predictions_deletion_lif_01 - predictions_deletion_mif_01)]
    })])

    if n_iter == 1 or n_iter % 5 == 0:
        results.to_csv(os.path.join(PATH_RESULTS, MODEL_NAME, f'mscoco_aid_exclip.csv'), index=False)
        np.save(os.path.join(PATH_RESULTS, MODEL_NAME, f'mscoco_aid_exclip.npy'), results_details)

### visualize explanations

In [None]:
PATH_INPUT = "../results"
PATH_OUTPUT = "../results/mscoco"

path_output = os.path.join(PATH_OUTPUT, MODEL_NAME, MODE)
if not os.path.exists(path_output):
    os.makedirs(path_output)

In [None]:
dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
    streaming=True
)

In [None]:
df_results = pd.read_csv(os.path.join(PATH_INPUT, MODEL_NAME, "mscoco_predictions.csv"), index_col=0)
top_ids = df_results.sort_values("logit").tail(1000).index

In [None]:
print_counter = 0
for i, d in enumerate(dataset):
    if i not in top_ids:
        continue
    if print_counter % 100 == 0:
        print(f'{MODE} | iter: {print_counter}/1000', flush=True)
    print_counter += 1

    input_image = d['jpg']
    image_preprocessed = preprocess(input_image).to(device).unsqueeze(0)
    input_text = d['txt'].split("\n")[df_results.loc[i, "best_text_id"].item()]
    text_processed = clip.tokenize([input_text]).to(device)
    
    explanation = exCLIP(model, device=device)
    output = explanation.attribute_prediction(
        text_processed, 
        image_preprocessed, 
        text_layer=11, 
        image_layer=11,
        N=20
    )
    crossmodal_interactions = output[1:-1,1:]
    gv = src.utils.convert_array_to_second_order(crossmodal_interactions, index="exclip")
    gv.save(os.path.join(path_output, f'iv_order2_{i}.pkl'))


### pointing game

In [None]:
GAMES = [
    ['goldfish', 'husky', 'pizza', 'tractor'],
    ['cat', 'goldfish', 'plane', 'pizza'],
    ['banana', 'cat', 'tractor', 'ball'],
    ['husky', 'banana', 'plane', 'church'],
    ['pizza', 'ipod', 'goldfish', 'banana'],
    ['ipod', 'cat', 'husky', 'plane'],
    ['tractor', 'ball', 'banana', 'ipod'],
    ['plane', 'church', 'ball', 'goldfish'],
    ['church', 'pizza', 'ipod', 'cat'],
    ['ball', 'husky', 'banana', 'tractor'],
]
PATH_OUTPUT = "../results/imagenet_pointing_game"

In [None]:
for game in GAMES:
    PATH_INPUT = f'../data/imagenet_pointing_game/{"_".join(game)}'
    print(game)
    for i_objects in range(1, 5):
        class_labels = game[:i_objects]
        cl = "_".join(class_labels)
        input_text = cl.replace("_", " ")
        text_processed = clip.tokenize([input_text]).to(device)
        path_output = os.path.join(PATH_OUTPUT, MODEL_NAME, MODE, cl)
        if not os.path.exists(path_output):
            os.makedirs(path_output)
        for i_image in range(50):
            input_image = Image.open(os.path.join(PATH_INPUT, f'{i_image}.jpg'))
            image_preprocessed = preprocess(input_image).to(device).unsqueeze(0)

            explanation = exCLIP(model, device=device)
            output = explanation.attribute_prediction(
                text_processed, 
                image_preprocessed, 
                text_layer=11, 
                image_layer=11,
                N=20
            )
            crossmodal_interactions = output[1:-1,1:]
            gv = src.utils.convert_array_to_second_order(crossmodal_interactions, index="exclip")
            gv.save(os.path.join(path_output, f'iv_order2_{i_image}.pkl'))