In [None]:
# from google.colab import drive
# drive.mount('/content/drive/')
# %cd '/content/drive/MyDrive/fly_model_shared/brian_pipeline/'

In [None]:
import pandas as pd
import pickle

# from brian2 import NeuronGroup, Synapses, PoissonInput, SpikeMonitor, Network
# from brian2 import mV, ms, Hz, Mohm, uF

# Load data

In [None]:
# load neuron data
df_comp = pd.read_csv('./data/2022_11_22_completeness_materialization_530_final.csv', index_col = 0) # neuron ids and excitation type
df_con = pd.read_csv('./data/2022_11_22_connectivity_530_final.csv', index_col = 0) # connectivity

# load name mappings
with open('./data/name_mappings_530.pickle', 'rb') as f:
    flyid2i, flyid2name, i2flyid, i2name, name2flyid, name2i = pickle.load(f)

# Define models

In [None]:
#                           # Kakaria and de Bivort 2017 https://doi.org/10.3389/fnbeh.2017.00008
#                           # refereneces therein, e.g. Hodgkin and Huxley 1952
v_0     = -52 * mV          # resting potential
v_rst   = -52 * mV          # reset potential after spike
v_th    = -45 * mV          # threshold for spiking
r_mbr   = 10. * Mohm        # membrane resistance
c_mbr   = .002 * uF         # membrane capacitance 
t_mbr   = c_mbr * r_mbr     # membrane time scale

#                           # Jürgensen et al https://doi.org/10.1088/2634-4386/ac3ba6
tau     = 5 * ms            # time constant (this is the excitatory one, the inhibitory is 10 ms)

#                           # Lazar et at https://doi.org/10.7554/eLife.62362
#                           # they cite Kakaria and de Bivort 2017, but those have used 2 ms
t_rfc   = 2.2 * ms          # refractory period

#                           # Paul et al 2015 doi: 10.3389/fncel.2015.00029
t_dly   = 1.8*ms            # delay for changes in post-synaptic neuron

#                           # adjusted arbitrarily
w_syn   = .275 * mV         # weight per synapse (note: modulated by exponential decay)
r_poi   = 150*Hz            # rate of the Poisson input
w_poi   = w_syn*250         # strength of Poisson

#                           # equations for neurons
eqs = '''
dv/dt = (x - (v - v_0)) / t_mbr : volt (unless refractory)
dx/dt = -x / tau                : volt (unless refractory) 
rfc                             : second
'''
eq_th   = 'v > v_th'        # condition for spike
eq_rst  = 'v = v_rst; w = 0; x = 0 * mV' # rules when spike 

def poi(neu, names, rate=r_poi):
    'creates a list of PoissonInput objects for a list of neuron names and NeuronGroup neu'
    l = []
    for n in names:
        i = name2i[n]
        p = PoissonInput(target=neu[i], target_var='v', N=1, rate=rate, weight=w_poi)
        neu[i].rfc = 0 * ms # no refractory period for Poisson targets
        l.append(p)
        
    return l, neu

def default_model():
    '''create default model for neurons and synapses from flywire data
    relies on equations and parameters defined above
    returns NeuronGroup, Synapses
    '''
    
    neu = NeuronGroup( # create neurons
        N=len(df_comp),
        model=eqs,
        method='linear',
        threshold=eq_th,
        reset=eq_rst,
        refractory='rfc',
        name='default_neurons', 
    )
    neu.v = v_0 # initialize values
    neu.x = 0
    neu.rfc = t_rfc

    # create synapses
    syn = Synapses(neu, neu, 'w : volt', on_pre='x += w', delay=t_dly, name='default_synapses')

    # connect synapses
    i_pre = df_con.loc[:, 'Presynaptic_Index'].values
    i_post = df_con.loc[:, 'Postsynaptic_Index'].values
    syn.connect(i=i_pre, j=i_post)

    # define connection weight
    syn.w = df_con.loc[:,"Excitatory x Connectivity"].values * w_syn

    # object to record spikes
    spk_mon = SpikeMonitor(neu) 

    return neu, syn, spk_mon

