In [None]:
#%matplotlib inline
import matplotlib.pyplot as plt




In [None]:
import Odorant_Stim_fourodors


In [None]:
import csv
import collections

def read_connections(filename):
    #r = list(csv.reader(open('updated_erecta_all_circuitry_absolute.csv'))) #Need to get updated connectivity from Ruairi with all neurons
    #r = list(csv.reader(open('updated_melanogaster_all_circuitry_absolute.csv'))) #does this need to be changed to 'filename'?
    r = list(csv.reader(open(filename)))
        
    header = r[0]
    data = r[1:]

    conns = {}
    for row in data:
        for i, item in enumerate(row):
            if i > 0:
                pre = row[0]
                post = header[i]
                c = int(item)
                if c > 0:
                    if pre not in conns:
                        conns[pre] = {}
                    conns[pre][post] = c
                    
    ORNs_left = [name for name in header if 'ORN' in name and 'left' in name]
    ORNs_right = [name for name in header if 'ORN' in name and 'right' in name]
    uPNs_left = [name for name in header if ' uPN' in name and 'left' in name]
    uPNs_right = [name for name in header if ' uPN' in name and 'right' in name]
    mPNs_left = [name for name in header if 'mPN' in name and 'left' in name]
    mPNs_right = [name for name in header if 'mPN' in name and 'right' in name]
    Pickys_left = [name for name in header if 'icky' in name and 'left' in name]
    Pickys_right = [name for name in header if 'icky' in name and 'right' in name]

    #assert (len(ORNs_left)+len(ORNs_right)+len(uPNs_left)+len(uPNs_right)+
    #        len(mPNs_left)+len(mPNs_right)+len(Pickys_left)+len(Pickys_right) == (21*4+15*2+5*2))
                         
    Names = collections.namedtuple('Names', ['ORNs_left', 'uPNs_left', 'mPNs_left', 'Pickys_left'])
    return conns, Names(ORNs_left, uPNs_left, mPNs_left, Pickys_left)

def make_weights(conns, pre, post):
    w = np.zeros((len(post), len(pre))) #note: pre/post switched in output array for print(make_weights())
    for i, pre_n in enumerate(pre):
        for j, post_n in enumerate(post):
            if post_n in conns[pre_n]:
                w[j,i] = conns[pre_n][post_n] 
    return w



In [None]:
import nengo
import numpy as np
import scipy.interpolate

def compute_rate_to_current(neuron_model=nengo.LIFRate(), max_current=10.0):
    tuning_model = nengo.Network()
    with tuning_model:
        N = 1
        T = 10
        max_current = 10.0
        n = nengo.Ensemble(n_neurons=N, dimensions=1,
                           neuron_type=nengo.LIFRate(),
                           gain=[1]*N, bias=[0]*N,
                           )

        stim = nengo.Node(lambda t: t/T*max_current)
        nengo.Connection(stim, n.neurons, transform=np.ones((N, 1)), synapse=None)
        p_rate = nengo.Probe(n.neurons)
        p_current = nengo.Probe(stim)
    sim = nengo.Simulator(tuning_model, progress_bar=False)
    with sim:
        sim.run(T)
    rate_to_current = scipy.interpolate.interp1d(sim.data[p_rate][:,0], sim.data[p_current][:,0])
    return rate_to_current


rate_to_current = compute_rate_to_current()

In [None]:
import pytry
import seaborn as sns

