In [59]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os

In [60]:
plt.rcParams['figure.dpi'] = 200

import config # contains the path to the data

unit = (32, 48)
crop_size = 13
CROP = True

In [61]:
parent_dir = f'{config.project_dir}/01_Output/Receptive_Field_Investigation_Per_Model/'
figure_dir = 'figures/Figure_5'

averaged_output_per_model = {}

#list subdirectories
for subdir in os.listdir(parent_dir):
    if 'DS_Store' in subdir:
        continue
    averaged_output_per_model[subdir] = np.load(os.path.join(parent_dir, subdir, 'five_model_average.npy'))

In [62]:
global_max_value, global_min_value = 0, 0


for model_name in averaged_output_per_model.keys():
    averaged_output = averaged_output_per_model[model_name]
    if averaged_output.max() > global_max_value:
        global_max_value = averaged_output.max()
    if averaged_output.min() < global_min_value:
        global_min_value = averaged_output.min()

In [63]:
rdgn = sns.diverging_palette(h_neg=220, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
for model_name, averaged_output in averaged_output_per_model.items():
    plt.rcParams.update({'font.size': 14})
    # create directory with model name
    os.makedirs(f'{figure_dir}/{model_name}', exist_ok=True)
    for nr, frame in enumerate(averaged_output):
        if CROP:
            frame = frame[unit[0]-crop_size:unit[0]+crop_size+1, unit[1]-crop_size:unit[1]+crop_size+1]
        
        plt.figure(figsize=(4, 4))
        heat = sns.heatmap(frame, vmin=global_min_value, vmax=global_max_value, cbar_kws = dict(use_gridspec=False,location="top"))
        
        if CROP:
            plt.scatter(x=crop_size+0.5, y=crop_size+0.5, marker="x", color="black", s=100)
        else:
            plt.scatter(x=unit[1], y=unit[0], marker="x", color="black", s=100)
        heat.set(xticklabels=[], yticklabels=[])
        heat.tick_params(bottom=False, left=False)
        heat.axhline(y=0, color='k', linewidth=1)
        heat.axhline(y=frame.shape[1], color='k', linewidth=2)
        heat.axvline(x=0, color='k', linewidth=2)
        heat.axvline(x=frame.shape[0], color='k', linewidth=1)

        if nr==0: #to get the colorbar
            plt.savefig(f'{figure_dir}/{model_name}/colorbar_frame{nr}.png', bbox_inches='tight')

        heat.collections[0].colorbar.remove()
        plt.tight_layout()
        plt.savefig(f'{figure_dir}/{model_name}/frame{nr}.png', bbox_inches='tight')
        plt.close()