In [None]:
from modules import *
from functions import *

In [None]:
N = int(500)                                                        # Number of neurons
a = 10                                                              # Parameters for propensity function
theta_stim = 90                                                     # Angle to stimulate at 
n_test_angles = 100                                                 # Number of angles to use to test for preferred orientation
vars = np.random.lognormal(2, 0.6, N)                               # Width of each neuron's tuning curve
learning_rate = 0.01                                                # Learning rate

n_norm_per_day = 5                                                  # How many times to normalise the weights per day     
n_steps_per_norm = 30                                               # How many orientation stimuli per day
init_steps = 300                                                    # Number of trials to run to settle to a baseline

hebb_scaling = 0.3                                                  # Scaling of Hebbian component
rand_scaling = 1                                                    # Scaling of random component 

n_days_deprv = 28                                                   # Number of days of stripe-rearing
n_days_recovery = 1000                                              # Number of days of recovery
n_steps_deprv = n_days_deprv * n_steps_per_norm * n_norm_per_day  
n_steps_recovery = n_days_recovery * n_steps_per_norm * n_norm_per_day

n_repeats = 10                                                      # Number of repeats

POS_pre_dep = []; POS_post_dep = []; POS_post_rec = []

for repeat in range(n_repeats):

    W_init = initialise_W(N, vars)                                                                       # Initialise weights 
    W_baseline = prerun(W_init, theta_stim, a, 0.4, 1, learning_rate, n_steps_per_norm, init_steps)      # Settle to a baseline

    """ Stripe-rearing """ 

    POs = []; W = np.zeros((N, N, n_steps_deprv+1)); W[:, :, 0] = W_baseline
    
    for t in tqdm(range(n_steps_deprv)):
        W_old = W[:, :, t]
        H = single_hebbian_component(N, W_old, theta_stim, type='stripe_rearing')
        eta = np.random.randn(N, N)
        prop_function = propensity(W_old, a)
        hebb = hebb_scaling * H * prop_function
        rand = rand_scaling * eta * prop_function
        W_new = W_old + (hebb + rand) * learning_rate
        if t % n_steps_per_norm == 0: normalisation(W_new)
        W[:, :, t+1] = W_new

        if t==0: PO_pre_dep = get_preferred_orientations(N, W_new, n_angles=n_test_angles)
        if t==n_steps_deprv-1: PO_post_dep = get_preferred_orientations(N, W_new, n_angles=n_test_angles)
        
    W_post_deprivation = W[:, :, -1]

    """ Recovery """ 

    POs = []; W = np.zeros((N, N, n_steps_recovery+1)); W[:, :, 0] = W_post_deprivation


    for t in tqdm(range(n_steps_recovery)):
        W_old = W[:, :, t]
        H = single_hebbian_component(N, W_old, theta_stim, type='baseline')
        eta = np.random.randn(N, N)
        prop_function = propensity(W_old, a)
        hebb = hebb_scaling * H * prop_function
        rand = rand_scaling * eta * prop_function
        W_new = W_old + (hebb + rand) * learning_rate
        if t % n_steps_per_norm == 0: normalisation(W_new)
        W[:, :, t+1] = W_new

        if t==n_steps_recovery-1: PO_post_rec = get_preferred_orientations(N, W_new, n_angles=n_test_angles)

    POS_pre_dep.append(PO_pre_dep); POS_post_dep.append(PO_post_dep); POS_post_rec.append(PO_post_rec)


In [None]:
rPOS_pre_dep = [np.abs(PO_pre_dep - theta_stim) for PO_pre_dep in POS_pre_dep]
rPOS_post_dep = [np.abs(PO_post_dep - theta_stim) for PO_post_dep in POS_post_dep]
rPOS_post_rec = [np.abs(PO_post_rec - theta_stim) for PO_post_rec in POS_post_rec]
convergences_during_deprivation = [rPO_pre_dep - rPO_post_dep for rPO_pre_dep, rPO_post_dep in zip(rPOS_pre_dep, rPOS_post_dep)]
convergences_during_recovery = [rPO_post_dep - rPO_post_rec for rPO_post_dep, rPO_post_rec in zip(rPOS_post_dep, rPOS_post_rec)]

