In [None]:
import os
import sys
import shapiq
import numpy as np
import pandas as pd
from itertools import chain, combinations

import sys
sys.path.append('../')
import src

import warnings
warnings.filterwarnings("ignore")

In [None]:
PATH_INPUT = "../results/imagenet_pointing_game"

In [None]:
def powerset(
    iterable,
    min_size: int = 0,
    max_size: int | None = None,
):
    s = sorted(iterable)
    max_size = len(s) if max_size is None else min(max_size, len(s))
    return chain.from_iterable(combinations(s, r) for r in range(max(min_size, 0), max_size + 1))

In [None]:
grid_players = {
    7: {
        0: np.tile(range(0, 3), 3) + np.repeat(np.arange(0, 3) * 7, 3),
        1: np.tile(range(4, 7), 3) + np.repeat(np.arange(0, 3) * 7, 3),
        2: np.tile(range(0, 3), 3) + np.repeat(np.arange(0, 3) * 7, 3) + 7 * 4,
        3: np.tile(range(4, 7), 3) + np.repeat(np.arange(0, 3) * 7, 3) + 7 * 4
    },
    # 7: {
    #     0: np.tile(range(0, 4), 4) + np.repeat(np.arange(0, 4) * 7, 4),
    #     1: np.tile(range(3, 7), 4) + np.repeat(np.arange(0, 4) * 7, 4),
    #     2: np.tile(range(0, 4), 4) + np.repeat(np.arange(0, 4) * 7, 4) + 7 * 4,
    #     3: np.tile(range(3, 7), 4) + np.repeat(np.arange(0, 4) * 7, 4) + 7 * 4
    # },
    8: {
        0: np.tile(range(0, 4), 4) + np.repeat(np.arange(0, 4) * 8, 4),
        1: np.tile(range(4, 8), 4) + np.repeat(np.arange(0, 4) * 8, 4),
        2: np.tile(range(0, 4), 4) + np.repeat(np.arange(0, 4) * 8, 4) + 8 * 4,
        3: np.tile(range(4, 8), 4) + np.repeat(np.arange(0, 4) * 8, 4) + 8 * 4
    },
    14: {
        0: np.tile(range(0, 7), 7) + np.repeat(np.arange(7) * 14, 7),
        1: np.tile(range(7, 14), 7) + np.repeat(np.arange(7) * 14, 7),
        2: np.tile(range(0, 7), 7) + np.repeat(np.arange(7) * 14, 7) + 14 * 7,
        3: np.tile(range(7, 14), 7) + np.repeat(np.arange(7) * 14, 7) + 14 * 7
    },
    16: {
        0: np.tile(range(0, 8), 8) + np.repeat(np.arange(8) * 16, 8),
        1: np.tile(range(8, 16), 8) + np.repeat(np.arange(8) * 16, 8),
        2: np.tile(range(0, 8), 8) + np.repeat(np.arange(8) * 16, 8) + 16 * 8,
        3: np.tile(range(8, 16), 8) + np.repeat(np.arange(8) * 16, 8) + 16 * 8
    }
}

## fixlip

In [None]:
# MODEL_NAME = "openai/clip-vit-base-patch32"
# MODEL_NAME = "openai/clip-vit-base-patch16"
# MODEL_NAME = "google/siglip2-base-patch32-256"
# MODEL_NAME = "google/siglip2-base-patch16-224"
MODEL_NAME = "google/siglip2-large-patch16-256"
# MODEL_NAME = "google/siglip-base-patch16-224"
# MODEL_NAME = "google/siglip-large-patch16-256"
MODE = "banzhaf_crossmodal"
# MODE = "shapley"

