# Visualize the probability vs crop size for each datapoint in an experiment

This notebook creates a separate plot per input image visualizing the MIRC search: On the y-axis, the decreasing probability is shown and on the x-axis, the decreasing crop size (in real pixel space) and the image crops are displayed. By default, the figure is only created if a datapoint contains a MIRC. The figures are saved to the folder that is specific to an experiment (it was created during the MIRC-search).

By *one datapoint*, I refer to one input image that has its unique
correct class(es). If several classes were considered as *one correct
class* in the case of the data from Ullman et al. (2016), then there is
only one datapoint. If several classes were considered *separately* as
one correct class each for the data from Ullman et al. (2016), then
there are several datapoints.

## Your TODO

Please specify the path to the most top directory of your recognition
gap experiments, i.e. the parent directory of the analysis folder.

In [None]:
import plot_utils
import data_npz_utils
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
path_to_recognition_gap_folder = "/gpfs01/bethge/home/jborowski/CHAM_recognition_gap/JOV_publication_git_bethgelab/recognition_gap/"

Please specify the name of the experiment folder whose data you would
like to visualize, i.e. where all the csv and npz files were stored to.

In [None]:
exp_name = "exp_792020Ullman_list_as_one_classTrue_startidx0_stopidx9_Ullman4_v0"

## Load libraries

In [None]:

# custom imports
sys.path.insert(1, path_to_recognition_gap_folder)

## Data

In [None]:
exp_dir = os.path.join("..", "figures_and_data_from_experiments", exp_name)

In [None]:
# list of paths to those directories whose data should be plotted
data_point_to_plot_list = data_npz_utils.get_list_of_data_points_to_plot(
    exp_dir, all_datapoints_including_nonMIRC=False)

In [None]:
# get a list with img_identifiers, e.g. 'plane_INclass404'
img_identifier_list = [data_point_to_plot_list[idx].split(
    os.path.sep)[-1] for idx in range(len(data_point_to_plot_list))]
if "MIRCs_and_original_images" in img_identifier_list:
    img_identifier_list.remove("MIRCs_and_original_images")

## Plotting business

### Parameters

In [None]:
# get dictionary of dictionary with the following information:
# imagenetnumber, wordnetID, word
imagenetnumber_wordnetID_word_dict = plot_utils.get_dict_of_dict_with_imagenet_number_wordnetID_word()

In [None]:
# infer from the path to the data, which data is being used: "Ullman" or
# "ImageNet"
if "Ullman" in data_point_to_plot_list[0]:
    Ullman_or_ImageNet = "Ullman"
elif "ImageNet" in data_point_to_plot_list[0]:
    Ullman_or_ImageNet = "ImageNet"
else:
    raise Exception(
        "You are using neither the data of Ullman et al. nor of ImageNet.")

### Helper functions

In [None]:
def get_target_list(img_class_dict):
    """return the target_list

    Args:
        img_class_dict:  ordered dictionary with all data for one datapoint
    Returns:
        target_list: list of labels that were considered correct in search procedure
    """

    img_class_dict_keys = list(img_class_dict.keys())
    target_list = list(img_class_dict[img_class_dict_keys[0]].target_list)
    return target_list

In [None]:
def customize_label(
        Ullman_or_ImageNet,
        target_list,
        imagenetnumber_wordnetID_word_dict):
    """Determine the label for the probability line.

    Args:
        Ullman_or_ImageNet:                 string specifying which data was used
        target_list:                        list of labels that were considered correct in search procedure
        imagenetnumber_wordnetID_word_dict: dictionary with imagenetnumber, wordnetID, word

    Returns:
        this_label:                         label for probability line
    """
    indicate_more_classes_str = ", ..." if len(target_list) > 1 else ""
    if Ullman_or_ImageNet == "Ullman":
        this_label = f"p({imagenetnumber_wordnetID_word_dict[target_list[0]]['word']}{indicate_more_classes_str})"
    else:
        category_word = imagenetnumber_wordnetID_word_dict[target_list[0]]["word"]
        this_label = f"p({category_word}{indicate_more_classes_str})"

    return this_label

In [None]:
def customize_legend(ax):
    """Adjust the legend

    Args:
        ax: axes of plot
    """
    legend = ax.legend(loc="lower left")
    frame = legend.get_frame()
    frame.set_facecolor("white")
    frame.set_edgecolor("white")

In [None]:
def customize_title(ax, target_list, imagenetnumber_wordnetID_word_dict):
    """Set the title.

    Args:
        ax:                                 axes of plot
        target_list:                        list of labels that were considered correct in search procedure
        imagenetnumber_wordnetID_word_dict: dictionary with imagenetnumber, wordnetID, word
    """
    if Ullman_or_ImageNet == "Ullman":
        this_title = f"wordnetID: {imagenetnumber_wordnetID_word_dict[target_list[0]]['wordnetID']}    "\
            f"neuronID: {target_list[0]}    " \
            f"total number of classes: {len(target_list)}"
    else:
        # find neuronID in ImageNet
        for neuronID, word_wordnetID_dict in imagenetnumber_wordnetID_word_dict.items():
            if word_wordnetID_dict["wordnetID"] == img_identifier.split('_')[
                    0]:
                break
        this_title = f"wordnetID: {img_identifier.split('_')[0]}    " \
            f"neuronID: {neuronID}    " \
            f"total number of classes: {len(target_list)}"
    ax.set_title(this_title)

### Main function

In [None]:
def create_and_save_plot():
    """create the figure and save it"""
    fig, ax = plt.subplots(1, 1, figsize=(17, 2))

    # Keep plots comparable despite varying number of search steps
    ax.plot([0.1] * 20, color="white")

    target_list = get_target_list(img_class_dict)
    this_label = customize_label(
        Ullman_or_ImageNet,
        target_list,
        imagenetnumber_wordnetID_word_dict)
    crop_probability = [val.probability for val in img_class_dict.values()]
    plot_utils.plot_probabilities(ax, crop_probability, this_label)
    plot_utils.plot_recognition_criterion(ax, crop_probability)
    plot_utils.plot_crops_below_xaxis(fig, ax, img_class_dict)

    orig_px_size = [str(val.crop_size) for val in img_class_dict.values()]
    plot_utils.customize_axes(ax, crop_probability, 1.04, orig_px_size)
    plot_utils.hide_right_and_top_spine(ax)
    customize_legend(ax)
    customize_title(ax, target_list, imagenetnumber_wordnetID_word_dict)

    fig_name = os.path.join(exp_dir, f"{img_identifier}_search_procedure.png")
    fig.savefig(
        fig_name,
        bbox_inches="tight")
    plt.close(fig)

## Plot it!

In [None]:
# loop through the different datapoints for which a plot should be
# generated for
for img_identifier in img_identifier_list:
    # get the data for that datapoint
    img_class_dict = data_npz_utils.get_img_class_dict_all_data(
        data_point_to_plot_list, exp_dir, img_identifier)
    # plot it!
    create_and_save_plot()