# Experiments

## define experiment

In [None]:
# experimental setup
t_sim = 1000 * ms   # duration of trial
n_run = 30          # number of runs

# helper functions
def save_spk(i_run, spk_mon, df):
    '''write spike times to dataframe
    creates a new column for each run'''

    spk_trn = {k: v for k, v in spk_mon.spike_trains().items() if len(v)} # select only non-empty spike trains
    df.loc[:, 'run_{}'.format(i_run) ] = pd.Series(spk_trn) # add to dataframe

    return df.copy() # return copy do avoid dataframe fragmentation

def run_default_exp(exp, exc):
    '''run default experiment
    supply name (exp) and list of neuron names to excite (exc)'''

    out_fth = './data/exp_activation/default_{}.feather'.format(exp)
    out_pkl = './data/exp_activation/default_{}.pickle'.format(exp)    

    print('>>> Experiment:     {}'.format(exp))
    print('    Output files:   {}'.format(out_fth))
    print('                    {}'.format(out_pkl))
    print('    Exited neurons: {}'.format(' '.join(exc)))

    neu, syn, spk_mon = default_model() # get default network
    poi_inp, neu = poi(neu, exc) # define Poisson input for excitation
    net = Network(neu, syn, spk_mon, *poi_inp)  # define network
    net.store() # store initial state

    df_spk = pd.DataFrame(index=i2flyid.keys()) # empty dataframe to collect spike times
    for i in range(n_run):
        # run
        net.restore() # restore initial state
        
        net.run(duration=t_sim) # run simulation

        # save spike times
        df_spk = save_spk(i, spk_mon, df_spk)

    # store spike times
    df_spk.to_feather(out_fth)

    # store metadata
    meta_data = {
        'exp':      exp,
        'exc':      exc,
        'exc_i':    [name2flyid[i] for i in exc], 
        't_sim':    t_sim,
        'n_run':    n_run,
    }
    with open(out_pkl, 'wb') as f:
        pickle.dump(meta_data, f)

In [None]:
# from brian2 import prefs, set_device, device
from joblib import Parallel, delayed, parallel_backend
import os
%load_ext autoreload
%autoreload 2
from utils import run

def save_spk(spk_mon):

    spk_trn = {k: v for k, v in spk_mon.spike_trains().items() if len(v)} 
    return spk_trn
    return pd.Series(spk_trn)

# def run(exc): 
#     pid = os.getpid()

#     # cache_dir = os.path.expanduser(f'~/.cython/brian-pid-{os.getpid()}')
#     # prefs.codegen.runtime.cython.cache_dir = cache_dir
#     # prefs.codegen.runtime.cython.multiprocess_safe = False
#     # set_device('cpp_standalone', directory=None)

#     # neu, syn, spk_mon = default_model() 
#     # poi_inp, neu = poi(neu, exc)
#     # net = Network(neu, syn, spk_mon, *poi_inp)  

#     # net.run(duration=1000*ms) 
    
#     # x = save_spk(spk_mon)

#     # device.reinit()

#     return pid
    
def run_default_exp(exp, exc):

    # np = 10
    inps = [ exc for _ in range(30) ]
    with parallel_backend('loky', n_jobs=20):
        res = Parallel()(map(delayed(run), inps))

    # inps = [ exc for _ in range(20) ]
    # res = Parallel(n_jobs=10, batch_size=1)(map(delayed(run), inps))
        
    return res

res = run_default_exp('P9', ['P9_l', 'P9_r'] ) # P9
res

## define neuron groups

In [None]:
# lists of neuron groups
# walk neurons
l_p9 =      ['P9_l', 'P9_r'] 
l_mdn =     ['MDN_a_l', 'MDN_a_r', 'MDN_b_l', 'MDN_b_r']
l_bpn1 =    [ i for i in name2i.keys() if i.startswith('BPN_1') ]
l_bpn2 =    [ i for i in name2i.keys() if i.startswith('BPN_2') ]
l_bpn3 =    [ i for i in name2i.keys() if i.startswith('BPN_3') ]
l_bpn4 =    [ i for i in name2i.keys() if i.startswith('BPN_4') ]
l_bpn = l_bpn1 + l_bpn2 + l_bpn3 + l_bpn4

