This notebook creates the query images in batch 0. 
In order to run it, get the dog swan image into a folder called `artificial_class`. 
You do not want to have any other images (e.g. reference images) in this folder yet.

In [None]:
# sudo apt update
# sudo apt install libgl1-mesa-glx
# pip install opencv-python

# Save query images of bird-dog instruction trial

## Imports

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
import cv2

In [None]:
import shutil
from torch.utils import data
from torchvision import transforms
from torchvision.datasets.folder import default_loader as default_image_loader

In [None]:
import occlusion_utils as ut

# Natural Query Image

dowloaded from https://pixnio.com/fauna-animals/dogs/dog-water-bird-swan-lake-waterfowl-animal-swimming

In [None]:
# load image
query_path = "$DATAPATH/pre_stimulus_instruction_20210512/"
data_loader = ut.get_data_loader(query_path)
image, _, _ = next(iter(data_loader))
image_np_transformed = image.numpy().transpose(0,2,3,1) # (1, 224, 224, 3)

In [None]:
trial_type = "instruction_practice_catch"
layer_number = "a"
kernel_size = "a"
channel_number = "a"
batch = 0

In [None]:
conditions_list = ["stimuli_mixed_condition", "stimuli_pure_conditions"]

In [None]:
start_positions_max = {}
# key: occlusion_size, value: tuple of x_start and y_start
start_positions_max[66] = (150, 110)
start_positions_max[90] = (134, 75)
start_positions_max[112] = (112, 52)

In [None]:
start_positions_min = {}
# key: occlusion_size, value: tuple of x_start and y_start
start_positions_min[66] = (0, 135)
start_positions_min[90] = (0, 134)
start_positions_min[112] = (0, 112)

In [None]:
for conditions_list_i in conditions_list:
    
    data_dir = os.path.join(
        query_path,
        conditions_list_i,
        "channel",
        trial_type,
        f"layer_{layer_number}",
        f"kernel_size_{kernel_size}",
        f"channel_{channel_number}",
        "natural_images",
        f"batch_{batch}"
    )

    # loop through occlusion sizes
    for percentage_side_length_i, occlusion_size_i, heatmap_size_i in zip(
        ut.percentage_side_length_list, 
        ut.occlusion_sizes_list, 
        ut.heatmap_sizes_list):
        print(f"percentage_side_length_i {percentage_side_length_i}, occlusion_size_i {occlusion_size_i}, heatmap_size_i {heatmap_size_i}")

        query_dir = os.path.join(
            data_dir,
            f"{percentage_side_length_i}_percent_side_length_dog_swan"
        )
        os.makedirs(query_dir, exist_ok=True)

        # loop through query images
        for query_type_i in ["default", "max_activation", "min_activation"]:
            print("query_type_i", query_type_i)

            # get images
            image_to_be_saved = copy.deepcopy(image_np_transformed.squeeze())

            # add patch to occlusion images
            # hand crafted by Judy
            if "activation" in query_type_i:
                if "min" in query_type_i:
                    x_start, y_start = start_positions_min[occlusion_size_i]
                    x_end, y_end = x_start + occlusion_size_i, y_start + occlusion_size_i
                elif "max" in query_type_i:
                    x_start, y_start = start_positions_max[occlusion_size_i]
                    x_end, y_end = x_start + occlusion_size_i, y_start + occlusion_size_i

                # add occlusion
                image_to_be_saved[x_start:x_end, y_start:y_end, :] = np.mean(np.mean(image_to_be_saved[x_start:x_end, y_start:y_end, :], axis=0), axis=0)

            image_path = os.path.join(query_dir, f"query_{query_type_i}.png")
            cv2.imwrite(image_path, cv2.cvtColor(image_to_be_saved*255, cv2.COLOR_RGB2BGR))

            plt.imshow(image_to_be_saved)
            plt.show()