In [1]:
# Numerics
import numpy as np
import jax
import jax.numpy as jnp

# Functional
from typing import List, Callable
from functools import partial

# Own
from data_gen_x import gen_data_x
from models import *

## Functions

In [2]:
def simulation(
    seed: int,
    num_it: int = 2000,
    cond: Callable = an_cond,
    lr_sgd: float = 1e-2,
    lr_hebb: float = 4e-3,
    lam_sgd: float = 1e-3,
    lam_hebb: float = 1e-1
):
    """
    Prepare test and validation data and run simulation for one-layer model.
    :param seed: Random PRNG seed
    :param num_it: Number of iterations to run
    :param cond: Function specifying if iteration i is an active or only passive trial.
    :param lr_sgd: SGD learning rate for v
    :param lr_hebb: Hebbian learning rate for v
    :param lam_sgd: SGD weight decay for v
    :param lam_hebb: Hebbian weight decay for v
    """
    jax_key = jax.random.PRNGKey(seed)

    # Prepare means of point clouds
    data_means = parclass_v().means 

    X_train, y_train, _ = gen_data_x(
        jax_key, 1000, 50, 2, parclass_wv().sigs, data_means, vec=True
    )

    jax_key, subkey1, subkey2 = jax.random.split(jax_key, 3)
    X_val, y_val, _ = gen_data_x(
        subkey1, 1000, 50, 2, parclass_wv().sigs, data_means, vec=True
    )

    sim_pars = parclass_v(
        lr_sgd_v=lr_sgd, lr_hebb_v=lr_hebb, lam_sgd_v=lam_sgd, lam_hebb_v=lam_hebb
    )
    sim_pars.means = data_means

    test_single = OneLayer(subkey2, pars=sim_pars)
    test_single.run_scheme(cond, X_train, y_train, X_val, y_val, num_it)
    return test_single.metrics

In [3]:
def simulation_double_wrongpc(
    seed: int,
    num_it: int = 2000,
    cond: Callable = an_cond,
    lr_sgd_v: float = 0,
    lr_hebb_v: float = 1e-4,
    lr_sgd_w: float = 1e-3,
    lr_fsm_w: float = 2e-3,
    lam_sgd_v: float = 5e-2,
    lam_hebb_v: float = 1,
    lam_sgd_w: float = 1e-3,
    lam_fsm_w: float = 1,
    hebb_w: bool = False,
):
    """
    Prepare test and validation data and run simulation for non-isotropic two-layer model.
    :param seed: Random PRNG seed
    :param num_it: Number of iterations to run
    :param cond: Function specifying if iteration i is an active or only passive trial.
    :param lr_sgd_v: SGD learning rate for v
    :param lr_hebb_v: Hebbian learning rate for v
    :param lr_sgd_w: SGD learning rate for W
    :param lr_fsm_w: FSM/Hebbian learning rate for W
    :param lam_sgd_v: SGD weight decay for v
    :param lam_hebb_v: Hebbian weight decay for v
    :param lam_sgd_w: SGD weight decay for W
    :param lam_hebb_w: Hebbian weight decay for W
    :param hebb_w: If True, use Hebbian learning for W, otherwise FSM
    """
    jax_key = jax.random.PRNGKey(seed)

    # Prepare means of point clouds

    data_means = 1.5*parclass_wv().means 
    sigs = pancake_sigs([1,1], 50, 8)
    X_train, y_train, _ = gen_data_x(
        jax_key, 1000, 50, 2, sigs, data_means, vec=True
    )
    jax_key, subkey1, subkey2 = jax.random.split(jax_key, 3)
    X_val, y_val, _ = gen_data_x(
        subkey1, 1000, 50, 2, sigs, data_means, vec=True
    )

    sim_pars = parclass_wv(
        lam_fsm_w=lam_fsm_w,
        lam_sgd_w=lam_sgd_w,
        lam_hebb_v=lam_hebb_v,
        lam_sgd_v=lam_sgd_v,
        lr_hebb_v=lr_hebb_v,
        lr_fsm_w=lr_fsm_w,
        lr_sgd_w=lr_sgd_w,
        lr_sgd_v=lr_sgd_v,
        hebb_w=hebb_w,
        lr_sgd_v_decay=0,
    )
    sim_pars.means = data_means
    sim_pars.sigs = sigs
    
    test_double = TwoLayer(subkey2, pars=sim_pars)
    test_double.run_scheme(cond, X_train, y_train, X_val, y_val, num_it)
    return test_double.metrics