# stop neurons
l_bb =      ['BB_r', 'BB_l'] 
l_fg =      ['FG_r', 'FG_l']
l_man1 =    ['MAN-1_r', 'MAN-1_l']
l_stop1 =   ['STOP-1_a_r', 'STOP-1_a_l', 'STOP-1_b_r', 'STOP-1_b_l', 'STOP-1_c_r', 'STOP-1_c_l']

# sensory neurons
l_sugarr =  [ i for i in name2i.keys() if i.startswith('sugar_r_') ] # sugar
l_sugarl =  [ i for i in name2i.keys() if i.startswith('sugar_l_') ]
l_ovidn =   [ i for i in name2i.keys() if i.startswith('OviDN_') ] # OviDNs
l_bitterr = [ i for i in name2i.keys() if i.startswith('bitter_r_') ] # bitter
l_bitterl = [ i for i in name2i.keys() if i.startswith('bitter_l_') ]
l_waterr =  [ i for i in name2i.keys() if i.startswith('water_r_') ] # water
l_waterl =  [ i for i in name2i.keys() if i.startswith('water_l_') ]
l_joe =     [ i for i in name2i.keys() if i.startswith('JO_E') ] # JO E
l_jof =     [ i for i in name2i.keys() if i.startswith('JO_F') ] # JO E
l_eye =     [ i for i in name2i.keys() if i.startswith('eye_bristle_') ] # eye bristle

## stop activation

In [None]:
run_default_exp('FG', l_fg) # FG

run_default_exp('BB', l_bb) # BB

run_default_exp('MAN', l_man1) # MAN-1

run_default_exp('STOP-1', l_stop1) # STOP-1

# estimated time: 15 min

## walk activation

In [None]:
# run_default_exp('P9', l_p9) # P9

# run_default_exp('MDN', l_mdn) # MDN

run_default_exp('BPN__', l_bpn)  # all

# run_default_exp('BPN1', l_bpn1) # BPN type 1 
# run_default_exp('BPN2', l_bpn2) # BPN type 2 
# run_default_exp('BPN3', l_bpn3) # BPN type 3 !!! missing IDs !!!
# run_default_exp('BPN4', l_bpn4) # BPN type 4 
# run_default_exp('BPN1+2',  l_bpn1 + l_bpn2) # two types
# run_default_exp('BPN1+3',  l_bpn1 + l_bpn3)
# run_default_exp('BPN1+4',  l_bpn1 + l_bpn4)
# run_default_exp('BPN2+3',  l_bpn2 + l_bpn3)
# run_default_exp('BPN2+4',  l_bpn2 + l_bpn4)
# run_default_exp('BPN3+4',  l_bpn3 + l_bpn4)
# run_default_exp('BPN1+2+3',  l_bpn1 + l_bpn2 + l_bpn3) # three types
# run_default_exp('BPN2+3+4',  l_bpn2 + l_bpn3 + l_bpn4)
# run_default_exp('BPN1+3+4',  l_bpn1 + l_bpn3 + l_bpn4)


# run_default_exp('P9l', ['P9_l']) # P9 unilateral
# run_default_exp('P9r', ['P9_r'])

# # estimated timel: 40 min

## sensory activation

In [None]:
# bilateral
run_default_exp('sugar', l_sugarl + l_sugarr) # sugar
run_default_exp('ovidn', l_ovidn) # OviDNs
run_default_exp('bitter', l_bitterl + l_bitterr) # bitter
run_default_exp('water', l_waterl + l_waterr) # water
run_default_exp('joe', l_joe) # JO E
run_default_exp('jof', l_jof) # JO F
run_default_exp('eye', l_eye) # eye bristle

# only right
run_default_exp('sugarr',  l_sugarr) # sugar
run_default_exp('bitterr', l_bitterr) # bitter
run_default_exp('waterr',  l_waterr) # water

