In [None]:
import os
import sys
from PIL import Image
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
torch.set_float32_matmul_precision("high")
from transformers import CLIPProcessor, CLIPModel
import datasets

import shapiq 

sys.path.append('../')
import src
PATH = "../data"

In [None]:
ID = 4343
ORDER = 2
METHOD = "banzhaf/0.5"
MODEL_NAME = "openai/clip-vit-base-patch32"
# MODEL_NAME = "openai/clip-vit-base-patch16"
PATH_DATA = f'../results/mscoco/{MODEL_NAME}/{METHOD}'
PATH_SAVE = f'figure5_{MODEL_NAME.replace("/", "-")}_{METHOD.replace("/", "-")}_{ID}.pdf'

In [None]:
interaction_path = os.path.join(PATH_DATA, f"iv_order{ORDER}_{ID}.pkl")
interaction_values = shapiq.InteractionValues.load(interaction_path)
df_metadata = pd.read_csv(os.path.join("..", "results", MODEL_NAME, "mscoco_predictions.csv"), index_col=0)

In [None]:
print(interaction_values)

In [None]:
dataset = datasets.load_dataset(
    "clip-benchmark/wds_mscoco_captions",
    split="test",
)
data = dataset[ID]
image = data['jpg']
text = data['txt'].split("\n")[df_metadata.loc[ID, "best_text_id"].item()]
game = src.game.VisionLanguageGame(
    model=CLIPModel.from_pretrained(MODEL_NAME),
    processor=CLIPProcessor.from_pretrained(MODEL_NAME),
    input_image=image,
    input_text=text,
)
text_tokens = game.inputs.tokens()
text_tokens = text_tokens[1:-1]
text_tokens = [token.replace('</w>', '') for token in text_tokens]

n_players_image = game.n_players_image
image_array = src.plots.image_torch_to_array(
    game.inputs['pixel_values'].squeeze(0),
    game.processor.image_processor.image_mean,
    game.processor.image_processor.image_std
)

In [None]:
interaction_values.n_players

In [None]:
src.plots.plot_image_and_text_together(
    img=image_array,
    text=text_tokens,
    image_players=list(range(n_players_image)),
    iv=interaction_values,
    plot_interactions=True,
    top_k=16,
    normalize_jointly=True,
    figsize=(7, 7),
    fontsize=22,
    margin=0.3,
    color_text=True,
    plot_heatmap=True,
    show=False,
)
plt.tight_layout(pad=0.15)
plt.savefig(PATH_SAVE)