In [None]:
from math import prod

import torch
import pandas as pd
from tqdm.auto import tqdm

from utils import concept_influence
from datasets.sanity_check import DummyConceptInfluenceDataset
from datasets.batch_collate import collate_fn_ci

NUM_SAMPLES = 32
IMAGE_SIZE = (16, 16)
MORPHOLOGICAL_DILATION = False

BATCH_SIZE = 4
NUM_WORKERS = 2

In [None]:
dummy_ci_dataset = torch.utils.data.DataLoader(
    DummyConceptInfluenceDataset(dilation=MORPHOLOGICAL_DILATION, image_size=IMAGE_SIZE, dataset_size=NUM_SAMPLES),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn_ci,
    num_workers=NUM_WORKERS,
)


result_records = []
for batch in tqdm(dummy_ci_dataset):

    batch_size = batch["attribution_maps"].shape[0]
    for i in range(batch_size):
        per_sample_dict = {}
        class_labels, intersection, concept_size_px = concept_influence(
            batch["attribution_maps"][i], batch["segmentation_maps"][i]
        )
        per_sample_dict["class_label"] = class_labels
        per_sample_dict["intersection"] = intersection
        per_sample_dict["concept_size_px"] = concept_size_px
        for meta_k, meta_v in batch["metadata"].items():
            per_sample_dict[meta_k] = meta_v[i]

    result_records.append(per_sample_dict)


image_size_px = prod(list(batch["attribution_maps"][0].shape))

df = pd.DataFrame.from_records(
    result_records,
).explode(["class_label", "intersection", "concept_size_px"])
df["concept_size_rel"] = df["concept_size_px"] / image_size_px
df["concept_influence"] = df["intersection"] / df["concept_size_rel"]

In [None]:
df