# only left
run_default_exp('sugarl',  l_sugarl) # sugar
run_default_exp('bitterl', l_bitterl) # bitter
run_default_exp('waterl',  l_waterl) # water

# estimated time: 170 min

## walk/stop coactivation

### P9

In [None]:
run_default_exp('P9+BB', l_p9 + l_bb) # plus bluebell
run_default_exp('P9+FG', l_p9 + l_fg) # plus foxglove
run_default_exp('P9+MAN', l_p9 + l_man1) # plus MAN-1
run_default_exp('P9+STOP1', l_p9 + l_stop1) # plus foxglove

# estimated time: 20 min

### MDN

In [None]:
run_default_exp('MDN+BB', l_mdn + l_bb) # plus bluebell
run_default_exp('MDN+FG', l_mdn + l_fg) # plus foxglove
run_default_exp('MDN+MAN', l_mdn + l_man1) # plus MAN-1
run_default_exp('MDN+STOP1', l_mdn + l_stop1) # plus foxglove

# estimated time: 20 min

### BPN

In [None]:
# plus bluebell
run_default_exp('BPN+BB',  l_bpn + l_bb)
run_default_exp('BPN1+BB', l_bpn1 + l_bb)
run_default_exp('BPN2+BB', l_bpn2 + l_bb)
run_default_exp('BPN3+BB', l_bpn3 + l_bb)
run_default_exp('BPN4+BB', l_bpn4 + l_bb)

# plus foxglove
run_default_exp('BPN+FG',  l_bpn + l_fg)
run_default_exp('BPN1+FG', l_bpn1 + l_fg)
run_default_exp('BPN2+FG', l_bpn2 + l_fg)
run_default_exp('BPN3+FG', l_bpn3 + l_fg)
run_default_exp('BPN4+FG', l_bpn4 + l_fg)

# plus MAN-1
run_default_exp('BPN+MAN',  l_bpn + l_man1)
run_default_exp('BPN1+MAN', l_bpn1 + l_man1)
run_default_exp('BPN2+MAN', l_bpn2 + l_man1)
run_default_exp('BPN3+MAN', l_bpn3 + l_man1)
run_default_exp('BPN4+MAN', l_bpn4 + l_man1)

# plus foxglove
run_default_exp('BPN+STOP1',  l_bpn + l_stop1)
run_default_exp('BPN1+STOP1', l_bpn1 + l_stop1)
run_default_exp('BPN2+STOP1', l_bpn2 + l_stop1)
run_default_exp('BPN3+STOP1', l_bpn3 + l_stop1)
run_default_exp('BPN4+STOP1', l_bpn4 + l_stop1)

# estimated time: 110 min

### P9 (unilateral)

In [None]:
# plus bluebell
run_default_exp('P9l+BBl', ['P9_l', 'BB_l']) # ipsi
run_default_exp('P9r+BBr', ['P9_r', 'BB_r'])
run_default_exp('P9l+BBr', ['P9_l', 'BB_r']) # contra
run_default_exp('P9r+BBl', ['P9_r', 'BB_l'])

# plus foxglove
run_default_exp('P9l+FGl', ['P9_l', 'FG_l']) # ipsi
run_default_exp('P9r+FGr', ['P9_r', 'FG_r'])
run_default_exp('P9l+FGr', ['P9_l', 'FG_r']) # contra
run_default_exp('P9r+FGl', ['P9_r', 'FG_l'])

# plus MAN-1
run_default_exp('P9l+MANl', ['P9_l', 'MAN-1_l']) # ipsi
run_default_exp('P9r+MANr', ['P9_r', 'MAN-1_r'])
run_default_exp('P9l+MANr', ['P9_l', 'MAN-1_r']) # contra
run_default_exp('P9r+MANl', ['P9_r', 'MAN-1_l'])

# plus STOP-1
run_default_exp('P9l+STOP1l', ['P9_l', 'STOP-1_a_l', 'STOP-1_b_l', 'STOP-1_c_l']) # ipsi
run_default_exp('P9r+STOP1r', ['P9_r', 'STOP-1_a_r', 'STOP-1_b_r', 'STOP-1_c_r'])
run_default_exp('P9l+STOP1r', ['P9_l', 'STOP-1_a_r', 'STOP-1_b_r', 'STOP-1_c_r']) # contra
run_default_exp('P9r+STOP1l', ['P9_r', 'STOP-1_a_l', 'STOP-1_b_l', 'STOP-1_c_l'])

