In [None]:
import pickle 

import imp
from IPython.display import clear_output, display
import matplotlib
%matplotlib inline
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import os

from context import rf_pool

In [None]:
from rf_pool import models, modules, pool, ops
from rf_pool.utils import lattice, functions, visualize, datasets, stimuli

In [None]:
from experiment_functions import *

In [None]:
# make the figures folder
if not os.path.exists('figures'):
        os.mkdir('figures')  

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../../data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=True, num_workers=2)

In [None]:
# load in crowded digits base set
base_set_filename = 'MNIST_CrowdedDataset.pkl'
if os.path.exists('datasets/' + base_set_filename):
    base_set = pickle.load(open('datasets/' + base_set_filename, 'rb'))
else:
    base_set = None
    
# set what labels mmap to what digit
label_map = {}
label_map.update([(n,n) for n in range(10)])

**Load Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))

In [None]:
# load previous model and results
(_, extras) = model.load_model('models/MNIST_rate_0.2_10k_3deg.pkl')

In [None]:
# remove reshape layer 
model.layers.pop('3')

**Replace max pool layer with rf pool layer**

In [None]:
# create the rf layer
img_shape = torch.Size((53,53))
offset = [0., -30.]
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, RF_rate, gap, n_rings=n_rings, std=std,
                                                        offset=offset)
rf_layer = rf_pool.pool.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                pool_fn='max_pool', kernel_size=2, retain_shape=True)

# append the rf pool layer to the model
layer_id = '1'
model.layers[layer_id].forward_layer.add_module('pool', rf_layer)
visualize.heatmap(model, '1');

**Set Figure Paramaters**

In [None]:
# set batch size, image size, and test size
batch_size = 1
img_size = 118
n_test = 1
# set spacing
spacing = 1.
# set label mapping
label_map = {}
label_map.update([(n,n) for n in range(10)])
# get crowded MNIST training data for targets
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

In [None]:
# set dpi and font types
matplotlib.rcParams['figure.dpi'] = 300.
font = {'size': 8., 'style': 'normal', 'family': 'times', 'weight': 'normal'}
matplotlib.rc('font', **font)
matplotlib.rc('pdf', fonttype=42)
matplotlib.rc('text', usetex=True)

**Get Bootstrapped Confidence Intervals**

In [None]:
task = 'Spacing'
index = 3 # 2 for SNR, 3 for accuracy
start = ''
end = '_space'
file = 'results/PSNR_%s_10k.pkl' % task.lower()
extras = pickle.load(open(file, 'rb'))

In [None]:
extras['confidence_intervals'] = {}

In [None]:
for k in extras.keys():
    if not k.startswith(start) or not k.endswith(end):
        continue
    ci = []
    extras['confidence_intervals'].update({k: ci})
    for i in range(len(extras[k])):
        clear_output(wait=True)
        display('%s: %d' % (k, i))
        stats = functions.bootstrap(extras[k][i][index], n_samples=1000)
        x_m = np.mean(extras[k][i][index])
        # better approx. of CI
        ci.append(2. * x_m - np.percentile(stats, [97.5, 2.5]))

In [None]:
with open(file, 'wb') as f:
    pickle.dump(extras, f)

**Get bootstrapped p-values**

In [None]:
n_samples = 1000
space_exp = pickle.load(open('results/PSNR_spacing_10k.pkl', 'rb'))
attn_exp = pickle.load(open('results/PSNR_attention_10k.pkl', 'rb'))
density_exp = pickle.load(open('results/PSNR_density_10k.pkl', 'rb'))

In [None]:
def bootstrap_means(x, y):
    return np.mean(x) - np.mean(y)