In [None]:
results = pd.DataFrame({
    'order': [], 
    'p_sampler': [],
    'text_input': [], 
    'n_objects': [],
    'image_id': [], 

    'clique_ratio_correct': [],
    'mass_ratio_correct': [], 
    'mass_correct': [], 
    'mass_wrong': [], 
    'mass_total': [], 
    'sign_ratio_correct': [],
})
for order in [
    1, 
    2
]:
    if MODE.startswith("banzhaf"):
        for p_sampler in [
            "0.5", 
            # "0.3", 
            # "0.7"
        ]:
            for input_text in os.listdir(os.path.join(PATH_INPUT, MODEL_NAME, MODE, p_sampler)):
                class_labels = input_text.split("_")
                #:# siglip
                # if "husky" in class_labels or "ipod" in class_labels or "goldfish" in class_labels:
                #     continue
                #:# siglip-2
                if "ipod" in class_labels or "goldfish" in class_labels:
                    continue
                n_players_text = len(class_labels)
                for image_id in range(50):
                    path_file = os.path.join(PATH_INPUT, MODEL_NAME, MODE, p_sampler, input_text, f'iv_order{order}_{image_id}.pkl')
                    try:
                        iv = shapiq.InteractionValues.load(path_file)
                    except:
                        # print(path_file)
                        continue
                    n_players_image = iv.n_players - n_players_text
                    grid_size = int(np.sqrt(n_players_image))
                    grid_ids = grid_players[grid_size]
                    mass_correct, mass_wrong = 0, 0
                    clique_value_correct, clique_ratio_correct = 0, 0
                    sign_correct = 0

                    fixlip = src.utils.get_crossmodal_subset(iv, n_players_image, n_players_text)

                    # full_graph_value = (fixlip.values.sum() - fixlip.baseline_value).item()
                    for token_id, token_text in enumerate(class_labels):
                        image_players_in = grid_ids[token_id]
                        text_players_out = n_players_image + np.array([e for e in range(len(class_labels)) if e != token_id])
                        image_players_out = np.concat([grid_ids[e] for e in range(4) if e != token_id])
                        if order == 1:
                            iv_subset_in = src.utils.get_subset(fixlip, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                            iv_subset_out = src.utils.get_subset(fixlip, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                            # iv_subset_in = src.utils.get_subset(fixlip, players=image_players_in, rename_players=False)
                            # iv_subset_out = src.utils.get_subset(fixlip, players=image_players_out, rename_players=False)
                            values_in = iv_subset_in.get_n_order(1).values
                            values_out = iv_subset_out.get_n_order(1).values       
                        elif order == 2:
                            iv_subset_in = src.utils.get_subset(fixlip, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                            iv_subset_out = src.utils.get_subset(fixlip, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                            values_in = iv_subset_in.get_n_order(2).values
                            values_out = iv_subset_out.get_n_order(2).values
                            
                        n_values = len(values_in) + len(values_out)
                        
                        sign_correct += ((values_in > 0).sum().item() + (values_out < 0).sum().item()) / n_values

                        mass_correct += values_in[values_in > 0].sum().item() + np.abs(values_out[values_out < 0]).sum().item()
                        mass_wrong += values_out[values_out > 0].sum().item() + np.abs(values_in[values_in < 0]).sum().item()

                    results = pd.concat([results, pd.DataFrame({
                        'order': [order],
                        'p_sampler': [p_sampler],
                        'text_input': [" ".join(class_labels)], 
                        'n_objects': [len(class_labels)], 
                        'image_id': [image_id], 
                        'mass_ratio_correct': [mass_correct / (mass_correct + mass_wrong)], 
                        'mass_correct': [mass_correct], 
                        'mass_wrong': [mass_wrong], 
                        'mass_total': [mass_correct + mass_wrong],
                        'sign_ratio_correct': [sign_correct / len(class_labels)],
                    })])
    elif MODE == "shapley":
        for input_text in os.listdir(os.path.join(PATH_INPUT, MODEL_NAME, MODE)):
            class_labels = input_text.split("_")
            n_players_text = len(class_labels)
            for image_id in range(50):
                path_file = os.path.join(PATH_INPUT, MODEL_NAME, MODE, input_text, f'iv_order{order}_{image_id}.pkl')
                try:
                    iv = shapiq.InteractionValues.load(path_file)
                except:
                    # print(path_file)
                    continue
                n_players_image = iv.n_players - n_players_text
                grid_size = int(np.sqrt(n_players_image))
                grid_ids = grid_players[grid_size]
                mass_correct, mass_wrong = 0, 0
                clique_value_correct, clique_ratio_correct = 0, 0
                sign_correct = 0

                fixlip = src.utils.get_crossmodal_subset(iv, n_players_image, n_players_text)

                for token_id, token_text in enumerate(class_labels):
                    image_players_in = grid_ids[token_id]
                    text_players_out = n_players_image + np.array([e for e in range(len(class_labels)) if e != token_id])
                    image_players_out = np.concat([grid_ids[e] for e in range(4) if e != token_id])
                    if order == 1:
                        iv_subset_in = src.utils.get_subset(fixlip, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                        iv_subset_out = src.utils.get_subset(fixlip, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                        # iv_subset_in = src.utils.get_subset(fixlip, players=image_players_in, rename_players=False)
                        # iv_subset_out = src.utils.get_subset(fixlip, players=image_players_out, rename_players=False)
                        values_in = iv_subset_in.get_n_order(1).values
                        values_out = iv_subset_out.get_n_order(1).values       
                    elif order == 2:
                        iv_subset_in = src.utils.get_subset(fixlip, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                        iv_subset_out = src.utils.get_subset(fixlip, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                        values_in = iv_subset_in.get_n_order(2).values
                        values_out = iv_subset_out.get_n_order(2).values
                        
                    n_values = len(values_in) + len(values_out)
                    
                    sign_correct += ((values_in > 0).sum().item() + (values_out < 0).sum().item()) / n_values

                    mass_correct += values_in[values_in > 0].sum().item() + np.abs(values_out[values_out < 0]).sum().item()
                    mass_wrong += values_out[values_out > 0].sum().item() + np.abs(values_in[values_in < 0]).sum().item()

                results = pd.concat([results, pd.DataFrame({
                    'order': [order],
                    'text_input': [" ".join(class_labels)], 
                    'n_objects': [len(class_labels)], 
                    'image_id': [image_id], 
                    'mass_ratio_correct': [mass_correct / (mass_correct + mass_wrong)], 
                    'mass_correct': [mass_correct], 
                    'mass_wrong': [mass_wrong], 
                    'mass_total': [mass_correct + mass_wrong],
                    'sign_ratio_correct': [sign_correct / len(class_labels)],
                })])

In [None]:
results.shape

In [None]:
results.groupby(["order", "p_sampler", "n_objects"]).aggregate({
    'mass_ratio_correct': ['mean', 'std']
}).round(3)

In [None]:
results.groupby(["order", "n_objects"]).aggregate({
    'mass_ratio_correct': ['mean', 'std']
}).round(3)

## baselines

In [None]:
results = pd.DataFrame({
    'model_name': [],
    'mode': [],
    'text_input': [], 
    'n_objects': [],
    'image_id': [], 
    'mass_ratio_correct': [], 
    'mass_correct': [], 
    'mass_wrong': [], 
    'mass_total': [], 
    'sign_ratio_correct': [],
})
for model_name in [
    "openai/clip-vit-base-patch32",
    "openai/clip-vit-base-patch16"
]:
    for mode in [
        # 'game',
        # 'gradeclip',
        'exclip',
    ]:
        for input_text in os.listdir(os.path.join(PATH_INPUT, model_name, mode)):
            class_labels = input_text.split("_")
            n_players_text = len(class_labels)
            for image_id in range(50):
                if mode in ['game', 'gradeclip']:
                    path_file = os.path.join(PATH_INPUT, model_name, mode, input_text, f'iv_order1_{image_id}.pkl')
                elif mode.startswith('exclip'):
                    path_file = os.path.join(PATH_INPUT, model_name, mode, input_text, f'iv_order2_{image_id}.pkl')
                try:
                    iv = shapiq.InteractionValues.load(path_file)
                except:
                    # print(path_file)
                    continue
                n_players_image = iv.n_players - n_players_text
                grid_size = int(np.sqrt(n_players_image))
                grid_ids = grid_players[grid_size]
                mass_correct, mass_wrong = 0, 0
                sign_correct = 0

                for token_id, token_text in enumerate(class_labels):
                    image_players_in = grid_ids[token_id]
                    text_players_out = n_players_image + np.array([e for e in range(len(class_labels)) if e != token_id])
                    image_players_out = np.concat([grid_ids[e] for e in range(4) if e != token_id])

                    if mode in ['game', 'gradeclip']:
                        iv_subset_in = src.utils.get_subset(iv, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                        iv_subset_out = src.utils.get_subset(iv, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                        values_in = iv_subset_in.get_n_order(1).values
                        values_out = iv_subset_out.get_n_order(1).values       
                    elif mode.startswith('exclip'):
                        iv_subset_in = src.utils.get_subset(iv, players=np.append(image_players_in, n_players_image + token_id), rename_players=False)
                        iv_subset_out = src.utils.get_subset(iv, players=np.append(image_players_out, n_players_image + token_id), rename_players=False)
                        values_in = iv_subset_in.get_n_order(2).values
                        values_out = iv_subset_out.get_n_order(2).values

                    n_values = len(values_in) + len(values_out)
                    
                    sign_correct += ((values_in > 0).sum().item() + (values_out < 0).sum().item()) / n_values

                    mass_correct += values_in[values_in > 0].sum().item() + np.abs(values_out[values_out < 0]).sum().item()
                    mass_wrong += values_out[values_out > 0].sum().item() + np.abs(values_in[values_in < 0]).sum().item()

                results = pd.concat([results, pd.DataFrame({
                    'model_name': [model_name],
                    'mode': [mode],
                    'text_input': [" ".join(class_labels)], 
                    'n_objects': [len(class_labels)], 
                    'image_id': [image_id], 
                    'mass_ratio_correct': [mass_correct / (mass_correct + mass_wrong)], 
                    'mass_correct': [mass_correct], 
                    'mass_wrong': [mass_wrong], 
                    'mass_total': [mass_correct + mass_wrong],
                    'sign_ratio_correct': [sign_correct / len(class_labels)],
                })])

In [None]:
results.shape

In [None]:
results.groupby(["model_name", "mode", "n_objects"]).aggregate({
    'mass_ratio_correct': ['mean', 'std'],
}).round(3)