In [4]:
def simulation_double(
    seed: int,
    num_it: int = 2000,
    cond: Callable = an_cond,
    lr_sgd_v: float = 0,
    lr_hebb_v: float = 1e-4,
    lr_sgd_w: float = 1e-3,
    lr_fsm_w: float = 2e-3,
    lam_sgd_v: float = 5e-2,
    lam_hebb_v: float = 1,
    lam_sgd_w: float = 1e-3,
    lam_fsm_w: float = 1,
    hebb_w: bool = False,
):
    """
    Prepare test and validation data and run simulation for two-layer model with isotropic input.
    :param seed: Random PRNG seed
    :param num_it: Number of iterations to run
    :param cond: Function specifying if iteration i is an active or only passive trial.
    :param lr_sgd_v: SGD learning rate for v
    :param lr_hebb_v: Hebbian learning rate for v
    :param lr_sgd_w: SGD learning rate for W
    :param lr_fsm_w: FSM/Hebbian learning rate for W
    :param lam_sgd_v: SGD weight decay for v
    :param lam_hebb_v: Hebbian weight decay for v
    :param lam_sgd_w: SGD weight decay for W
    :param lam_hebb_w: Hebbian weight decay for W
    :param hebb_w: If True, use Hebbian learning for W, otherwise FSM
    """
    jax_key = jax.random.PRNGKey(seed)

    # Prepare means of point clouds

    data_means = 1*parclass_wv().means 
    sigs = pancake_sigs([1,1], 50, 1)
    X_train, y_train, _ = gen_data_x(
        jax_key, 1000, 50, 2, sigs, data_means, vec=True
    )
    jax_key, subkey1, subkey2 = jax.random.split(jax_key, 3)
    X_val, y_val, _ = gen_data_x(
        subkey1, 1000, 50, 2, sigs, data_means, vec=True
    )

    sim_pars = parclass_wv(
        lam_fsm_w=lam_fsm_w,
        lam_sgd_w=lam_sgd_w,
        lam_hebb_v=lam_hebb_v,
        lam_sgd_v=lam_sgd_v,
        lr_hebb_v=lr_hebb_v,
        lr_fsm_w=lr_fsm_w,
        lr_sgd_w=lr_sgd_w,
        lr_sgd_v=lr_sgd_v,
        hebb_w=hebb_w,
        lr_sgd_v_decay=0,
    )
    sim_pars.means = data_means
    sim_pars.sigs = sigs
    
    test_double = TwoLayer(subkey2, pars=sim_pars)
    test_double.run_scheme(cond, X_train, y_train, X_val, y_val, num_it)
    return test_double.metrics

