In [None]:
from modules import *
from functions import *

In [None]:
N = int(500)                                                        # Number of neurons
a = 10                                                              # Parameters for propensity function
theta_stim = 90                                                     # Angle to stimulate at 
n_test_angles = 100                                                 # Number of angles to use to test for preferred orientation
vars = np.random.lognormal(2, 0.6, N)                               # Width of each neuron's tuning curve
learning_rate = 0.01                                                # Learning rate

n_days = 100                                                        # Number of days to run for
n_norm_per_day = 1                                                  # How many times to normalise the weights per day    
n_steps_per_norm = 30                                               # How many orientation stimuli per day
n_steps = n_steps_per_norm * n_norm_per_day * n_days                # Number of steps to run for
init_steps = 300                                                    # Number of trials to run to settle to a baseline

def get_correlation(ratio, n_steps, n_repeats):

    hebb_scaling = ratio 
    correlations = np.zeros((n_repeats, n_days))

    for repeat in range(n_repeats):
        W_init = initialise_W(N, vars)                                                                                          # Initialise weights 
        W_baseline = prerun(W_init, theta_stim, a, hebb_scaling, 1, learning_rate, n_steps_per_norm, init_steps)                # Settle to a baseline
        POs = get_POs_over_trials(W_baseline, n_steps, a, hebb_scaling, 1, learning_rate, theta_stim, n_steps_per_norm, n_norm_per_day, n_test_angles, 'stripe_rearing')
        correlations[repeat, :] = get_r_values(POs, theta_stim, n_days, N)
    return np.mean(correlations, axis=0)

n_repeats = 10
nx = n_days
ny = 20

r_matrix = np.zeros((ny, nx))
for i, ratio in enumerate(tqdm(np.linspace(0, 1, ny))):
    r_matrix[i, :] = get_correlation(ratio, n_steps, n_repeats)

## Fig. 3h

In [None]:
def fig_3h(r_matrix):

    fig, ax = plt.subplots(1, 1, figsize=(4, 2), dpi=180)
    n_ticks = 3
    cbar_max = 0.5
    im1 = ax.imshow(r_matrix,  cmap='RdBu_r', vmin=-cbar_max, vmax=cbar_max)
    cbar = fig.colorbar(im1, ax=ax, fraction=0.026, pad=0.04)
    cbar.ax.set_yticks(np.round(np.linspace(-cbar_max, cbar_max, 3), 1))
    cbar.ax.tick_params(labelsize=9)
    cbar.ax.set_ylabel(r'mean correlation', rotation=-90, va="bottom", fontsize=9, labelpad=2)
    ax.set_xlabel(r'days', labelpad=5)
    ax.set_ylabel(r'ratio H to $\xi$', labelpad=1)
    ax.set_xticks(np.linspace(-0.5, nx, n_ticks)); ax.set_xticklabels((np.linspace(0, n_days, n_ticks)).astype(int))
    ax.set_yticks(np.linspace(-0.5, ny-0.2, 2)); ax.set_yticklabels((np.linspace(0, 1, 2).astype(int)))
    ax.invert_yaxis()
    ax.set_facecolor('k')

    ax.set_ylim([-0.5, ny-0.2])
    ax.set_xlim([-0.5, nx+0.3])

    ax.axhline(0.3*ny, color='k', linestyle='--', linewidth=1)
    ax.axvline(28, color='k', linestyle='--', linewidth=1)
    ax.scatter(28, 0.3*ny, ec='k', fc='w', s=30, zorder=3)
    ax.set_aspect(1.0/ax.get_data_ratio()*1)
    fig.show()