In [None]:
import joblib
import numpy
import torch
import pandas
import polars as pl

from matplotlib import pyplot

In [None]:
from influ_examples.components.datasets import DatasetType

dataset_type = DatasetType.cifar10
# dataset_type = DatasetType.mnist

results = joblib.load(f"../../result/analyzed_influence-{dataset_type.value}.gz")
len(results)

In [None]:
for idx, k in enumerate(results.keys()):
    print(f"{idx=} {k}")


In [None]:
if dataset_type == DatasetType.mnist:
    from influ_examples.components.datasets.mnist import g_class_names, load_data
    from influ_examples.components.models.model_mnist import SimpleModelMnist

    SimpleModel = SimpleModelMnist
elif dataset_type == DatasetType.cifar10:
    from influ_examples.components.datasets.cifar10 import g_class_names, load_data
    from influ_examples.components.models.model_cifar10 import SimpleModelCifar10

    SimpleModel = SimpleModelCifar10
else:
    raise NotImplementedError(f"{dataset_type.value=}")

trainloader, testloader = load_data(do_shuffle=False)

In [None]:
def _to_image_data(x: torch.Tensor):
    return (((x.permute(1, 2, 0) + 1.0) / 2.) * 255.).cpu().numpy().astype(numpy.uint8)


In [None]:
def show_train_images(samples: list, influence: list, to_influence_index: list, trainloader, n_rows: int, n_cols: int, title="train images"):
    n = n_rows * n_cols
    fig = pyplot.figure(figsize=(6, 6))
    pyplot.rcParams["font.size"] = 5
    pyplot.axis("off")
    pyplot.title(f"{title}\n\n")
    axes = []
    for idx, train_idx in enumerate(samples[:n]):
        z = trainloader.dataset[train_idx]
        x, t = z
        influence_value = influence[to_influence_index[train_idx]]
        ax = fig.add_subplot(n_rows, n_cols, idx + 1)
        ax.axis("off")
        axes.append(ax)
        train_label = g_class_names[t]
        axes[-1].set_title(f"{idx + 1}. {train_label} : {train_idx}\n({influence_value:0.3f})")  
        train_image = _to_image_data(x)
        pyplot.imshow(train_image)
    fig.tight_layout()    
    pyplot.show()




In [None]:
def show_image(x, label):
    test_image = _to_image_data(x)
    fig = pyplot.figure(figsize=(3, 3))
    pyplot.rcParams["font.size"] = 5
    pyplot.axis("off")
    pyplot.title(f"test label: {g_class_names[label]}")
    pyplot.imshow(test_image)
    pyplot.show()


In [None]:
def show_influenced_image(test_x, test_label, influence, helpful, harmful, to_influence_index, trainloader):
    n_rows = 4
    n_cols = 5

    n = n_rows * n_cols


    # show test image
    show_image(test_x, test_label)

    # show helpful images
    show_train_images(helpful, influence, to_influence_index, trainloader, n_rows, n_cols, title="helpful train images ranking")

    # show harmful images
    show_train_images(harmful, influence, to_influence_index, trainloader, n_rows, n_cols, title="harmful train images ranking")



In [None]:
# NOTE: 複数のテストデータ点にまたがって、harmful, helpful の画像も特定できるようにしたい

In [None]:
testset = testloader.dataset

for lbl, rankings in results.items():
    for rec in rankings:
        test_idx, influence_values, helpful_train_indices, harmful_train_indices, to_influence_index, test_estimation = rec
        test_z = testloader.dataset[test_idx]
        test_x, test_t = test_z
        print("=" * 100)
        show_influenced_image(test_x, test_t, influence_values, helpful_train_indices, harmful_train_indices, to_influence_index, trainloader)