In [None]:
from modules import *
from functions import *

In [None]:
def prerun(W_init, n_trials):
    w = W_init
    for t in range(n_trials):
        H = single_hebbian_component(N, w, theta_stim, type='baseline'); eta = np.random.randn(N, N); prop_function = propensity(w, a)
        w += (hebb_scaling * H * prop_function + rand_scaling * eta * prop_function) * learning_rate
        if t % n_steps_per_norm == 0: normalisation(w)
    return w

def evolve_W(W_old, t, type):
    H = single_hebbian_component(N, W_old, theta_stim, type=type)
    eta = np.random.randn(N, N)
    prop_function = propensity(W_old, a)
    hebb =  hebb_scaling * H * prop_function
    rand =  rand_scaling * eta * prop_function
    W_new = W_old + (hebb + rand) * learning_rate
    if t % n_steps_per_norm == 0: 
        normalisation(W_new)
        if t % (n_steps_per_norm * n_norm_per_day) == 0: PO = get_preferred_orientations(N, W_old, n_angles=n_test_angles); POs.append(PO)    
    return W_new

def get_POs_over_trials(W_init, n_steps, type):
    global POs; POs = []
    W = np.zeros((N, N, n_steps+1)); W[:, :, 0] = W_init
    for t in tqdm(range(n_steps)):
        W[:, :, t+1] = evolve_W(W[:, :, t], t, type)
    return POs

def get_POs_over_trials_7(W_baseline, n_steps):
    global POs; POs = []; 
    W = np.zeros((N, N, n_steps+1)); W[:, :, 0] = W_baseline
    for t in tqdm(range(n_steps)):
        if t % (n_steps_per_norm * n_norm_per_day * 7) == 0 and t != 0:
            t_if = t  
            while t < t_if + (n_steps_per_norm * n_norm_per_day) * 7:
                W[:, :, t+1] = evolve_W(W[:, :, t], t, 'test')
                t += 1  
        else:
           W[:, :, t+1] = evolve_W(W[:, :, t], t, 'stripe_rearing')
    return POs


## Metrics over days

In [None]:
N = int(500)                                                        # Number of neurons
a = 10                                                              # Parameters for propensity function
theta_stim = 90                                                     # Angle to stimulate at 
n_test_angles = 100                                                 # Number of angles to use to test for preferred orientation
vars = np.random.lognormal(2, 0.6, N)                               # Width of each neuron's tuning curve
learning_rate = 0.01                                                # Learning rate

n_days = 28                                                         # Number of days to run for
n_norm_per_day = 1                                                  # How many times to normalise the weights per day  
n_steps_per_norm = 30                                               # How many orientation stimuli per day
n_steps = n_steps_per_norm * n_norm_per_day * n_days                # Number of steps to run for
init_steps = 300                                                    # Number of trials to run to settle to a baseline

hebb_scaling = 0.3                                                  # Scaling of Hebbian component  
rand_scaling = 1                                                    # Scaling of random component 

W_init = initialise_W(N, vars)                                      # Initialise weights 
W_baseline = prerun(W_init, init_steps)                             # Settle to a baseline

eo = 2                                                              # Plot every other "x" values (for visual clarity)

""" Baseline """ 

POs = get_POs_over_trials(W_baseline, n_steps, 'baseline')
drift_magnitude_baseline, drift_rate_baseline, convergence_baseline = get_metrics(N, n_days, theta_stim, POs)


""" 28 day stripe-rearing """

POs = get_POs_over_trials(W_baseline, n_steps, 'stripe_rearing')
drift_magnitude_28, drift_rate_28, convergence_28 = get_metrics(N, n_days, theta_stim, POs)

""" 7 day stripe-rearing """ 

POs = get_POs_over_trials_7(W_baseline, n_steps)
drift_magnitude_7, drift_rate_7, convergence_7 = get_metrics(N, n_days, theta_stim, POs)


## Fig. 3d

In [None]:
def fig_3d(drift_magnitude_28, drift_magnitude_baseline, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_28, axis=1)[:-1][::eo], c='green', ls='-', marker='o', ms=4, label='Deprivation 28 days', clip_on=False)
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_baseline, axis=1)[:-1][::eo], c='black', ls='-', marker='o', ms=4, label='Baseline', clip_on=False)
    ax.set_ylim([0, 5]); ax.set_yticks([0, 5])
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'drift magnitude $ \; [\degree]$')
    ax.set_xlim(0, 30)
    ax.legend(frameon=False, fontsize=8)
    fig.show()

## Fig. 3e

In [1]:
def fig_3e(drift_rate_28, drift_rate_baseline, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_rate_28, axis=1)[:-1][::eo], c='green', ls='-', marker='o', ms=4, label='Deprivation 28 days', clip_on=False)
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_rate_baseline, axis=1)[:-1][::eo], c='black', ls='-', marker='o', ms=4, label='Baseline', clip_on=False)
    ax.set_ylim([0, 5]); ax.set_yticks([0, 5])
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'drift rate $ \; [\degree / $ day $]$')
    ax.set_xlim(0, 30)
    ax.legend(frameon=False, fontsize=8)
    fig.show()