median_convergence_during_deprivation_28 = np.mean(np.median(convergences_during_deprivation, axis=1))
median_convergence_during_recovery_1000 = np.mean(np.median(convergences_during_recovery, axis=1))

std_convergence_during_deprivation_28 = np.std(np.median(convergences_during_deprivation, axis=1))
std_convergence_during_recovery_1000 = np.std(np.median(convergences_during_recovery, axis=1))

## Supp. Fig. 3k

In [None]:
def supp_fig_3k(POS_pre_dep, POS_post_dep, POS_post_rec):

    fig, ax = plt.subplots(1, 3, figsize=(12, 3), dpi=150)

    nbins = 50
    c1 = 'firebrick'; c2 = 'coral'; c3 = 'goldenrod'

    init_distribution = np.concatenate(POS_pre_dep, axis=0)
    sr_distribution = np.concatenate(POS_post_dep, axis=0)
    rec_distribution = np.concatenate(POS_post_rec, axis=0)

    ax[0].hist(init_distribution, bins=nbins, histtype='stepfilled', color=c1, alpha=0.5)
    ax[1].hist(sr_distribution, bins=nbins, histtype='stepfilled', color=c2, alpha=0.5)
    ax[2].hist(rec_distribution, bins=nbins, histtype='stepfilled', color=c3, alpha=0.5)

    ax[0].set_ylabel('frequency', labelpad=20)
    ax[0].set_xlabel('PO initial'); ax[1].set_xlabel('PO post deprivation (28 days)'); ax[2].set_xlabel('PO post recovery (1000 days)')

    dict = {ax[0]: init_distribution, ax[1]: sr_distribution, ax[2]: rec_distribution}

    for a, data in dict.items():
        result = kstest(data, 'uniform', args=(0, 180))
        if result.pvalue < 0.05: a.text(0.4, 0.2, f"D: {result.statistic:.3f}\np: {result.pvalue:.3f}", transform=a.transAxes, fontsize=8, color='red')
        else: a.text(0.4, 0.2, f"D: {result.statistic:.3f}\np: {result.pvalue:.3f}", transform=a.transAxes, fontsize=8, color='black')

    for ax in ax: 
        ax.set_xlim(0, 180); ax.set_xticks([0, 90, 180]); ax.set_yticks([])

    fig.show()

## Fig. 3l