# estimated time: 60 min

## walk/walk coactivation

In [None]:
# P9 plus MDN
run_default_exp('P9+MDN', l_p9 + l_mdn)

# P9 plus BPN 
run_default_exp('P9+BPN', l_p9 + l_bpn)
run_default_exp('P9+BPN1', l_p9 + l_bpn1)
run_default_exp('P9+BPN2', l_p9 + l_bpn2)
run_default_exp('P9+BPN3', l_p9 + l_bpn3)
run_default_exp('P9+BPN4', l_p9 + l_bpn4)

# MDN plus BPN
run_default_exp('MDN+BPN', l_mdn + l_bpn)
run_default_exp('MDN+BPN1', l_mdn + l_bpn1)
run_default_exp('MDN+BPN2', l_mdn + l_bpn2)
run_default_exp('MDN+BPN3', l_mdn + l_bpn3)
run_default_exp('MDN+BPN4', l_mdn + l_bpn4)

# estimated time: 70 min

## sensory/walk coactivation

### sugar

In [None]:
# right side
run_default_exp('sugarr+P9', l_sugarr + l_p9)     # P9
run_default_exp('sugarr+MDN', l_sugarr + l_mdn)   # MDN
run_default_exp('sugarr+BPN', l_sugarr + l_bpn) # BPN
run_default_exp('sugarr+BPN1', l_sugarr + l_bpn1) # BPN type 1
run_default_exp('sugarr+BPN2', l_sugarr + l_bpn2) #     type 2
run_default_exp('sugarr+BPN3', l_sugarr + l_bpn3) #     type 3
run_default_exp('sugarr+BPN4', l_sugarr + l_bpn4) #     type 4

# bilateral (incomplete)
run_default_exp('sugar+P9', l_sugarr + l_sugarl + l_p9)     # P9
run_default_exp('sugar+MDN', l_sugarr + l_sugarl + l_mdn)   # MDN
run_default_exp('sugar+BPN', l_sugarr + l_sugarl + l_bpn) # BPN
run_default_exp('sugar+BPN1', l_sugarr + l_sugarl + l_bpn1) # BPN type 1
run_default_exp('sugar+BPN2', l_sugarr + l_sugarl + l_bpn2) #     type 2
run_default_exp('sugar+BPN3', l_sugarr + l_sugarl + l_bpn3) #     type 3
run_default_exp('sugar+BPN4', l_sugarr + l_sugarl + l_bpn4) #     type 4

# estimated time: 130 min

### water

In [None]:
# right side
run_default_exp('waterr+P9', l_waterr + l_p9)     # P9
run_default_exp('waterr+MDN', l_waterr + l_mdn)   # MDN
run_default_exp('waterr+BPN', l_waterr + l_bpn)   # BPN 
run_default_exp('waterr+BPN1', l_waterr + l_bpn1) #     type 1
run_default_exp('waterr+BPN2', l_waterr + l_bpn2) #     type 2
run_default_exp('waterr+BPN3', l_waterr + l_bpn3) #     type 3
run_default_exp('waterr+BPN4', l_waterr + l_bpn4) #     type 4

# bilateral (incomplete)
run_default_exp('water+P9', l_waterr + l_waterl + l_p9)     # P9
run_default_exp('water+MDN', l_waterr + l_waterl + l_mdn)   # MDN
run_default_exp('water+BPN', l_waterr + l_waterl + l_bpn)   # BPN 
run_default_exp('water+BPN1', l_waterr + l_waterl + l_bpn1) #     type 1
run_default_exp('water+BPN2', l_waterr + l_waterl + l_bpn2) #     type 2
run_default_exp('water+BPN3', l_waterr + l_waterl + l_bpn3) #     type 3
run_default_exp('water+BPN4', l_waterr + l_waterl + l_bpn4) #     type 4