## Supp. Fig. 3j (left)

In [None]:
def supp_fig_3j_left(drift_magnitude_7, drift_magnitude_28, drift_magnitude_baseline, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2), dpi=180)
    ax.plot(np.arange(1, n_days+1)[::eo], np.median(drift_magnitude_7[:-1], axis=1)[::eo], c='orange', ls='-', marker='o', ms=4, label='Deprivation 7 days')
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_28[:-1], axis=1)[::eo], c='green', ls='-', marker='o', ms=4, label='Deprivation 28 days')
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_baseline[:-1], axis=1)[::eo], c='black', ls='-', marker='o', ms=4, label='Baseline')
    ax.set_ylim([0, 5]); ax.set_yticks([0, 5])
    ax.set_xlim([0, n_days+1])
    for x in np.arange(7, n_days+1, 7): ax.axvline(x=x, c='k', ls='--', alpha=0.1, lw=1)
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'drift magnitude $ \; [\degree]$')
    ax.set_xlim(0, 30)
    ax.legend(frameon=False, fontsize=8)
    fig.show()

## Supp. Fig. 3j (right)

In [None]:
def supp_fig_3j_right(convergence_7, convergence_28, convergence_baseline, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_7, axis=1)[::eo], c='orange', ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_28, axis=1)[::eo], c='green', ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_baseline, axis=1)[::eo], c='black', ls='-', marker='o', ms=4)
    ax.set_ylim([-2, 5]); ax.set_xlim(0, 30)
    for x in np.arange(7, n_days+1, 7): ax.axvline(x=x, c='k', ls='--', alpha=0.1, lw=1)
    ax.locator_params(axis='y', nbins=2)
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'convergence $ \; [\degree]$')
    fig.show()

## Fig. 3f

In [None]:
def fig_3f(convergence_28, convergence_baseline, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_28, axis=1)[::eo], c='green', ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_baseline, axis=1)[::eo], c='black', ls='-', marker='o', ms=4)
    ax.set_ylim([-2, 5]); ax.set_xticks([0, 7, 14, 21, 28])
    ax.locator_params(axis='y', nbins=2)
    ax.set_xlabel('Time since start [days]')
    ax.set_ylabel(r'Convergence $ \; [\degree]$')
    fig.show()

## "Knockouts"

In [None]:
np.random.seed(5)                                                   # Set random seed
N = int(500)                                                        # Number of neurons

a=10                                                                # Parameters for propensity function
theta_stim = 90                                                     # Angle to stimulate at 
n_test_angles = 100                                                 # Number of angles to use to test for preferred orientation
vars = np.random.lognormal(2, 0.6, N)                               # Width of each neuron's tuning curve
learning_rate = 0.01                                                # Learning rate

n_days = 28                                                         # Number of days to run for
n_norm_per_day = 2                                                  # How many times to normalise the weights per day     --> If too high, leads to greater drift rate?
n_steps_per_norm = 15       # from 30                               # How often to normalise the weights  --> If too high (>20), leads to greater drift magnitude
n_steps = n_steps_per_norm * n_norm_per_day * n_days                # Number of steps to run for
init_steps = 300                                                    # Number of trials to run to settle to a baseline

rand_scaling = 1
hebb_scaling = 0.3      
W_init = initialise_W(N, vars)                                      # Initialise weights 
W_baseline = prerun(W_init, init_steps)                             # Settle to a baseline

hebb_color = 'deeppink'                                             # Colours for plotting
rand_color = 'deepskyblue'
opt_color = 'black'


""" ----------  BASELINE ---------- """

# Hebbian only 

rand_scaling = 0
hebb_scaling = 1.3

POs = get_POs_over_trials(W_baseline, n_steps, 'baseline')
drift_magnitude_baseline_hebb, drift_rate_baseline_hebb, convergence_baseline_hebb = get_metrics(N, n_days, theta_stim, POs)

# Random only

rand_scaling = 1.3
hebb_scaling = 0

POs = get_POs_over_trials(W_baseline, n_steps, 'baseline')
drift_magnitude_baseline_rand, drift_rate_baseline_rand, convergence_baseline_rand = get_metrics(N, n_days, theta_stim, POs)

# Optimal 

rand_scaling = 1
hebb_scaling = 0.3

POs = get_POs_over_trials(W_baseline, n_steps, 'baseline')
drift_magnitude_baseline_opt, drift_rate_baseline_opt, convergence_baseline_opt = get_metrics(N, n_days, theta_stim, POs)



""" ----------  STRIPE-REARING ---------- """

""" Hebbian only """ 

rand_scaling = 0
hebb_scaling = 1.3

POs = get_POs_over_trials(W_baseline, n_steps, 'stripe_rearing')
drift_magnitude_sr_hebb, drift_rate_sr_hebb, convergence_sr_hebb = get_metrics(N, n_days, theta_stim, POs)

""" Random only """

rand_scaling = 1.3
hebb_scaling = 0