class PickyTrial_MultiOSN(pytry.PlotTrial):
    def params(self):
        self.param('species (melanogaster|erecta)', species='melanogaster')
        self.param('background OR rate', background_rate_OR=6.0)
        self.param('concentration of odorant (log scale)', concentration=-6)
        self.param('odorant (geranyl acetate|anisole|2-heptanone)', odorant='2-heptanone')
        
        self.param('synapse strength for ORN to uPN', w_ORN_uPN=0.002)
        self.param('synapse strength for ORN to mPN', w_ORN_mPN=0.002)
        self.param('synapse strength for ORN to Picky', w_ORN_Picky=0.0005)
        self.param('synapse strength for Picky to uPN', w_Picky_uPN=-0.005)
        self.param('synapse strength for Picky to mPN', w_Picky_mPN=-0.005)
        self.param('synapse strength for Picky to Picky internal feedback loop', w_Picky_Picky=-0.005)
        
        
        
    def evaluate(self, p, plt):
        
        conns, names = read_connections('updated_'+p.species+'_all_circuitry_absolute.csv')
        model = nengo.Network(seed=p.seed)
        with model:
            
            stims = [-20,p.concentration,p.concentration,p.concentration,-20,-20,-20]
            log_concentrations = nengo.Node(nengo.processes.PresentInput(stims, presentation_time=1))

            def logconc_to_conc_func(t, x):              
                return 10**x     
            concentrations = nengo.Node(logconc_to_conc_func, size_in=1)
            odorant_index = ['geranyl acetate', 'anisole', '2-heptanone', 'menthol'].index(p.odorant)
            
            def OR_func(t, x):
                rel = Odorant_Stim_fourodors.convert_compounds_to_responses(x)
                max_rate_range = [80,50,50,50]
                max_rate = max_rate_range[odorant_index]
                background_rate = p.background_rate_OR
                return rate_to_current(rel*max_rate+background_rate)

            l_ORN_current = nengo.Node(OR_func, size_in=4)
            l_ORN = nengo.Ensemble(n_neurons=len(names.ORNs_left), dimensions=1,
                                   neuron_type=nengo.LIF(),
                                   noise=nengo.processes.WhiteNoise(nengo.dists.Gaussian(0,0.02)),
                                   gain=[1]*len(names.ORNs_left), bias=[0]*len(names.ORNs_left))
            
            nengo.Connection(l_ORN_current, l_ORN.neurons, synapse=None)
            nengo.Connection(log_concentrations, concentrations, synapse=None)
            nengo.Connection(concentrations, l_ORN_current[odorant_index], synapse=None)
            #ORN_firing_rate_stim = nengo.Node(nengo.processes.PresentInput(inputs=ORN_of_interest(p.ORN_roi), presentation_time=1))
            #nengo.Connection(ORN_firing_rate_stim, l_ORN.neurons, transform=0.05, synapse=0.01)

            
            l_uPN = nengo.Ensemble(n_neurons=len(names.uPNs_left), dimensions=1,
                                   gain=np.ones(len(names.uPNs_left)), bias=np.zeros(len(names.uPNs_left)))
            l_mPN = nengo.Ensemble(n_neurons=len(names.mPNs_left), dimensions=1,
                                   gain=np.ones(len(names.mPNs_left)), bias=np.zeros(len(names.mPNs_left)))
            l_Picky = nengo.Ensemble(n_neurons=len(names.Pickys_left), dimensions=1,
                                   gain=np.ones(len(names.Pickys_left)), bias=np.zeros(len(names.Pickys_left)))

            
            
            nengo.Connection(l_ORN.neurons, l_uPN.neurons, 
                             transform=p.w_ORN_uPN*make_weights(conns, names.ORNs_left, names.uPNs_left), 
                             synapse=0.01)
            nengo.Connection(l_ORN.neurons, l_mPN.neurons, 
                             transform=p.w_ORN_mPN*make_weights(conns, names.ORNs_left, names.mPNs_left),
                             synapse=0.01)
            nengo.Connection(l_ORN.neurons, l_Picky.neurons, 
                             transform=p.w_ORN_Picky*make_weights(conns, names.ORNs_left, names.Pickys_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_uPN.neurons, 
                             transform=p.w_Picky_uPN*make_weights(conns, names.Pickys_left, names.uPNs_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_mPN.neurons, 
                             transform=p.w_Picky_mPN*make_weights(conns, names.Pickys_left, names.mPNs_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_Picky.neurons, 
                             transform=p.w_Picky_Picky*make_weights(conns, names.Pickys_left, names.Pickys_left),
                             synapse=0.01)

            p_Picky = nengo.Probe(l_Picky.neurons)
            p_uPN = nengo.Probe(l_uPN.neurons)
            p_mPN = nengo.Probe(l_mPN.neurons)
            p_ORN = nengo.Probe(l_ORN.neurons)

        sim = nengo.Simulator(model, seed=p.seed+1)
        sim.run(7)
        
        data_ORN = sim.data[p_ORN]
        data_uPN = sim.data[p_uPN]
        data_mPN = sim.data[p_mPN]
        data_Picky = sim.data[p_Picky]
        
        def calc_max_activity_at_peak(data_neurons):
            filt = nengo.synapses.Lowpass(0.3)
                #Change (value) to get diff filters, try 0.1 for rougher data, try 0.01 for what looks like ephys data (ie close to no filter)
            filt_data = filt.filtfilt(data_neurons)
                #stores filtered dataset in new list to make next line of code easier to write/read
            binned_responses=np.mean(filt_data.T.reshape(len(data_neurons[0][:]), 28, 250), axis=2)  
                                            #In this case, the sim is 7s, so there are 28 bins each 250ms (28bins*250ms/bin=7000ms=7s)
                        #calculating the mean of each 'bin'; the 'reshape' code is reformatting the 1ms-bin data 
                        #into 250ms-bin data // basically, computing mean across 250 ms window while reformatting
                    #I think, axis=2 means that the mean is being computed per 250ms defined bins AND per neuron (in data_neurons)
            return np.max(binned_responses, axis=1) #np.max is a more direct way to extract the max_value (rather than max_index)
                

        result = dict(
            responses_ORN=calc_max_activity_at_peak(data_ORN),
            responses_uPN=calc_max_activity_at_peak(data_uPN),
            responses_mPN=calc_max_activity_at_peak(data_mPN),
            responses_Picky=calc_max_activity_at_peak(data_Picky),
        )
        if plt:
            legends=[names.ORNs_left, names.uPNs_left, names.mPNs_left, names.Pickys_left]
            titles=['ORNs left', 'uPNs left', 'mPNs left', 'Pickys left']

            fig, axs = plt.subplots(4,1, figsize=(30, 40))
            fig.suptitle('Firing Rates w temporal dynamics (dmel)', fontsize = 30, y=0.92)

            for j, i in enumerate([p_ORN,p_uPN,p_mPN,p_Picky]):
                filt = nengo.synapses.Lowpass(0.3)
                new_simdata = np.transpose(sim.data[i])
                for m, neuron in enumerate(i.target):
                    if m>=0 and m<10:
                        y_data_list = sim.data[i]
                        axs[j].plot(sim.trange(), filt.filtfilt(new_simdata[m]), linewidth=2, linestyle = '-')
                    if m>=10 and m<20:
                        axs[j].plot(sim.trange(), filt.filtfilt(new_simdata[m]), linewidth=2, linestyle = '--')
                    if m>=20 and m<30:
                        axs[j].plot(sim.trange(), filt.filtfilt(new_simdata[m]), linewidth=2, linestyle = ':')
                axs[j].axvline(x=1, color='gray')
                axs[j].axvline(x=4, color='gray')
                axs[j].tick_params(axis='y', labelsize= 20)
                axs[j].set_xticks([0,1,4,7]) #,30,35,50,55])
                axs[j].set_xticklabels(['start','ON','OFF','stop'], fontsize = 20) #,'[-6] ON','[-6] OFF','[-5] ON','[-5] OFF'
                axs[j].legend(labels=legends[j], bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0., fontsize = 15)
                axs[j].set_title(titles[j], fontsize = 25) 
                axs[j].set_ylabel('Firing rate (spikes/s)', fontsize = 30)
            #plt.savefig('LinePlot_picky-tuning_7sSTIM-temporaldynamics-nolowpassfilter.png')
            plt.show()
            
        return result 
        
        


In [None]:
print(PickyTrial_MultiOSN().show_params())

In [None]:
r = PickyTrial_MultiOSN().run(concentration=-4, plt=True)


In [None]:
conns, names = read_connections('updated_melanogaster_all_circuitry_absolute.csv')

for seed in range(3):
    for x in [0, 0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008, 0.0009, 0.001, 0.002, 0.003, 0.004, 0.005]:
        for y in [0, -0.0001, -0.0002, -0.0003, -0.0004, -0.0005, -0.0006, -0.0007, -0.0008, -0.0009, -0.001, -0.002, -0.003, -0.004, -0.005]:
            for i in [-6,-5,-4]:
                PickyTrial_MultiOSN().run(verbose=False, w_ORN_Picky=x, w_Picky_Picky=y, concentration=i, 
                                   data_dir='8July2021_MultiOSN_Picky-Grid-Tuning_exp1', data_format='npz', seed=seed) 


In [None]:
15*15*3*21

In [None]:
##old version of model
###none of this code was used to obtain multi-osn stim data

import pytry
import seaborn as sns

class PickyTrial_MultiOSN(pytry.PlotTrial):
    def params(self):
        self.param('species (melanogaster|erecta)', species='melanogaster')
        self.param('concentration of odorant (log scale)', concentration=-6)
        self.param('odorant (geranyl acetate|anisole|2-heptanone)', odorant='2-heptanone')
        self.param('synapse strength for ORN to uPN', w_ORN_uPN=0.002)
        self.param('synapse strength for ORN to mPN', w_ORN_mPN=0.002)
        self.param('synapse strength for ORN to Picky', w_ORN_Picky=0.0005)
        self.param('synapse strength for Picky to uPN', w_Picky_uPN=-0.005)
        self.param('synapse strength for Picky to mPN', w_Picky_mPN=-0.005)
        self.param('synapse strength for Picky to Picky internal feedback loop', w_Picky_Picky=-0.005)
        #self.param('maximum OR rate', max_rate_OR=50.0)
        self.param('background OR rate', background_rate_OR=6.0)
        
    def evaluate(self, p, plt):
        
        conns, names = read_connections('updated_'+p.species+'_all_circuitry_absolute.csv')
        
        
        
        rate_to_current = compute_rate_to_current()

        model = nengo.Network(seed=p.seed)
        with model:
            stims = [-20,p.concentration,-20]
            log_concentrations = nengo.Node(nengo.processes.PresentInput(stims, presentation_time=3))

            def logconc_to_conc_func(t, x):              
                return 10**x     
            concentrations = nengo.Node(logconc_to_conc_func, size_in=1)
            odorant_index = ['geranyl acetate', 'anisole', '2-heptanone', 'menthol'].index(p.odorant)
            
            def OR_func(t, x):
                rel = Odorant_Stim_fourodors.convert_compounds_to_responses(x)
                max_rate_range = [80,50,50,50]
                max_rate = max_rate_range[odorant_index]
                background_rate = p.background_rate_OR
                return rate_to_current(rel*max_rate+background_rate)

            l_ORN_current = nengo.Node(OR_func, size_in=4)
            
            l_ORN = nengo.Ensemble(n_neurons=len(names.ORNs_left), dimensions=1,
                                   neuron_type=nengo.LIF(),
                                   noise=nengo.processes.WhiteNoise(nengo.dists.Gaussian(0,0.02)),
                                   gain=[1]*len(names.ORNs_left), bias=[0]*len(names.ORNs_left))
            
            nengo.Connection(log_concentrations, concentrations, synapse=None)
            nengo.Connection(concentrations, l_ORN_current[odorant_index], synapse=nengo.synapses.Alpha(1))
            nengo.Connection(l_ORN_current, l_ORN.neurons, synapse=None)

            
            l_uPN = nengo.Ensemble(n_neurons=len(names.uPNs_left), dimensions=1,
                                   gain=np.ones(len(names.uPNs_left)), bias=np.zeros(len(names.uPNs_left)))
            l_mPN = nengo.Ensemble(n_neurons=len(names.mPNs_left), dimensions=1,
                                   gain=np.ones(len(names.mPNs_left)), bias=np.zeros(len(names.mPNs_left)))
            l_Picky = nengo.Ensemble(n_neurons=len(names.Pickys_left), dimensions=1,
                                   gain=np.ones(len(names.Pickys_left)), bias=np.zeros(len(names.Pickys_left)))

            
            
            nengo.Connection(l_ORN.neurons, l_uPN.neurons, 
                             transform=p.w_ORN_uPN*make_weights(conns, names.ORNs_left, names.uPNs_left), 
                             synapse=0.01)
            nengo.Connection(l_ORN.neurons, l_mPN.neurons, 
                             transform=p.w_ORN_mPN*make_weights(conns, names.ORNs_left, names.mPNs_left),
                             synapse=0.01)
            nengo.Connection(l_ORN.neurons, l_Picky.neurons, 
                             transform=p.w_ORN_Picky*make_weights(conns, names.ORNs_left, names.Pickys_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_uPN.neurons, 
                             transform=p.w_Picky_uPN*make_weights(conns, names.Pickys_left, names.uPNs_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_mPN.neurons, 
                             transform=p.w_Picky_mPN*make_weights(conns, names.Pickys_left, names.mPNs_left),
                             synapse=0.01)
            nengo.Connection(l_Picky.neurons, l_Picky.neurons, 
                             transform=p.w_Picky_Picky*make_weights(conns, names.Pickys_left, names.Pickys_left),
                             synapse=0.01)

            p_Picky = nengo.Probe(l_Picky.neurons)
            p_uPN = nengo.Probe(l_uPN.neurons)
            p_mPN = nengo.Probe(l_mPN.neurons)
            p_ORN = nengo.Probe(l_ORN.neurons)

        sim = nengo.Simulator(model, seed=p.seed+1)
        sim.run(9)
        
        data_ORN = sim.data[p_ORN]
        data_uPN = sim.data[p_uPN]
        data_mPN = sim.data[p_mPN]
        data_Picky = sim.data[p_Picky]

        def calc_max_activity_at_peak(data_neurons):
            binned_responses=np.mean(data_neurons.T.reshape(len(data_neurons[0][:]), 36, 250), axis=2) #note: 36 bins for 9s sim with 250ms/bin
            max_values = np.zeros((len(data_neurons[0][:])))
            for i in range((len(data_neurons[0][:]))):
                x = np.argmax(binned_responses[i])
                max_values[i] = binned_responses[i][int(x)]                          
            return max_values

        result = dict(
            responses_ORN=calc_max_activity_at_peak(data_ORN),
            responses_uPN=calc_max_activity_at_peak(data_uPN),
            responses_mPN=calc_max_activity_at_peak(data_mPN),
            responses_Picky=calc_max_activity_at_peak(data_Picky),
        )
        if plt:
            fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12,16))
            sns.heatmap(result['responses_ORN'], cmap='Greens', yticklabels= names.ORNs_left, ax=ax1, square=True)
            sns.heatmap(result['responses_uPN'], cmap='Blues', yticklabels= names.uPNs_left, ax=ax2, square=True)
            sns.heatmap(result['responses_mPN'], cmap='Oranges', yticklabels= names.mPNs_left, ax=ax3, square=True)
            sns.heatmap(result['responses_Picky'], cmap='Reds', yticklabels= names.Pickys_left, ax=ax4, square=True)
            for i, ax in enumerate([ax1,ax2,ax3,ax4]):
                ax.axvline(x=20, color='gray')
                ax.axvline(x=30, color='gray')
                ax.set_xticks([20,30])
                ax.set_xticklabels(['['+str(p.concentration)+'] ON', '['+str(p.concentration)+'] OFF'], rotation=0)
            
        return result 
        
        


In [None]:
print(PickyTrial_MultiOSN().show_params())
#r = PickyTrial_MultiOSN().run(concentration= )

In [None]:
conns, names = read_connections('updated_melanogaster_all_circuitry_absolute.csv')

for seed in range(3):
    print(seed)
    for x in [0, 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01]:
        for y in [0, 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01]:
            for i in range(21):
                PickyTrial_MultiOSN().run(verbose=False, w_ORN_Picky=x, w_Picky_Picky=y, 
                                   data_dir='25Jun2021_Picky-Grid-Tuning_exp2', data_format='npz', seed=seed) 
