In [None]:
print("start 3_occlusion_save_query_images")

# Save extremal query images of occlusion stimuli

## Imports

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
import cv2
import argparse

In [None]:
import occlusion_utils as ut

## Parameters

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--stimuli-dir", required=True, help="Path to save stimuli to.")
parser.add_argument("-t", "--trial-type", required=True, help="instruction_practice_catch or sampled_trials.")
args = parser.parse_args()
print(args)

In [None]:
stimuli_dir = args.stimuli_dir
trial_type = args.trial_type

## Load experiment specification

In [None]:
# read in unit specifications from csv into pandas dataframe
path_to_csv_file = os.path.join(stimuli_dir, f"layer_folder_mapping_{trial_type}.csv")
unit_specs_df = pd.read_csv(path_to_csv_file, header=1)

In [None]:
unit_specs_df

## Functions

In [None]:
def get_activations_data(stimulus_type, percentage_side_length_i):
    npy_file_name = f"activations_for_occlusions_of_{percentage_side_length_i}_percent.npy"
    path_to_npy = os.path.join(
        data_dir,
        f"{percentage_side_length_i}_percent_side_length",
        npy_file_name
    )
    activations_data = np.load(path_to_npy)
    return activations_data

# Save them!

In [None]:
for _, row in unit_specs_df.iterrows():
    
    # load unit specification
    layer_number = row["layer_number"]
    kernel_size_number = row["kernel_size_number"]
    channel_number = row["channel_number"]
    feature_map_number = row["feature_map_number"]
    layer_name = row["layer_name"]
    pre_post_relu = row["pre_post_relu"]
    
    print(row)
    
    for batch in range(ut.n_batches):
        start = time.time()
    
        # load images
        data_dir = os.path.join(
            stimuli_dir,
            ut.objective,
            trial_type,
            f"layer_{layer_number}",
            f"kernel_size_{kernel_size_number}",
            f"channel_{channel_number}",
            "natural_images",
            f"batch_{batch}"
        )
        data_loader = ut.get_data_loader(os.path.join(data_dir, "val"))
        image, _, _ = next(iter(data_loader))
        image_np_transformed = image.numpy().transpose(0,2,3,1) # (1, 224, 224, 3)
        
        # loop through occlusion sizes
        for percentage_side_length_i, occlusion_size_i, heatmap_size_i in zip(
            ut.percentage_side_length_list, 
            ut.occlusion_sizes_list, 
            ut.heatmap_sizes_list):
            print(f"percentage_side_length_i {percentage_side_length_i}, occlusion_size_i {occlusion_size_i}, heatmap_size_i {heatmap_size_i}")

            activations_data_one_occlusion_size = get_activations_data("occlusions", percentage_side_length_i)

            query_dir = os.path.join(
                data_dir,
                f"{percentage_side_length_i}_percent_side_length"
            )
            os.makedirs(query_dir, exist_ok=True)

            list_of_positions = ut.get_list_of_occlusion_positions(heatmap_size_i, occlusion_size_i)

            # loop through query images
            for query_type_i in ["default", "max_activation", "min_activation"]:

                # get images
                image_to_be_saved = copy.deepcopy(image_np_transformed.squeeze())

                # add patch to occlusion images
                if "activation" in query_type_i:
                    if "min" in query_type_i:
                        extreme_idx = np.argmin(activations_data_one_occlusion_size[:-1])
                    elif "max" in query_type_i:
                        extreme_idx = np.argmax(activations_data_one_occlusion_size[:-1])
                    x_start, x_end, y_start, y_end = list_of_positions[extreme_idx]

                    # add occlusion
                    image_to_be_saved[x_start:x_end, y_start:y_end, :] = np.mean(np.mean(image_to_be_saved[x_start:x_end, y_start:y_end, :], axis=0), axis=0)
                                    
                image_path = os.path.join(query_dir, f"query_{query_type_i}.png")
                cv2.imwrite(image_path, cv2.cvtColor(image_to_be_saved*255, cv2.COLOR_RGB2BGR))

                plt.imshow(image_to_be_saved)
                plt.show()
        end = time.time()
        print(f"       time for one batch: {end-start}")    

In [None]:
print("done with 3_occlusion_save_extremal_query_images")