POs = get_POs_over_trials(W_baseline, n_steps, 'stripe_rearing')
drift_magnitude_sr_rand, drift_rate_sr_rand, convergence_sr_rand = get_metrics(N, n_days, theta_stim, POs)

""" Optimal """

rand_scaling = 1
hebb_scaling = 0.3

POs = get_POs_over_trials(W_baseline, n_steps, 'stripe_rearing')
drift_magnitude_sr_opt, drift_rate_sr_opt, convergence_sr_opt = get_metrics(N, n_days, theta_stim, POs)

## Supp. Fig. 3h (left)

In [None]:
def supp_fig_3h_left(drift_magnitude_baseline_hebb, drift_magnitude_baseline_rand, drift_magnitude_baseline_opt, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_baseline_hebb[:-1], axis=1)[::eo], c=hebb_color, ls='-', marker='o', ms=4, label='Hebbian only')
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_baseline_rand[:-1], axis=1)[::eo], c=rand_color, ls='-', marker='o', ms=4, label='Random only')
    ax.plot(np.arange(1, n_days)[::eo], np.median(drift_magnitude_baseline_opt[:-1], axis=1)[::eo], c=opt_color, ls='-', marker='o', ms=4, label='Optimal')
    ax.locator_params(axis='y', nbins=2)
    ax.set_ylim([0, 5]); ax.set_yticks([0, 5]); ax.set_xticks([0, 7, 14, 21, 28])
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'drift magnitude $ \; [\degree]$')
    fig.show()

## Supp. Fig. 3h (right)

In [None]:
def supp_fig_3h_right(drift_rate_baseline_hebb, drift_rate_baseline_rand, drift_rate_baseline_opt, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.mean(drift_rate_baseline_hebb, axis=1)[::eo], c=hebb_color, ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.mean(drift_rate_baseline_rand, axis=1)[::eo], c=rand_color, ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.mean(drift_rate_baseline_opt, axis=1)[::eo], c=opt_color, ls='-', marker='o', ms=4)
    ax.set_ylim([0, 5]); ax.set_yticks([0, 5]); ax.set_xticks([0, 7, 14, 21, 28])
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'drift rate $ \; [\degree / $ day $]$')
    fig.show()

## Supp. Fig. 3i

In [None]:
def supp_fig_3i(convergence_sr_hebb, convergence_sr_rand, convergence_sr_opt, n_days):
    fig, ax = plt.subplots(1, 1, figsize=(2.5, 2), dpi=180)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_sr_hebb, axis=1)[::eo], c=hebb_color, ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_sr_rand, axis=1)[::eo], c=rand_color, ls='-', marker='o', ms=4)
    ax.plot(np.arange(1, n_days)[::eo], np.median(convergence_sr_opt, axis=1)[::eo], c='green', ls='-', marker='o', ms=4)
    ax.set_ylim([-2, 8]); ax.set_yticks([0, 5]); ax.set_xticks([0, 7, 14, 21, 28])
    ax.locator_params(axis='y', nbins=2)
    ax.set_xlabel('time since start [days]')
    ax.set_ylabel(r'convergence $ \; [\degree]$')
    fig.show()

## Fig. 3k

In [None]:
def fig_3k(drift_magnitude_baseline_opt, drift_magnitude_baseline_rand, drift_magnitude_baseline_hebb):
    colors = [opt_color, rand_color, hebb_color]
    magnitude_values = [np.median(drift_magnitude_baseline_opt[-1]), np.median(drift_magnitude_baseline_rand[-1]), np.median(drift_magnitude_baseline_hebb[-1])] 
    fig, ax = plt.subplots(figsize=(2, 3), dpi=180)
    ax.bar([0, 3, 6], magnitude_values, color=colors, width=1, label='magnitude')
    ax.set_ylabel(r'baseline drift magnitude $ \; [\degree]$')
    ax.set_xticks([0, 3, 6]); ax.set_xlim([-1, 7]); ax.set_ylim([0, 9]); ax.set_yticks([0, 5])
    ax.set_xticklabels(['model', 'random only', 'hebbian only'], rotation=50, ha='right')
    fig.tight_layout()
    fig.show()

## Fig. 3j

In [None]:
def fig_3j(convergence_sr_opt, convergence_sr_rand, convergence_sr_hebb):
    colors = ['green', rand_color, hebb_color]
    convergence_values = [np.median(convergence_sr_opt[-1]), np.median(convergence_sr_rand[-1]), np.median(convergence_sr_hebb[-1])]
    fig, ax = plt.subplots(figsize=(2, 3), dpi=180)
    ax.bar([0, 3, 6], convergence_values, color=colors, width=1, label='convergence')
    ax.set_ylabel(r'convergence $ \; [\degree]$')
    ax.set_xticks([0, 3, 6]); ax.set_xlim([-1, 7]); ax.set_ylim([0, 9]); ax.set_yticks([0, 5])
    ax.set_xticklabels(['model', 'random only', 'hebbian only'], rotation=50, ha='right')
    fig.tight_layout()
    fig.show()