In [5]:
def simulation_double_fsm(
    seed: int,
    num_it: int = 2000,
    cond: Callable = an_cond,
    lr_sgd_v: float = 0,
    lr_hebb_v: float = 1e-4,
    lr_sgd_w: float = 1e-3,
    lr_fsm_w: float = 2e-3,
    lam_fsm_w = 4e-0,
    lam_sgd_w = 1e-3,
    hebb_w: bool = False,
):
    """
    Prepare test and validation data and run simulation for two-layer model with non-aligned input.
    :param seed: Random PRNG seed
    :param num_it: Number of iterations to run
    :param cond: Function specifying if iteration i is an active or only passive trial.
    :param lr_sgd_v: SGD learning rate for v
    :param lr_hebb_v: Hebbian learning rate for v
    :param lr_sgd_w: SGD learning rate for W
    :param lr_fsm_w: FSM/Hebbian learning rate for W
    :param lam_sgd_w: SGD weight decay for W
    :param lam_hebb_w: Hebbian weight decay for W
    :param hebb_w: If True, use Hebbian learning for W, otherwise FSM
    """
    jax_key = jax.random.PRNGKey(seed)

    # Prepare means of point cloud
    O = jnp.eye(100)
    Sig_diag = jnp.diag(jnp.array([2] * 20 + 80 * [1]))
    mean_dir = 1 / jnp.sqrt(30) * O[:, 0:30].sum(axis=1)
    data_means = 1.5*jnp.array([-mean_dir, mean_dir])
    Sig = O @ Sig_diag @ O.T

    X_train, y_train, _ = gen_data_x(
        jax_key, 1000, 100, 2, jnp.array([Sig, Sig]), data_means, vec=True
    )
    jax_key, subkey1, subkey2 = jax.random.split(jax_key, 3)
    X_val, y_val, _ = gen_data_x(
        subkey1, 1000, 100, 2, jnp.array([Sig, Sig]), data_means, vec=True
    )

    sim_pars = parclass_wv(
        lam_fsm_w=lam_fsm_w,
        lam_sgd_w=lam_sgd_w,
        lam_hebb_v=1,
        lam_sgd_v=5e-2,
        lr_hebb_v=lr_hebb_v,
        lr_fsm_w=lr_fsm_w,
        lr_sgd_w=lr_sgd_w,
        lr_sgd_v=lr_sgd_v,
        hebb_w=hebb_w,
        lr_sgd_v_decay=0,
    )
    sim_pars.means = data_means
    sim_pars.sigs = jnp.array([Sig, Sig])
    sim_pars.dim, sim_pars.dim_hid = 100, 20
    test_double = TwoLayer(subkey2, pars=sim_pars)
    test_double.run_scheme(cond, X_train, y_train, X_val, y_val, num_it)
    return test_double.metrics

In [6]:
def save_to_file(an_accs: jnp.ndarray, ap_accs: jnp.ndarray, pta_accs: jnp.ndarray, filename: str):
    """
    Save input arrays to compressed numpy file.
    """
    np.savez_compressed(filename, an=np.array(an_accs), ap=np.array(ap_accs), pta=np.array(pta_accs))

## Model 1

In [7]:
an_onel = jax.vmap(lambda seed: simulation(seed, cond=an_cond, num_it=5000, lr_sgd=2e-4, lr_hebb=1e-4, lam_sgd=5e-2, lam_hebb=2.4))(
    jnp.arange(50)
)

(5000,)


In [8]:
ap_onel = jax.vmap(lambda seed: simulation(seed, cond=ap_cond, num_it=50_000, lr_sgd=2e-4, lr_hebb=1e-4, lam_sgd=5e-2, lam_hebb=2.4))(
    jnp.arange(50)
)

(50000,)


In [9]:
pta_onel = jax.vmap(lambda seed: simulation(seed, cond=pta_cond, num_it=50_000, lr_sgd=2e-4, lr_hebb=1e-4, lam_sgd=5e-2, lam_hebb=2.4))(
    jnp.arange(50)
)

(50000,)


In [10]:
save_to_file(an_onel[1][0], ap_onel[1][0], pta_onel[1][0], "sim_data/onel.npz")

## Model 2

In [11]:
an_doublel_sw = jax.vmap(lambda seed: simulation_double(seed, cond=an_cond, num_it=5000, lr_sgd_v=0, lr_hebb_v=5e-2, lr_sgd_w=1e-3, lr_fsm_w=0e-2, hebb_w=True))(jnp.arange(50))

(5000,)


In [12]:
ap_doublel_sw = jax.vmap(lambda seed: simulation_double(seed, cond=ap_cond, num_it=50000, lr_sgd_v=0, lr_hebb_v=5e-2, lr_sgd_w=1e-3, lr_fsm_w=0e-2, hebb_w=True))(jnp.arange(50))

(50000,)


In [13]:
pta_doublel_sw = jax.vmap(lambda seed: simulation_double(seed, cond=pta_cond, num_it=50000, lr_sgd_v=0, lr_hebb_v=5e-2, lr_sgd_w=1e-3, lr_fsm_w=0e-2, hebb_w=True))(jnp.arange(50))

(50000,)


In [None]:
save_to_file(an_doublel_sw[1][0], ap_doublel_sw[1][0], pta_doublel_sw[1][0], "sim_data/doublel_sw.npz")

## Model 3 (isotropic)

