In [None]:
%matplotlib inline

In [None]:
import numpy as np
from scipy import stats
import matplotlib.pylab as plt
from mpl_toolkits.axes_grid1 import ImageGrid, AxesGrid, make_axes_locatable, SubplotDivider

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../src'))

from figure_presets import *
from plotting_functions import *

from adaptive_response.adaptive_threshold import AdaptiveThresholdNumeric
from utils.numba.tools import random_seed

In [None]:
Nl = 256
Nr = 32
alpha = 1.3
s = 0.1 * Nl
width = 1

In [None]:
parameters = {'c_distribution': 'log-normal'}

model = AdaptiveThresholdNumeric(Nl, Nr, parameters=parameters)
model.threshold_factor = alpha
model.choose_commonness('const', mean_mixture_size=s)
model.c_means = 1
model.c_vars = 1
model.choose_sensitivity_matrix('log-normal', mean_sensitivity=1, width=width)

init_state = model.parameters['initialize_state']
init_state['c_mean'] = 'exact'
init_state['c_var'] = 'exact'
init_state['correlations'] = 'exact'

In [None]:
darkblue = "#02324F"
darkorange = "#914301"

random_seed(14)

model.choose_sensitivity_matrix('log-normal', mean_sensitivity=1, width=width)

ymax = 79
trans = 2/3

with figure_file(
        'histogram_first_receptor.pdf',
        fig_width_pt=200., crop_pdf=False, legend_frame=False,
        transparent=True, post_process=False,
        ) as fig:
    
    en_plot = next(model._sample_excitations(1))

    ax = plt.gca()
    bounds = ax.get_position().bounds
    plt.delaxes(ax)

    grid = AxesGrid(fig, bounds,  
                    nrows_ncols=(2, 2),
                    axes_pad=0.1,  # pad between axes in inch.
                    share_all=False)    
    
    for ax_k, factor in enumerate((1, 2)):
        # add the histogram
        ax = grid[2*ax_k]
        ax.set_aspect(0.19)

        en = en_plot.copy()
        en[0] *= factor

        xs = np.arange(len(en)) + 1
        bars = ax.bar(xs - 0.5, en, width=1,
                      color=COLOR_BLUE, edgecolor='none', lw=0)

        ax.axhline(alpha * en.mean(), color=COLOR_RED)

        bars[0].set_color(COLOR_ORANGE)
        
        for i in np.flatnonzero(en > alpha * en.mean()):
            if i == 0:
                bars[i].set_color(darkorange)
            else:
                bars[i].set_color(darkblue)

        # add histogram
        axHist = grid[2*ax_k + 1]
        axHist.set_aspect(0.0006)

        ax.set_xlim(0.5, len(en) + 0.5)
        ax.set_ylim(0, ymax)
        ax.set_yticks(np.arange(0, ymax, 20))

        ax.set_ylabel('$e_n$')

        bins, height = np.linspace(*ax.get_ylim(), num=64, retstep=True)
        bars0 = np.zeros(len(bins) + 1)
        bars1 = np.zeros(len(bins) + 1)
        for _ in range(100):
            model.choose_sensitivity_matrix('log-normal', mean_sensitivity=1, width=width)
            for en in model._sample_excitations(1000): # 100000
                bars0[np.searchsorted(bins, factor * en[0])] += 1
                for e in en[1:]:
                    bars1[np.searchsorted(bins, e)] += 1

        norm = bars0.sum() + bars1.sum()
        bars0 /= height * norm / Nr
        bars1 /= height * norm * (Nr - 1) / Nr
        barsl, barsr = bars0, bars1

        en_mean = model.excitation_statistics_estimate()['mean'].mean()
        en_mean *= (factor + Nr - 1)/Nr
        en_thresh = alpha * en_mean

        idx = np.flatnonzero(bins > en_thresh)[0]
        
        # distribution for all other receptors
        axHist.barh(bins[:idx], barsr[:idx], height=height,
                    color=COLOR_BLUE, edgecolor='none', lw=0, alpha=trans)
        axHist.barh(bins[idx:], barsr[idx:-1], height=height,
                    color=darkblue, edgecolor='none', lw=0, alpha=trans)
        
        # distribution for first receptor
        axHist.barh(bins[:idx], barsl[:idx], height=height,
                    color=COLOR_ORANGE, edgecolor='none', lw=0, alpha=trans)
        axHist.barh(bins[idx:], barsl[idx:-1], height=height,
                    color=darkorange, edgecolor='none', lw=0, alpha=trans)

        axHist.axhline(en_thresh, color=COLOR_RED)

        axHist.set_xlim(0, 0.06)
        axHist.set_xticks([0, 3e-2])
        axHist.set_xticklabels(['0', '0.03'])

    ax.set_xlabel('Channel $n$');
    axHist.set_xlabel('Frequency');
    
print('Finished')