# estimated time: 130 min

### bitter

In [None]:
# right side
run_default_exp('bitterr+P9', l_bitterr + l_p9)     # P9
run_default_exp('bitterr+MDN', l_bitterr + l_mdn)   # MDN
run_default_exp('bitterr+BPN', l_bitterr + l_bpn)   # BPN 
run_default_exp('bitterr+BPN1', l_bitterr + l_bpn1) #     type 1
run_default_exp('bitterr+BPN2', l_bitterr + l_bpn2) #     type 2
run_default_exp('bitterr+BPN3', l_bitterr + l_bpn3) #     type 3
run_default_exp('bitterr+BPN4', l_bitterr + l_bpn4) #     type 4

# bilateral (incomplete)
run_default_exp('bitter+P9', l_bitterr + l_bitterl + l_p9)     # P9
run_default_exp('bitter+MDN', l_bitterr + l_bitterl + l_mdn)   # MDN
run_default_exp('bitter+BPN', l_bitterr + l_bitterl + l_bpn)   # BPN 
run_default_exp('bitter+BPN1', l_bitterr + l_bitterl + l_bpn1) #     type 1
run_default_exp('bitter+BPN2', l_bitterr + l_bitterl + l_bpn2) #     type 2
run_default_exp('bitter+BPN3', l_bitterr + l_bitterl + l_bpn3) #     type 3
run_default_exp('bitter+BPN4', l_bitterr + l_bitterl + l_bpn4) #     type 4

# estimated time: 120 min 

### OviDNs

In [None]:
run_default_exp('ovidn+P9', l_ovidn + l_p9)     # P9
run_default_exp('ovidn+MDN', l_ovidn + l_mdn)   # MDN
run_default_exp('ovidn+BPN', l_ovidn + l_bpn)   # BPN 
run_default_exp('ovidn+BPN1', l_ovidn + l_bpn1) #     type 1
run_default_exp('ovidn+BPN2', l_ovidn + l_bpn2) #     type 2
run_default_exp('ovidn+BPN3', l_ovidn + l_bpn3) #     type 3
run_default_exp('ovidn+BPN4', l_ovidn + l_bpn4) #     type 4

# esimated time: 50 min

### JO

In [None]:
# JO E
run_default_exp('joe+P9', l_joe + l_p9)     # P9
run_default_exp('joe+MDN', l_joe + l_mdn)   # MDN
run_default_exp('joe+BPN', l_joe + l_bpn)   # BPN
run_default_exp('joe+BPN1', l_joe + l_bpn1) #     type 1
run_default_exp('joe+BPN2', l_joe + l_bpn2) #     type 2
run_default_exp('joe+BPN3', l_joe + l_bpn3) #     type 3
run_default_exp('joe+BPN4', l_joe + l_bpn4) #     type 4

# JO F
run_default_exp('jof+P9', l_jof + l_p9)     # P9
run_default_exp('jof+MDN', l_jof + l_mdn)   # MDN
run_default_exp('jof+BPN', l_jof + l_bpn)   # BPN 
run_default_exp('jof+BPN1', l_jof + l_bpn1) #     type 1
run_default_exp('jof+BPN2', l_jof + l_bpn2) #     type 2
run_default_exp('jof+BPN3', l_jof + l_bpn3) #     type 3
run_default_exp('jof+BPN4', l_jof + l_bpn4) #     type 4

# estimated time: 180 min

### Eye bristle

In [None]:
# eye bristle
run_default_exp('eye+P9', l_eye + l_p9)     # P9
run_default_exp('eye+MDN', l_eye + l_mdn)   # MDN
run_default_exp('eye+BPN', l_eye + l_bpn)   # BPN
run_default_exp('eye+BPN1', l_eye + l_bpn1) #     type 1
run_default_exp('eye+BPN2', l_eye + l_bpn2) #     type 2
run_default_exp('eye+BPN3', l_eye + l_bpn3) #     type 3
run_default_exp('eye+BPN4', l_eye + l_bpn4) #     type 4

# estimated time: 610 min