In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

import sys
import os

In [None]:
from brian2 import *
from brian2tools import *

import matplotlib.pyplot as plt

import cx_rate
import trials
import plotter


from cx_spiking.constants import *

import cx_spiking.plotting
import cx_spiking.inputs
import cx_spiking.network_creation as nc

import cx_spiking.optimisation.metric as metric
import cx_spiking.optimisation.ng_optimiser as ng_optimiser

In [3]:
######################################
### INPUTS
######################################
route_file = os.path.join(os.environ.get('MSC_PROJECT'), 'notebooks/data/route.npz')
T_outbound = 1500
T_inbound = 1500
T = T_outbound+T_inbound

h, v, = cx_spiking.inputs.generate_route(T_outbound=1500, vary_speed=True, route_file=route_file, load_route=True)

cx_spiking.inputs.save_route(route_file, h, v, save_route=True)

# Convert headings
headings = cx_spiking.inputs.compute_headings(h, N=N_TL2//2, vmin=5, vmax=100)
headings = np.tile(headings, 2)
headings = np.concatenate((headings, np.zeros((T_inbound, headings.shape[1]))), axis=0)

# Convert velocity into optical flow
flow = cx_spiking.inputs.compute_flow(h, v, baseline=50, vmin=0, vmax=50)
flow = np.concatenate((flow, np.zeros((T_inbound, flow.shape[1]))), axis=0)


######################################
### RATE BASED CX
######################################
noise = 0.1
cx = cx_rate.CXRatePontin(noise=noise)

h, v, cx_log, cpu4_snapshot = trials.run_trial(logging=True,
                                               T_outbound=T_outbound,
                                               T_inbound=T_inbound,
                                               noise=noise,
                                               cx=cx,
                                               route=(h[:T_outbound], v[:T_outbound]))

Load route from /Users/ff/dev/MSc/CX_Path_Integration/notebooks/data/route.npz
/Users/ff/dev/MSc/CX_Path_Integration/notebooks/data/route.npz exists - not overwriting it


In [21]:
######################################
### SPIKE BASED CX
######################################
start_scope()

time_step = 20 # ms

In [22]:
h_stimulus = TimedArray(headings*Hz, dt=1.*time_step*ms)
P_HEADING = PoissonGroup(N_TL2, rates='h_stimulus(t,i)')

f_stimulus = TimedArray(flow*Hz, dt=1.*time_step*ms)
P_FLOW = PoissonGroup(N_TN2, rates='f_stimulus(t,i)')

In [22]:
# Neuron groups already optimised
G_TL2 = nc.generate_neuron_groups(N_TL2, eqs, threshold_eqs, reset_eqs, TL2_neuron_params, name='TL2')
G_CL1 = nc.generate_neuron_groups(N_CL1, eqs, threshold_eqs, reset_eqs, CL1_neuron_params, name='CL1')
G_TB1 = nc.generate_neuron_groups(N_TB1, eqs, threshold_eqs, reset_eqs, TB1_neuron_params, name='TB1')
G_TN2 = nc.generate_neuron_groups(N_TN2, eqs, threshold_eqs, reset_eqs, TN2_neuron_params, name='TN2')

G_CPU4 = nc.generate_neuron_groups(N_CPU4, eqs, threshold_eqs, reset_eqs, neuron_params, name='CPU4')
G_CPU1A = nc.generate_neuron_groups(N_CPU1A, eqs, threshold_eqs, reset_eqs, neuron_params, name='CPU1A')
G_CPU1B = nc.generate_neuron_groups(N_CPU1B, eqs, threshold_eqs, reset_eqs, neuron_params, name='CPU1B')
G_PONTINE = nc.generate_neuron_groups(N_PONTINE, eqs, threshold_eqs, reset_eqs, neuron_params, name='PONTINE')
G_MOTOR = nc.generate_neuron_groups(N_MOTOR, eqs, threshold_eqs, reset_eqs, neuron_params, name='MOTOR')

In [None]:
SPM_HEADING = SpikeMonitor(P_HEADING)
SPM_FLOW = SpikeMonitor(P_FLOW)

SPM_TL2 = SpikeMonitor(G_TL2)
SPM_CL1 = SpikeMonitor(G_CL1)
SPM_TB1 = SpikeMonitor(G_TB1)
SPM_TN2 = SpikeMonitor(G_TN2)
SPM_CPU4 = SpikeMonitor(G_CPU4)
SPM_CPU1A = SpikeMonitor(G_CPU1A)
SPM_CPU1B = SpikeMonitor(G_CPU1B)
SPM_PONTINE = SpikeMonitor(G_PONTINE)
SPM_MOTOR = SpikeMonitor(G_MOTOR)

In [23]:
S_P_HEADING_TL2 = nc.connect_synapses(P_HEADING, G_TL2, W_HEADING_TL2, model=synapses_model, params=H_TL2_synapses_params, on_pre=synapses_eqs_ex)
S_TL2_CL1 = nc.connect_synapses(G_TL2, G_CL1, W_TL2_CL1, model=synapses_model, params=TL2_CL1_synapses_params, on_pre=synapses_eqs_ex)
S_CL1_TB1 = nc.connect_synapses(G_CL1, G_TB1, W_CL1_TB1, model=synapses_model, params=CL1_TB1_synapses_params, on_pre=synapses_eqs_ex)
S_TB1_TB1 = nc.connect_synapses(G_TB1, G_TB1, W_TB1_TB1, model=synapses_model, params=TB1_TB1_synapses_params, on_pre=synapses_eqs_in)
S_P_FLOW_TN2 = nc.connect_synapses(P_FLOW, G_TN2, W_FLOW_TN2, model=synapses_model, params=F_TN2_synapses_params, on_pre=synapses_eqs_ex)
S_TB1_CPU4 = nc.connect_synapses(G_TB1, G_CPU4, W_TB1_CPU4, model=synapses_model, params=synapses_params, on_pre=synapses_eqs_in)
S_TN2_CPU4 = nc.connect_synapses(G_TN2, G_CPU4, W_TN2_CPU4, model=synapses_model, params=synapses_params, on_pre=synapses_eqs_ex)
S_TB1_CPU1A = nc.connect_synapses(G_TB1, G_CPU1A, W_TB1_CPU1A, model=synapses_model, params=synapses_params, on_pre=synapses_eqs_in)
S_CPU4_PONTINE = nc.connect_synapses(G_CPU4, G_PONTINE, W_CPU4_PONTINE, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_ex)
S_CPU4_CPU1A = nc.connect_synapses(G_CPU4, G_CPU1A, W_CPU4_CPU1A, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_ex)
S_PONTINE_CPU1A = nc.connect_synapses(G_PONTINE, G_CPU1A, W_PONTINE_CPU1A, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_in)
S_TB1_CPU1B = nc.connect_synapses(G_TB1, G_CPU1B, W_TB1_CPU1B, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_in)
S_CPU4_CPU1B = nc.connect_synapses(G_CPU4, G_CPU1B, W_CPU4_CPU1B, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_ex)
S_PONTINE_CPU1B = nc.connect_synapses(G_PONTINE, G_CPU1B, W_PONTINE_CPU1B, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_in)
S_CPU1A_MOTOR = nc.connect_synapses(G_CPU1A, G_MOTOR, W_CPU1A_MOTOR, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_ex)
S_CPU1B_MOTOR = nc.connect_synapses(G_CPU1B, G_MOTOR, W_CPU1B_MOTOR, model=synapses_model, params=synapses_params,  on_pre=synapses_eqs_ex)

In [25]:
global CPU4_memory, CPU4_memory_history
CPU4_memory_history = np.zeros((T, N_CPU4))
CPU4_memory = np.zeros(N_CPU4)

def extract_spike_counts(SPM, t, time_step):
    spike_trains = SPM.spike_trains()
    neurons = np.zeros(len(SPM.spike_trains()), dtype=int)
    for idx in range(len(spike_trains)):
        spike_train = spike_trains[idx]
        neurons[idx] = len(spike_train[(spike_train > t-time_step*ms) & (spike_train < t)])
    return neurons

@network_operation(dt=time_step*ms)
def CPU4_accumulator(t):
    global CPU4_memory, CPU4_memory_history
    
    timestep = int((t/ms + 0.5) / time_step)
    
    if t < time_step*ms:
        return
    spike_trains = SPM_CPU4.spike_trains()
    neurons = extract_spike_counts(SPM_CPU4, t, time_step)    
        
    mem_update = neurons 
    CPU4_memory_history[timestep,:] += mem_update
    CPU4_memory += mem_update

In [26]:
%%time
run((T_outbound)*time_step*ms, report='text')

Starting simulation at t=0. s for a duration of 30. s
1.1768 s (3%) simulated in 10s, estimated 4m 5s remaining.
3.5284 s (11%) simulated in 20s, estimated 2m 30s remaining.
6.2588 s (20%) simulated in 30s, estimated 1m 54s remaining.
8.9897 s (29%) simulated in 40s, estimated 1m 33s remaining.
11.3851 s (37%) simulated in 50s, estimated 1m 22s remaining.
13.7899 s (45%) simulated in 1m 0s, estimated 1m 11s remaining.
15.6622 s (52%) simulated in 1m 10s, estimated 1m 4s remaining.
17.6847 s (58%) simulated in 1m 20s, estimated 56s remaining.
20.0608 s (66%) simulated in 1m 30s, estimated 45s remaining.
22.5301 s (75%) simulated in 1m 40s, estimated 33s remaining.
24.9998 s (83%) simulated in 1m 50s, estimated 22s remaining.
26.9148 s (89%) simulated in 2m 0s, estimated 14s remaining.
28.3563 s (94%) simulated in 2m 10s, estimated 8s remaining.
30. s (100%) simulated in 2m 16s
CPU times: user 2min 7s, sys: 1.42 s, total: 2min 8s
Wall time: 2min 36s


In [1]:
def normalise_CPU4_accumulator(CPU4_memory, vmin=0.15, vmax=0.6):
    L_CPU4 = CPU4_memory[:N_CPU4//2]
    R_CPU4 = CPU4_memory[-N_CPU4//2:]
    L_n = cx_spiking.inputs.normalise_range(L_CPU4, vmin=vmin, vmax=vmax)
    R_n = cx_spiking.inputs.normalise_range(R_CPU4, vmin=vmin, vmax=vmax)
    return np.concatenate((L_n, R_n))