def bootstrap_slope(*args, spacing=None):
    assert spacing is not None
    assert len(args) % len(spacing) == 0
    # get mean for each arg
    n_spacing = len(spacing)
    y0 = [np.mean(a) for a in args[:n_spacing]]
    y1 = [np.mean(a) for a in args[n_spacing:]]
    # get slope
    A = np.stack([spacing,np.ones(n_spacing)]).T
    s_0 = np.abs(np.linalg.lstsq(A, y0, rcond=None)[0][0])
    s_1 = np.abs(np.linalg.lstsq(A, y1, rcond=None)[0][0])
    return s_0 - s_1

def bootstrap_test(*args, fn, n_samples, fn_kwargs={}):
    # get initial mean difference
    diff0 = fn(*args, **fn_kwargs)
    
    # get null distribution
    n_args = len(args)
    x_primes = []
    y_primes = []
    for x, y in zip(args[:n_args // 2], args[n_args // 2:]):
        z = np.concatenate([x,y])
        z_mean = np.mean(z)
        x_primes.append(x - np.mean(x) + z_mean)
        y_primes.append(y - np.mean(y) + z_mean)
    args_prime = x_primes + y_primes
    # bootstrap resample
    diff = np.stack(functions.bootstrap(*args_prime, fn=fn, n_samples=n_samples, 
                                        fn_kwargs=fn_kwargs))
    return np.mean(np.abs(diff) > np.abs(diff0))

In [None]:
# test radial min extent attention vs. radial 2-spacing
keys = ['radial_space', 'radial_attn']
p = bootstrap_test(space_exp[keys[0]][-1][3], attn_exp[keys[1]][0][3], fn=bootstrap_means, n_samples=n_samples)
print('%s (%a spacing) < %s (%a extent) p-value: %a' % 
      (keys[0], space_exp['spacing'][-1], keys[1], attn_exp['extent'][0], p))

In [None]:
# test spacing outer vs. inner, radial vs. tangential (pooled across spacings 1 to 1.5)
for key in zip(['outer_space', 'radial_space'], ['inner_space','tangential_space']):
    x = []
    y = []
    for i, spacing in enumerate(space_exp['spacing']):
        # pool across spacings
        if spacing in [1., 1.25, 1.5]:
            x.extend(space_exp[key[0]][i][3])
            y.extend(space_exp[key[1]][i][3])
    # bootstrap test between means
    p = bootstrap_test(x, y, fn=bootstrap_means, n_samples=n_samples)
    print('%s < %s (%a) p-value: %a' % (key[0], key[1], spacing, p))

In [None]:
# test attention outer vs. inner, radial vs. tangential
for key in zip(['outer_attn', 'radial_attn'], ['inner_attn','tangential_attn']):
    for i, extent in enumerate(attn_exp['extent']):
        p = bootstrap_test(attn_exp[key[0]][i][3], attn_exp[key[1]][i][3], 
                           fn=bootstrap_means, n_samples=n_samples)
        print('%s < %s (%a) p-value: %a' % (key[0], key[1], extent, p))

In [None]:
# test slope of density SNR
for key in zip(['cost_0','cost_0','cost_1'], ['cost_1','cost_2','cost_2']):
    p = bootstrap_test(*[x[2] for x in density_exp[key[0]]], *[x[2] for x in density_exp[key[1]]],
                       fn=bootstrap_slope, n_samples=n_samples, fn_kwargs={'spacing': density_exp['spacing']})
    print('Sigma %0.2f < %0.2f slope p-value: %a' % 
          (density_exp['sigma_' + key[0][-1]][0], density_exp['sigma_' + key[1][-1]][0], p))

**Heatmap Figures**

In [None]:
heatmap_file = 'results/PSNR_heatmaps_10k.pkl'
figsize = (3.3,3.3)
fig, axes = plt.subplots(2, 2, figsize=figsize)
titles = ['Inner', 'Outer', 'Radial', 'Tangential']
end = '_hm'
layer_id = '1'
task = 'Spacing'

# load heatmaps    
extras = pickle.load(open(heatmap_file, 'rb'))

# set indices, extent
if task.lower() == 'spacing':
    indices = [1,0]
    extent = None
    spacing = extras['spacing'][-1]
else:
    indices = [2,0]
    extent = extras['extent'][0]
    spacing = 1.
    
# get vmax, diff scores
vmax = 0.
diff_scores = []
for key in titles:
    scores = [extras[key.lower() + end][i] for i in indices]
    diff = scores[0] - scores[1]
    tmp = torch.max(diff[torch.isnan(diff).bitwise_not()])
    diff_scores.append(diff)
    if tmp > vmax:
        vmax = tmp
        
# get heatmaps into subplots
for r in range(2):
    for c in range(2):
        # get flankers, axis, offset
        n_flankers, axis = get_crowd_params(titles[r*2 + c].lower())
        # create crowd set
        crowd_set = create_crowd_set(testset, 1, 118, n_flankers, axis, spacing,
                                     base_set=base_set, label_map=label_map)
        # update RF offset
        mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.2, 0., n_rings=10, std=1.,
                                                                offset=offset)
        model.layers[layer_id].forward_layer.pool.set(mu=mu, sigma=sigma)
        if task.lower() == 'attention':
            model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
        # get heatmap
        fig = visualize.heatmap(model, layer_id, scores=diff_scores[r*2 + c].squeeze(0), cmap='Greens',
                                input=crowd_set[0][0][0], RF_alpha=0.1,
                                vmin=0., vmax=vmax, ax=axes[r,c], show=False)
        axes[r,c].set_title(titles[r*2 + c])

# colorbar
cbar = plt.colorbar(axes[0,0].get_images()[0], ax=axes)
if task.lower() == 'spacing':
    cbar_label = 'Max. - Min. %s PSNR (dB)' % task
else:
    cbar_label = '%s - No %s PSNR (dB)' % (task, task)
cbar.ax.set_ylabel(cbar_label, labelpad=10, rotation=270)

# update RFs due to colorbar
for r in range(2):
    for c in range(2):
        # update scatter RFs
        visualize.scatter_rfs(model, layer_id, updates={'sizes': None}, figsize=figsize, ax=axes[r,c])

In [None]:
fig.savefig('figures/%s_heatmaps.pdf' % task.lower(), dpi=300, bbox_inches='tight')

In [None]:
key = 'radial'
end = '_hm'
layer_id = '1'
task = 'Spacing'
spacings = [2., 1., 2.]

fig, axes = plt.subplots(1,3, figsize=(10., 2.5))
indices = [1,0]
scores = [extras[key.lower() + end][i] for i in indices]
scores.append(scores[0] - scores[1])

# get flankers, axis, offset
n_flankers, axis = get_crowd_params(key.lower())
# create crowd set
crowd_set = create_crowd_set(testset, 1, 118, n_flankers, axis, spacing,
                             base_set=base_set, label_map=label_map)

vmax = torch.max(scores[0][torch.isnan(scores[0]).bitwise_not()]) + 5.
vmin = 0.
for r in range(3):
    # create crowd set
    crowd_set = create_crowd_set(testset, 1, 118, n_flankers, axis, spacings[r],
                                 base_set=base_set, label_map=label_map)
        
    visualize.heatmap(model, layer_id, scores=scores[r].squeeze(0), cmap='Greens',
                    input=crowd_set[0][0][0], RF_alpha=0.1, vmin=vmin, vmax=vmax,
                    ax=axes[r], show=False)
    
cbar = plt.colorbar(axes[0].get_images()[0], ax=axes)
cbar_label = 'Max. - Min. %s PSNR (dB)' % task
cbar.ax.set_ylabel(cbar_label, labelpad=10, rotation=270)

# update RFs due to colorbar
for r in range(3):
    # update scatter RFs
    visualize.scatter_rfs(model, layer_id, updates={'sizes': None}, figsize=figsize, ax=axes[r])

In [None]:
fig.savefig('figures/spacing_heatmap_subtraction.pdf', dpi=600, bbox_inches='tight')

**Spacing, Attention Heatmap/Accuracy comparison**

In [None]:
figsize = (7.2, 1.8)
fig, axes = plt.subplots(1, 3, figsize=figsize)
spacing = 1.
key = 'radial'
exps = ['Attention', 'Spacing']
ends = ['_attn', '_space']

# get data
extras = pickle.load(open('results/PSNR_attention_10k.pkl', 'rb'))
extras2 = pickle.load(open('results/PSNR_spacing_10k.pkl', 'rb'))
heatmap = pickle.load(open('results/PSNR_heatmaps_10k.pkl', 'rb'))
extent = heatmap['extent'][0]

# get heatmap scores
diff_scores = []
diff_scores.append(heatmap['radial_hm'][2] - heatmap['radial_hm'][0])
diff_scores.append(heatmap['radial_hm'][1] - heatmap['radial_hm'][0])
vmax = torch.max(diff_scores[1][torch.isnan(diff_scores[1]).bitwise_not()])

# get accuracy and confidence intervals
acc_0 = extras[key+ends[0]][0][1]
acc_1 = extras2[key+ends[1]][-1][1]
ci_0 = np.abs(np.array(extras['confidence_intervals'][key+ends[0]][0]) - acc_0)
ci_1 = np.abs(np.array(extras2['confidence_intervals'][key+ends[1]][-1]) - acc_1)

# plot accuracy
axes[0].bar(exps, [acc_0, acc_1], 
            color='green', alpha=0.8)
axes[0].errorbar([0,1], [acc_0, acc_1], yerr=[ci_0, ci_1], capsize=2., 
                 color='green', alpha=0.8, fmt='none')
axes[0].set_ylabel('Accuracy (Proportion Correct)')
# asterisk based on testing radial min extent attention vs. radial 2-spacing
axes[0].plot(0.5, 0.6, marker='*', color='Black', markersize=4., markeredgewidth=0.25)

# set heatmaps
for i, end in enumerate(ends):
    # get flankers, axis, offset
    n_flankers, axis = get_crowd_params(key)
    # create crowd set
    spacing = i + 1.
    crowd_set = create_crowd_set(testset, 1, 118, n_flankers, axis, spacing,
                                 base_set=base_set, label_map=label_map)
    # update RF offset
    mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.2, 0., n_rings=10, std=1.,
                                                            offset=offset)
    model.layers[layer_id].forward_layer.pool.set(mu=mu, sigma=sigma)
    if i == 0:
        model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
    # get heatmap
    axes[i+1].set_xlabel(exps[i])
    visualize.heatmap(model, layer_id, scores=diff_scores[i].squeeze(0), cmap='Greens',
                      outline_rfs=True, input=crowd_set[0][0][0], RF_alpha=0.1,
                      vmin=0., vmax=vmax, ax=axes[i+1], show=False)
    
