### Nonlinear learning rule capacity

In [2]:
import sys
sys.path.insert(0, '../../../network')

In [3]:
import pdb
import time
import itertools
import warnings
from tqdm import tqdm
import ray
import numpy as np
from scipy.special import erf
from scipy.stats import pearsonr, norm

In [4]:
from network import Population, RateNetwork
from transfer_functions import ErrorFunction, StepFunction
from connectivity import SparseConnectivity, LinearSynapse, ThresholdPlasticityRule
from sequences import GaussianSequence

In [6]:
ray.init(redis_address="10.122.160.26:6382", include_webui=True, ignore_reinit_error=True)

{'node_ip_address': '10.122.160.26',
 'redis_address': '10.122.160.26:6382',
 'object_store_address': '/tmp/ray/session_2020-02-25_17-43-49_790168_1820902/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-02-25_17-43-49_790168_1820902/sockets/raylet',
 'webui_url': 'http://10.122.160.26:8080/?token=fc0c195f651bd6849b08e5e1425eff451fb0c537a13685e0',
 'session_dir': '/tmp/ray/session_2020-02-25_17-43-49_790168_1820902'}

In [7]:
@ray.remote
def f():
    time.sleep(0.01)
    return ray.services.get_node_ip_address()

# Get a list of the IP addresses of the nodes that have joined the cluster.
print(set(ray.get([f.remote() for _ in range(1000)])))

{'10.122.160.35', '10.122.160.34', '10.122.160.27', '10.122.160.21', '10.122.160.26', '10.122.160.25'}


In [18]:
def simulate(P, theta_0, x_f, q_f, N=50000, debug=False):
    
    warnings.filterwarnings('ignore', message='An input array is constant; the correlation coefficent is not defined.')
    warnings.filterwarnings('ignore', message='An input array is nearly constant; the computed correlation coefficent may be inaccurate.')
    
    q_g = scipy.stats.norm.cdf(x_f) # Note: x_f == x_g
    p = 1 - q_g
    q = 1 - q_f
    p_0 = p*(1-p)*(p*(1-q)**2 + (1-p)*q**2)
    theta = theta_0*p_0
    phi = StepFunction(mu=theta)
    exc = Population(N, tau=1e-2, phi=phi.phi)
    c = 0.005
    sequences = [GaussianSequence(16,exc.size,seed=1)]
    patterns = np.stack([s.inputs for s in sequences])
    plasticity = ThresholdPlasticityRule(x_f=x_f, q_f=q_f)
    conn = SparseConnectivity(source=exc, target=exc, p=c, disable_pbar=True)
    synapse = LinearSynapse(conn.K, A=1)
    conn.store_sequences(patterns, synapse.h_EE, plasticity.f, plasticity.g)
    
    if P > 0:
        sequences2 = [GaussianSequence(P,exc.size,seed=2)]
        patterns2 = np.stack([s.inputs for s in sequences2])
        conn.store_sequences(patterns2, synapse.h_EE, plasticity.f, plasticity.g)
    
    net = RateNetwork(
        exc,
        c_EE=conn,
        formulation=1,
        disable_pbar=True) 
    net.simulate_euler(
        t=0.4,
        r0=exc.phi(plasticity.f(patterns[0,0,:])))
    overlaps = sequences[0].overlaps(
        net,
        exc,
        plasticity=plasticity,
        correlation=False,
        disable_pbar=True)
    corr_P = np.asarray([
        pearsonr(
            plasticity.g(patterns[0,-1,:]),
            net.exc.state[:,t])[0] 
                for t in range(net.exc.state.shape[1])])
    
    return corr_P, overlaps

Bisection method

In [19]:
def run_bisection(
        theta_0=0.01,
        x_f=1.645,
        E_x_f=-0.15,
        P_lower=0.0,
        P_upper=0.5,
        N=40000,
        threshold=0.05,
        tol=0.001,
        debug=False):
    
    from scipy.special import erf
    
    def criterion(correlations):
        corr_P_max = np.nanmax(correlations)
        return corr_P_max > threshold
    
    q_f = 0.5*erf(x_f/np.sqrt(2)) + 0.5 + E_x_f
    
    P_lower_ = P_lower
    P_upper_ = P_upper
    P_final = np.NaN
    errcode = 0
    
    if debug:
        print('theta_0', theta_0, 'E_x_f', E_x_f)
        
    correlation_P, overlaps = simulate(
        P_lower,
        theta_0,
        x_f,
        q_f, N, debug)
    seq_lower = criterion(correlation_P)
    if debug:
        print("Seq_lower", seq_lower, P_lower, P_upper, np.nanmax(correlation_P))
        
    if seq_lower:    
        correlation_P, overlaps = simulate(
            P_upper, theta_0, x_f, q_f, N, debug)
        seq_upper = criterion(correlation_P)
        if debug:
            print("Seq_upper", seq_upper, P_lower, P_upper, np.nanmax(correlation_P))
        
        while True:
            if seq_lower and seq_upper:
                P_lower = P_upper
                P_upper = P_lower + 32 #+ 0.1
                correlation_P, overlaps = simulate(
                    P_upper, theta_0, x_f, q_f, N, debug)
                seq_upper = criterion(correlation_P)
                if debug:
                    print("Seq_upper", seq_upper, P_lower, P_upper, np.nanmax(correlation_P))
            elif not seq_lower and seq_upper:
                # Failed to converge
                errcode = 2
                break
            else:
                P_mid = int((P_lower + P_upper) / 2.)
                correlation_P, overlaps = simulate(
                    P_mid, theta_0, x_f, q_f, N, debug)
                seq_mid = criterion(correlation_P)
                if debug:
                    print("Seq_mid", seq_mid, P_lower, P_mid, P_upper, np.nanmax(correlation_P))
                if seq_lower and not seq_upper:
                    # TODO: Should not need the first two conditions
                    # if rounding error is not present
                    if P_mid == P_lower: 
                        P_final = P_lower
                        break # Converged
                    elif P_mid == P_upper: 
                        P_final = P_lower
                        break # Converged
                    elif P_lower == P_upper: 
                        P_final = P_lower
                        break # Converged
                    else:
                        if seq_mid:
                            P_lower = P_mid
                        else:
                            P_upper = P_mid
                else:
                    # Failed to converge
                    errcode = 3
                    break
            
    return {
        'theta_0': theta_0,
        'x_f': x_f,
        'E_x_f': E_x_f,
        'P_lower_': P_lower_,
        'P_upper_': P_upper_,
        'P_lower': P_lower,
        'P_upper': P_upper,
        'P_final': P_final,
        'N': N,
        'errcode': errcode,
        'criterion_threshold': threshold,
        'overlaps': overlaps,
    }