In [None]:
an_doublel_sv = jax.vmap(
    lambda x: simulation_double(
        x,
        cond=an_cond,
        num_it=5000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
    )
)(jnp.arange(50))

In [None]:
ap_doublel_sv = jax.vmap(
    lambda x: simulation_double(
        x,
        cond=ap_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
    )
)(jnp.arange(50))

In [None]:
pta_doublel_sv = jax.vmap(
    lambda x: simulation_double(
        x,
        cond=pta_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
    )
)(jnp.arange(50))

In [None]:
save_to_file(an_doublel_sv[1][0], ap_doublel_sv[1][0], pta_doublel_sv[1][0], "sim_data/doublel_sv.npz")

## Model 3 (non-isotropic)

In [None]:
an_doublel_sv_wrongpc= jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=an_cond,
        num_it=5000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
        
    )
)(jnp.arange(50))

In [None]:
ap_doublel_sv_wrongpc = jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=ap_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
    )
)(jnp.arange(50))

In [None]:
pta_doublel_sv_wrongpc = jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=pta_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=2e-5,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        hebb_w=True,
    )
)(jnp.arange(50))

In [None]:
save_to_file(an_doublel_sv_wrongpc[1][0], ap_doublel_sv_wrongpc[1][0], pta_doublel_sv_wrongpc[1][0], "sim_data/doublel_sv_wrongpc.npz")

## Model 4 (non-isotropic)

In [None]:
an_doublel_sv_wrongpc_fsm = jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=an_cond,
        num_it=5000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=5e-6,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        lam_fsm_w = 5e-1,
        hebb_w=False,
        
    )
)(jnp.arange(50))

In [None]:
ap_doublel_sv_wrongpc_fsm = jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=ap_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=5e-6,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        lam_fsm_w=5e-1,
        hebb_w=False,
    )
)(jnp.arange(50))

In [None]:
pta_doublel_sv_wrongpc_fsm = jax.vmap(
    lambda x: simulation_double_wrongpc(
        x,
        cond=pta_cond,
        num_it=50_000,
        lr_sgd_v=1e-2,
        lr_hebb_v=0e-4,
        lr_sgd_w=0,
        lr_fsm_w=5e-6,
        lam_sgd_v = 2e-2,
        lam_hebb_v=1,
        lam_fsm_w=5e-1,
        hebb_w=False,
    )
)(jnp.arange(50))

In [None]:
save_to_file(an_doublel_sv_wrongpc_fsm[1][0], ap_doublel_sv_wrongpc_fsm[1][0], pta_doublel_sv_wrongpc_fsm[1][0], "sim_data/doublel_sv_wrongpc_fsm.npz")

## Model 5 (non-aligned)

In [None]:
an_doublel_svw_fsm = jax.vmap(
    lambda x: simulation_double_fsm(
        x,
        cond=an_cond,
        num_it=5000,
        lr_sgd_v=2e-3,
        lr_hebb_v=1e-5,
        lr_sgd_w=1e-4,
        lr_fsm_w=6e-5,
        hebb_w=False,
        lam_fsm_w = 4e0,
    )
)(jnp.arange(50))

In [None]:
ap_doublel_svw_fsm = jax.vmap(
    lambda x: simulation_double_fsm(
        x,
        cond=ap_cond,
        num_it=50000,
        lr_sgd_v=2e-3,
        lr_hebb_v=1e-5,
        lr_sgd_w=1e-4,
        lr_fsm_w=6e-5,
        lam_fsm_w=4e0,
        hebb_w=False,
    )
)(jnp.arange(50))

In [None]:
pta_doublel_svw_fsm = jax.vmap(
    lambda x: simulation_double_fsm(
        x,
        cond=pta_cond,
        num_it=50000,
        lr_sgd_v=2e-3,
        lr_hebb_v=1e-5,
        lr_sgd_w=1e-4,
        lr_fsm_w=6e-5,
        lam_fsm_w = 4e0,
        hebb_w=False,
    )
)(jnp.arange(50))

In [None]:
save_to_file(an_doublel_svw_fsm[1][0], ap_doublel_svw_fsm[1][0], pta_doublel_svw_fsm[1][0], "sim_data/doublel_svw_fsm.npz")