# colorbar
cbar = plt.colorbar(axes[2].get_images()[0], ax=fig.axes)
cbar_label = 'PSNR Change (dB)'
cbar.ax.set_ylabel(cbar_label, labelpad=10, rotation=270)

In [None]:
fig.savefig('figures/attention_spacing_heatmaps.pdf', dpi=300, bbox_inches='tight')

**Accuracy Figures**

In [None]:
task = 'Spacing'
extras = pickle.load(open('results/PSNR_%s_10k.pkl' % task.lower(), 'rb'))
fig, ax = plt.subplots(1,2, figsize=(4.8, 2.4))
colors = ['blue', 'orange','green','red']
linestyles = ['--', '-', '-.', ':']
end = '_space'

# plot spacing accuracy
for i, key in enumerate(['outer','inner','radial','tangential']):
    acc = np.array([x[1] for x in extras[key+end]]).reshape(-1, 1)
    ci = np.abs(np.array(extras['confidence_intervals'][key+end]) - acc).T
    ax[1].errorbar(extras['spacing'], acc, yerr=ci, capsize=2., alpha=0.8, color=colors[i],
                   linestyle=linestyles[i])
ax[1].set_ylabel('Accuracy (Proportion Correct)')
ax[1].set_xlabel('Target-Flanker Spacing (DVA)')
# ax[1].yaxis.grid(which='major', color='gray', alpha=0.8, linestyle='dashed', linewidth=1.)
ax[1].legend(['Outer','Inner','Radial','Tangential'])
ax[1].hlines(extras['none'+end][0][1], extras['spacing'][0], extras['spacing'][-1])
fig.tight_layout()