Parameter exploration

In [20]:
N = 40000
theta_0 = np.linspace(-7,7,28)
coding_levels = np.asarray([0.5, 0.1, 0.05])
x_f = [norm.ppf(x) for x in 1-coding_levels]
E_x_f = [-0.15,-0.1,-0.05]
P_lower = 0
P_upper = 64
threshold = 0.025 

combinations = list(itertools.product(
    np.atleast_1d(theta_0),
    np.atleast_1d(x_f),
    np.atleast_1d(E_x_f),
    np.atleast_1d(N),
    np.atleast_1d(threshold)))

parallel = True
debug = False
object_ids = []
run_bisection_ray = ray.remote(num_cpus=4)(run_bisection)

n = 0
for theta_0_, x_f_, E_x_f_, N_, threshold_ in combinations[:]:
    if parallel:
        func = run_bisection_ray.remote
    else:
        func = run_bisection
    object_ids.append(func(
        theta_0_,
        x_f_,
        E_x_f_,
        P_lower,
        P_upper,
        N_,
        threshold_,
        debug=debug))
    n += 1

Collect and store results

In [21]:
directory = "data/"
pbar = tqdm(total=n)
while len(object_ids) > 0:
    if parallel:
        ready_object_ids, _ = ray.wait(object_ids)
        id_ = ready_object_ids[0]
        data = ray.get(id_)
        object_ids.remove(id_)
    else:
        data = object_ids[0]
        object_ids.remove(data)
    theta_0, x_f, E_x_f, N = \
        data['theta_0'], data['x_f'], data['E_x_f'], data['N']
    filename = "theta_0_%.5f_x_f%.3f_E_x_f%.3f_N%i_thresh%.3f"%(
        theta_0,x_f,E_x_f,N,threshold) + ".npy"
    filepath = directory + filename
    np.save(open(filepath, 'wb'), data)
    pbar.update(1)
    time.sleep(1)

 58%|█████▊    | 146/252 [15:15<09:42,  5.50s/it]  

[2m[36m(pid=2879506, ip=10.122.160.34)[0m   parser = argparse.ArgumentParser(


 60%|█████▉    | 151/252 [15:43<10:40,  6.34s/it]

[2m[36m(pid=2879507, ip=10.122.160.34)[0m   parser = argparse.ArgumentParser(


 88%|████████▊ | 221/252 [22:59<02:29,  4.81s/it]

[2m[36m(pid=653164, ip=10.122.160.25)[0m   parser = argparse.ArgumentParser(


 92%|█████████▏| 231/252 [23:23<00:54,  2.61s/it]

[2m[36m(pid=1820956)[0m   parser = argparse.ArgumentParser(


 92%|█████████▏| 233/252 [23:29<01:00,  3.20s/it]

[2m[36m(pid=2879502, ip=10.122.160.34)[0m   parser = argparse.ArgumentParser(


 96%|█████████▌| 241/252 [23:48<00:28,  2.59s/it]

[2m[36m(pid=2879502, ip=10.122.160.34)[0m   parser = argparse.ArgumentParser(


100%|██████████| 252/252 [28:37<00:00, 28.78s/it]2020-02-25 18:12:39,711	ERROR worker.py:1521 -- print_logs: Connection closed by server.
2020-02-25 18:12:39,716	ERROR worker.py:1621 -- listen_error_messages_raylet: Connection closed by server.
2020-02-25 18:12:39,716	ERROR import_thread.py:89 -- ImportThread: Connection closed by server.


[2m[36m(pid=1820955)[0m Traceback (most recent call last):
[2m[36m(pid=1820955)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/workers/default_worker.py", line 98, in <module>
[2m[36m(pid=1820955)[0m     ray.worker.global_worker.main_loop()
[2m[36m(pid=1820955)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/worker.py", line 954, in main_loop
[2m[36m(pid=1820955)[0m     task = self._get_next_task_from_raylet()
[2m[36m(pid=1820955)[0m   File "/home/mhg19/.local/lib/python3.7/site-packages/ray/worker.py", line 937, in _get_next_task_from_raylet
[2m[36m(pid=1820955)[0m     task = self.raylet_client.get_task()
[2m[36m(pid=1820955)[0m   File "python/ray/_raylet.pyx", line 335, in ray._raylet.RayletClient.get_task
[2m[36m(pid=1820955)[0m   File "python/ray/_raylet.pyx", line 109, in ray._raylet.check_status
[2m[36m(pid=1820955)[0m ray.exceptions.RayletError: The Raylet died with this message: [RayletClient] Raylet connection closed.