In [None]:
def get_correlation_recovery(total_days):

    n_steps_recovery = n_steps_per_norm * n_norm_per_day * total_days
    n_steps_deprivation = n_steps_per_norm * n_norm_per_day * 28

    # Initialise weights for stripe-rearing 
    W_init = initialise_W(N, vars)                                                                                         
    W_baseline = prerun(W_init, theta_stim, a, 0.4, 1, learning_rate, n_steps_per_norm, init_steps)                        
    W = np.zeros((N, N, n_steps_deprivation+1)); W[:, :, 0] = W_baseline
    
    # Stripe rearing period 
    for t in range(n_steps_deprivation):
        W_old = W[:, :, t]
        H = single_hebbian_component(N, W_old, theta_stim, type='stripe_rearing')
        eta = np.random.randn(N, N)
        prop_function = propensity(W_old, a)
        hebb = hebb_scaling * H * prop_function
        rand = rand_scaling * eta * prop_function
        W_new = W_old + (hebb + rand) * learning_rate
        if t % n_steps_per_norm == 0: normalisation(W_new)
        W[:, :, t+1] = W_new

        if t==0: PO_pre_dep = get_preferred_orientations(N, W_new, n_angles=n_test_angles)
        if t==n_steps_deprivation-1: PO_post_dep = get_preferred_orientations(N, W_new, n_angles=n_test_angles)
        
    W_post_deprivation = W[:, :, -1]

    rPO_pre_dep = np.abs(PO_pre_dep - theta_stim); rPO_post_dep = np.abs(PO_post_dep - theta_stim)
    convergence_during_deprivation = rPO_pre_dep - rPO_post_dep

    # Recovery period
    W_rec_step = np.zeros((N, N, n_steps_recovery+1)); 
    W_rec_day = np.zeros((N, N, total_days+1)); 
    correlations = np.zeros(total_days+1)
    medians_recovery = np.zeros(total_days+1)
    median_deprivation = np.median(convergence_during_deprivation)

    POs_rec = np.zeros((N, total_days+1)); POs_rec[:, 0] = get_preferred_orientations(N, W_post_deprivation, n_angles=n_test_angles)
    
    # Initialise weights for recovery as those after stripe-rearing 
    W_rec_step[:, :, 0] = W_post_deprivation

    for t in range(n_steps_recovery):
        W_new = evolve_weights(N, W_rec_step[:, :, t], t, 'baseline', theta_stim, a, learning_rate, hebb_scaling, rand_scaling, n_steps_per_norm)
        W_rec_step[:, :, t+1] = W_new
        if t % (n_steps_per_norm * n_norm_per_day) == 0: 
            POs_rec[:, int(t/(n_steps_per_norm * n_norm_per_day))+1] = get_preferred_orientations(N, W_new, n_angles=n_test_angles)      
            W_rec_day[:, :, int(t/(n_steps_per_norm * n_norm_per_day))+1] = W_new

    convergence_during_recoveries_after_this_sr_period = [np.abs(PO_post_dep - theta_stim) - np.abs(PO - theta_stim) for PO in POs_rec.T]

    for rec_day in range(total_days+1):
        correlations[rec_day] = np.corrcoef(convergence_during_deprivation, convergence_during_recoveries_after_this_sr_period[rec_day])[0, 1] 
        medians_recovery[rec_day] = np.median(convergence_during_recoveries_after_this_sr_period[rec_day])

    return correlations, medians_recovery, median_deprivation

In [None]:
total_days = 80
repeats = 50

all_correlations = np.zeros((repeats, total_days+1))
all_medians_recovery = np.zeros((repeats, total_days+1))
medians_deprivation = np.zeros(repeats)

for repeat in tqdm(range(repeats)):
    all_correlations[repeat, :], all_medians_recovery[repeat, :], medians_deprivation[repeat] = get_correlation_recovery(total_days)

In [None]:
def fig_3l(all_medians_recovery, median_convergence_during_recovery_1000, std_convergence_during_recovery_1000, median_convergence_during_deprivation_28):

    fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=180)
    ax2 = ax.twinx()
    end_point = 95

    ax.plot(np.mean(all_medians_recovery,  axis=0), label='Median')
    ax.fill_between(np.arange(total_days+1), np.mean(all_medians_recovery,  axis=0) - np.std(all_medians_recovery,  axis=0), np.mean(all_medians_recovery,  axis=0) + np.std(all_medians_recovery,  axis=0), alpha=0.2)

    ax.set_xlabel('recovery [days]')
    ax.set_ylabel(r'convergence during recovery $[\degree]$', color='steelblue', labelpad=10)

    ax.scatter(end_point, median_convergence_during_recovery_1000, c='steelblue', s=20, clip_on=False)
    ax.errorbar(end_point, median_convergence_during_recovery_1000, yerr=std_convergence_during_recovery_1000, fmt='o', c='steelblue', capsize=2, capthick=1, elinewidth=1, markersize=4, zorder=2, clip_on=False)

    ax2.axhline(median_convergence_during_deprivation_28, c='darkgreen', ls='--', lw=1)

    ax.set_xticks([0, 20, 40, 60, 80, end_point])
    ax.set_xticklabels([0, 20, 40, 60, 80, '1000'])

    ax.set_ylim(-4.2, 0.2)
    ax2.set_ylim(0, 4.2)
    ax.set_xlim(0, end_point+5)
    ax.invert_yaxis()
    polygon = Polygon(np.array([[85, -5], [90, -5], [90, 0.5], [85, 0.5]]), closed=False, color='white', lw=0.5, clip_on=False, zorder=3)
    ax2.add_patch(polygon)
    ax2.spines['right'].set_visible(True)
    ax2.set_ylabel(r'convergence during deprivation  $[\degree]$', color='darkgreen', labelpad=10)
    fig.show()