# create crowded stimulus with 4 flankers
crowd_set = create_crowd_set(testset, 1, 118, 4, 0., 1.5, 
                             base_set=None, label_map=label_map)

# create bounding boxes for the different configurations
centers = [(59+15,59), (59-15,59), (59, 59), (59, 59)]
widths = [56, 56, 90, 30]
heights = [26, 22, 30, 90]
visualize.bounding_box(ax[0], 4, centers, widths, heights, alpha=0.8, lw=2,
                       color=colors, linestyle=linestyles)
# add the RF array
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.2, 0., n_rings=10, std=1.,
                                                        offset=offset)
model.layers[layer_id].forward_layer.pool.set(mu=mu, sigma=sigma)
visualize.heatmap(model, '1', input=crowd_set[0][0][0], ax=ax[0])

In [None]:
fig.savefig('figures/%s_acc.pdf' % task.lower(), dpi=300, bbox_inches='tight')

In [None]:
task = 'Attention'
fig, ax = plt.subplots(2,3, figsize=(7.2, 2.8), gridspec_kw={'height_ratios': [1,4], 'hspace': 0.})
colors = ['blue', 'orange','green','red']
linestyles = ['--', '-', '-.', ':']
end = '_attn'

extras = pickle.load(open('results/PSNR_attention_10k.pkl', 'rb'))

ax[0,2].axis('off')
for i, key in enumerate(['outer','inner','radial','tangential']):
    acc = np.array([x[1] for x in extras[key+end]]).reshape(-1, 1)
    ci = np.abs(np.array(extras['confidence_intervals'][key+end]) - acc).T
    ax[1,2].errorbar(model.rf_to_image_space(layer_id, extras['extent'])[0] / 20.,
                     acc, yerr=ci, color=colors[i], capsize=2., alpha=0.8,
                     linestyle=linestyles[i])
ax[1,2].set_ylabel('Accuracy (Proportion Correct)')
ax[1,2].set_xlabel('Attentional Field Extent (DVA)')
# ax[1,2].yaxis.grid(which='major', color='gray', alpha=0.8, linestyle='dashed', linewidth=1.)
ax[1,2].legend(['Outer','Inner','Radial','Tangential'])
fig.tight_layout()
    
extent_types = ['Min.', 'Max.']
for i, extent in enumerate([7., 27.]):
    im_extent = model.rf_to_image_space(layer_id, extent)[0]
    priority_map = torch.zeros(118,118)
    priority_map[59,59] = 1./im_extent
    af = rf_pool.utils.lattice.gaussian_field(priority_map)
    max_y = np.max(af.numpy()) + 0.1
    
    mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.2, 0., n_rings=10, std=1., offset=offset)
    model.layers[layer_id].forward_layer.pool.set(mu=mu, sigma=sigma)
    model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
    visualize.heatmap(model, layer_id, show=False)
    
    if i == 1:
        af = af - np.min(af.numpy()) - 0.2
    ax[0,i].plot(np.arange(118), af[0,:,59], 'black')
    ax[0,i].axis('off')
    ax[0,i].set_ylim(0., max_y)
    
    extent_dva = model.rf_to_image_space(layer_id, extent)[0] / 20.
    ax[1,i].set_xlabel('%s Attentional Field Extent (%d DVA)' % (extent_types[i], extent_dva))
    visualize.heatmap(model, layer_id, ax=ax[1,i])

In [None]:
fig.savefig('figures/%s_acc.pdf' % task.lower(), dpi=300, bbox_inches='tight')

**Density/Size Figures**

In [None]:
file = 'results/PSNR_density_10k.pkl'
extras = pickle.load(open(file, 'rb'))

In [None]:
task = 'Density_Size'
fig, ax = plt.subplots(1,2, figsize=(4.8, 2.4))
colors = ['orange','black','blue']
linestyles = ['-', '-', '--']
end = ''

# create figure with density plot
for i, key in enumerate(['cost_0','cost_1','cost_2']):
    snr = np.array([x[0] for x in extras[key+end]]).reshape(-1, 1)
    ci = np.abs(np.array(extras['confidence_intervals'][key+end]) - snr).T
    ax[0].errorbar(extras['spacing'], snr, yerr=ci, color=colors[i], capsize=2., alpha=0.8,
                   linestyle=linestyles[i])
ax[0].set_ylabel('Peak Signal-to-Noise Ratio (dB)')
ax[0].set_xlabel('RF Spacing (Units of Sigma)')
# ax[0].yaxis.grid(which='major', color='gray', alpha=0.8, linestyle='dashed', linewidth=1.)
ax[0].legend(['$\sigma$ = %0.2f' % (model.rf_to_image_space(layer_id, extras[k][0])[0] / 20.)
              for k in ['sigma_0','sigma_1','sigma_2']])

In [None]:
file = 'results/PSNR_size_10k.pkl'
extras = pickle.load(open(file, 'rb'))

In [None]:
# add size plot
key = 'cost_0'
end = ''
colors = ['black']

for i, key in enumerate(['cost_0']):
    snr = np.array([x[0] for x in extras[key+end][1:]]).reshape(-1, 1)
    ci = np.abs(np.array(extras['confidence_intervals'][key+end][1:]) - snr).T
    ax[1].errorbar(model.rf_to_image_space(layer_id, extras['sigma_0'][1:])[0] / 20.,
                   snr, yerr=ci, color=colors[i], capsize=2., alpha=0.8)
ax[1].set_ylabel('Peak Signal-to-Noise Ratio (dB)')
ax[1].set_xlabel('RF Sigma Size (DVA)')
# ax[1].yaxis.grid(which='major', color='gray', alpha=0.8, linestyle='dashed', linewidth=1.)

fig.tight_layout()

In [None]:
fig.savefig('figures/%s_snr.pdf' % task.lower(), dpi=300, bbox_inches='tight')

**MISC**

In [None]:
fig, ax = plt.subplots(2,2, figsize=(5., 2.8), gridspec_kw={'height_ratios': [1,4], 'hspace': 0.})

for i, extent in enumerate([0., 10.]):
    if extent > 0.:
        im_extent = model.rf_to_image_space(layer_id, extent)[0]
        priority_map = torch.zeros(118,118)
        priority_map[59,59] = 1./im_extent
        af = rf_pool.utils.lattice.gaussian_field(priority_map)
    else:
        af = torch.ones(1,118,118)

    model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
    idx = [33,51]
    visualize.heatmap(model, layer_id, input=target_set[0][0][0], ax=ax[1,i],
                      RF_linestyles=['dashed' if i in idx else 'solid' for i in range(n_kernels)],
                      RF_alpha=1.,
                      RF_linewidths=[1. if i in idx else 0.25 for i in range(n_kernels)],
                      RF_edgecolors=['red' if i in idx else 'black' for i in range(n_kernels)])
    
    ax[0,i].plot(np.arange(118), af[0,:,59], 'black')
    ax[0,i].axis('off')
    ax[1,i].set_xlabel(['No Attention','Attention'][i])

In [None]:
fig.savefig('figures/example_heatmap.pdf', bbox_inches='tight', dpi=600.)